├── .gitignore
├── LICENSE
├── README.md
├── assets
├── 1.png
├── 2.png
├── 3.png
├── 4.png
├── dataset_logo.png
├── dataset_overview.png
├── results.png
└── ui.png
├── configs
├── index_pred_net.yml
├── parsing_gen.yml
├── parsing_token.yml
├── sample_from_parsing.yml
├── sample_from_pose.yml
├── sampler.yml
├── vqvae_bottom.yml
└── vqvae_top.yml
├── data
├── __init__.py
├── mask_dataset.py
├── parsing_generation_segm_attr_dataset.py
├── pose_attr_dataset.py
└── segm_attr_dataset.py
├── environment
└── text2human_env.yaml
├── models
├── __init__.py
├── archs
│ ├── __init__.py
│ ├── fcn_arch.py
│ ├── shape_attr_embedding_arch.py
│ ├── transformer_arch.py
│ ├── unet_arch.py
│ └── vqgan_arch.py
├── hierarchy_inference_model.py
├── hierarchy_vqgan_model.py
├── losses
│ ├── __init__.py
│ ├── accuracy.py
│ ├── cross_entropy_loss.py
│ ├── segmentation_loss.py
│ └── vqgan_loss.py
├── parsing_gen_model.py
├── sample_model.py
├── transformer_model.py
└── vqgan_model.py
├── sample_from_parsing.py
├── sample_from_pose.py
├── train_index_prediction.py
├── train_parsing_gen.py
├── train_parsing_token.py
├── train_sampler.py
├── train_vqvae.py
├── ui
├── __init__.py
├── color_blocks
│ ├── class_bag.png
│ ├── class_belt.png
│ ├── class_bg.png
│ ├── class_dress.png
│ ├── class_earstuds.png
│ ├── class_eyeglass.png
│ ├── class_face.png
│ ├── class_footwear.png
│ ├── class_glove.png
│ ├── class_hair.png
│ ├── class_headwear.png
│ ├── class_leggings.png
│ ├── class_necklace.png
│ ├── class_neckwear.png
│ ├── class_outer.png
│ ├── class_pants.png
│ ├── class_ring.png
│ ├── class_rompers.png
│ ├── class_skin.png
│ ├── class_skirt.png
│ ├── class_socks.png
│ ├── class_tie.png
│ ├── class_top.png
│ └── class_wrist.png
├── icons
│ ├── icon_palette.png
│ └── icon_title.png
├── mouse_event.py
└── ui.py
├── ui_demo.py
├── ui_util
├── __init__.py
└── config.py
└── utils
├── __init__.py
├── language_utils.py
├── logger.py
├── options.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .cache/
3 | datasets/*
4 | experiments/*
5 | tb_logger/*
6 | results/*
7 | *.png
8 | *.txt
9 | *.pth
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2022 S-Lab
4 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
9 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text2Human - Official PyTorch Implementation
2 |
3 |
4 |
5 | This repository provides the official PyTorch implementation for the following paper:
6 |
7 | **Text2Human: Text-Driven Controllable Human Image Generation**
8 | [Yuming Jiang](https://yumingj.github.io/), [Shuai Yang](https://williamyang1991.github.io/), [Haonan Qiu](http://haonanqiu.com/), [Wayne Wu](https://wywu.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) and [Ziwei Liu](https://liuziwei7.github.io/)
9 | In ACM Transactions on Graphics (Proceedings of SIGGRAPH), 2022.
10 |
11 | From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affliated with S-Lab, Nanyang Technological University and SenseTime Research.
12 |
13 |
14 |
15 |  |
16 |  |
17 |  |
18 |  |
19 |
20 |
21 | The lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt. |
22 | The man wears a long and floral shirt, and long pants with the pure color pattern. |
23 | A lady is wearing a sleeveless pure-color shirt and long jeans |
24 | The man wears a short-sleeve T-shirt with the pure color pattern and a short pants with the pure color pattern. |
25 |
26 |
27 |
28 | [**[Project Page]**](https://yumingj.github.io/projects/Text2Human.html) | [**[Paper]**](https://arxiv.org/pdf/2205.15996.pdf) | [**[Dataset]**](https://github.com/yumingj/DeepFashion-MultiModal) | [**[Demo Video]**](https://youtu.be/yKh4VORA_E0) | [**[Gradio Web Demo]**](https://huggingface.co/spaces/CVPR/Text2Human)
29 |
30 |
31 | ## Updates
32 |
33 | - [09/2022] :fire::fire::fire:**We have released a high-quality 3D human generative model [EVA3D](https://hongfz16.github.io/projects/EVA3D.html)!**:fire::fire::fire:
34 | - [07/2022] Release the model trained on [SHHQ dataset](https://stylegan-human.github.io/)!
35 | - [07/2022] Try out the web demo of [drawings-to-human](https://huggingface.co/spaces/CVPR/drawings-to-human)! [](https://huggingface.co/spaces/CVPR/drawings-to-human).
36 | - [06/2022] Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/CVPR/Text2Human)
37 | - [05/2022] Paper and demo video are released.
38 | - [05/2022] Code is released.
39 | - [05/2022] This website is created.
40 |
41 |
42 | ## Installation
43 | **Clone this repo:**
44 | ```bash
45 | git clone https://github.com/yumingj/Text2Human.git
46 | cd Text2Human
47 | ```
48 | **Dependencies:**
49 |
50 | All dependencies for defining the environment are provided in `environment/text2human_env.yaml`.
51 | We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment:
52 | ```bash
53 | conda env create -f ./environment/text2human_env.yaml
54 | conda activate text2human
55 | pip install mmcv-full==1.2.1 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html
56 | pip install mmsegmentation==0.9.0
57 | conda install -c huggingface tokenizers=0.9.4
58 | conda install -c huggingface transformers=4.0.0
59 | conda install -c conda-forge sentence-transformers=2.0.0
60 | ```
61 |
62 | If it doesn't work, you may need to install the following packages on your own:
63 | - Python 3.6
64 | - PyTorch 1.7.1
65 | - CUDA 10.1
66 | - [sentence-transformers](https://huggingface.co/sentence-transformers) 2.0.0
67 | - [tokenizers](https://pypi.org/project/tokenizers/) 0.9.4
68 | - [transformers](https://huggingface.co/docs/transformers/installation) 4.0.0
69 |
70 | ## (1) Dataset Preparation
71 |
72 | In this work, we contribute a large-scale high-quality dataset with rich multi-modal annotations named [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
73 | Here we pre-processed the raw annotations of the original dataset for the task of text-driven controllable human image generation. The pre-processing pipeline consists of:
74 | - align the human body in the center of the images according to the human pose
75 | - fuse the clothing color and clothing fabric annotations into one texture annotation
76 | - do some annotation cleaning and image filtering
77 | - split the whole dataset into the training set and testing set
78 |
79 | You can download our processed dataset from this [Google Drive](https://drive.google.com/file/d/1KIoFfRZNQVn6RV_wTxG2wZmY8f2T_84B/view?usp=sharing). If you want to access the raw annotations, please refer to the [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
80 |
81 | After downloading the dataset, unzip the file and put them under the dataset folder with the following structure:
82 | ```
83 | ./datasets
84 | ├── train_images
85 | ├── xxx.png
86 | ...
87 | ├── xxx.png
88 | └── xxx.png
89 | ├── test_images
90 | % the same structure as in train_images
91 | ├── densepose
92 | % the same structure as in train_images
93 | ├── segm
94 | % the same structure as in train_images
95 | ├── shape_ann
96 | ├── test_ann_file.txt
97 | ├── train_ann_file.txt
98 | └── val_ann_file.txt
99 | └── texture_ann
100 | ├── test
101 | ├── lower_fused.txt
102 | ├── outer_fused.txt
103 | └── upper_fused.txt
104 | ├── train
105 | % the same files as in test
106 | └── val
107 | % the same files as in test
108 | ```
109 |
110 | ## (2) Sampling
111 |
112 | ### HuggingFace Demo
113 | [Full Web Demo](https://huggingface.co/spaces/CVPR/Text2Human)[](https://huggingface.co/spaces/CVPR/Text2Human)
114 |
115 | [Drawing-to-human](https://huggingface.co/spaces/CVPR/drawings-to-human)[](https://huggingface.co/spaces/CVPR/drawings-to-human)
116 |
117 |
118 | ### Colab
119 | [Unofficial Demo](https://colab.research.google.com/drive/1AVwbqLwMp_Gz3KTCgBTtnGVtXIlCZDPk#scrollTo=wMeHXDu11ebH) implemented by [@neverix](https://github.com/neverix).
120 |
121 |
122 | ### Pretrained Models
123 |
124 | Pretrained models can be downloaded from the model zoo. Unzip the file and put them under the pretrained_models folder with the following structure:
125 | ```
126 | pretrained_models
127 | ├── index_pred_net.pth
128 | ├── parsing_gen.pth
129 | ├── parsing_token.pth
130 | ├── sampler.pth
131 | ├── vqvae_bottom.pth
132 | └── vqvae_top.pth
133 | ```
134 |
135 | #### Model Zoo
136 | | Model | Dataset | Annotations |
137 | | :--- | :---- | :---- |
138 | | [Standard Model](https://drive.google.com/file/d/1VyI8_AbPwAUaZJPaPba8zxsFIWumlDen/view?usp=sharing) | [DeepFashion-Multimodal](https://github.com/yumingj/DeepFashion-MultiModal) | Follow the dataset preparation in Step(1) |
139 | | [Extended Model](https://drive.google.com/file/d/1hK1Yu2PA03UuDhewu_sC-WGiEO7h-j6G/view?usp=sharing) | [SHHQ](https://stylegan-human.github.io/) | Replace the annotations with the following ones: [densepose](https://drive.google.com/file/d/1nWRAdjoBqAjrFGaxtNClsvh6JHB8NOnn/view?usp=sharing), [segm](https://drive.google.com/file/d/1tz5i4zjPfYn1fWc5nGCflajQBrdruGj2/view?usp=sharing), [shape](https://drive.google.com/file/d/1Cqo62ffCKuiCiAPyIQ_HT7JqpLQQvCGu/view?usp=sharing), [texture](https://drive.google.com/file/d/1xyFyHGvlp-Qly7t8TT_IlSX7YJlikyy3/view?usp=sharing) |
140 |
141 | **Remark**: For fair research comparisons, it is suggested to use the standard model.
142 |
143 | ### Generation from Paring Maps
144 | You can generate images from given parsing maps and pre-defined texture annotations:
145 | ```python
146 | python sample_from_parsing.py -opt ./configs/sample_from_parsing.yml
147 | ```
148 | The results are saved in the folder `./results/sampling_from_parsing`.
149 |
150 | ### Generation from Poses
151 | You can generate images from given human poses and pre-defined clothing shape and texture annotations:
152 | ```python
153 | python sample_from_pose.py -opt ./configs/sample_from_pose.yml
154 | ```
155 |
156 | **Remarks**: The above two scripts generate images without language interactions. If you want to generate images using texts, you can use the notebook or our user interface.
157 |
158 | ### User Interface
159 |
160 | ```python
161 | python ui_demo.py
162 | ```
163 |
164 |
165 | The descriptions for shapes should follow the following format:
166 | ```
167 | , , , , , ...
168 |
169 | Note: The outer clothing type and accessories can be omitted.
170 |
171 | Examples:
172 | man, sleeveless T-shirt, long pants
173 | woman, short-sleeve T-shirt, short jeans
174 | ```
175 |
176 | The descriptions for textures should follow the following format:
177 | ```
178 | , ,
179 |
180 | Note: Currently, we only support 5 types of textures, i.e., pure color, stripe/spline, plaid/lattice,
181 | floral, denim. Your inputs should be restricted to these textures.
182 | ```
183 |
184 | ## (3) Training Text2Human
185 |
186 | ### Stage I: Pose to Parsing
187 | Train the parsing generation network. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1MNyFLGqIQcOMg_HhgwCmKqdwfQSjeg_6/view?usp=sharing).
188 | ```python
189 | python train_parsing_gen.py -opt ./configs/parsing_gen.yml
190 | ```
191 |
192 | ### Stage II: Parsing to Human
193 |
194 | **Step 1: Train the top level of the hierarchical VQVAE.**
195 | We provide our pretrained model [here](https://drive.google.com/file/d/1TwypUg85gPFJtMwBLUjVS66FKR3oaTz8/view?usp=sharing). This model is trained by:
196 | ```python
197 | python train_vqvae.py -opt ./configs/vqvae_top.yml
198 | ```
199 |
200 | **Step 2: Train the bottom level of the hierarchical VQVAE.**
201 | We provide our pretrained model [here](https://drive.google.com/file/d/15hzbY-RG-ILgzUqqGC0qMzlS4OayPdRH/view?usp=sharing). This model is trained by:
202 | ```python
203 | python train_vqvae.py -opt ./configs/vqvae_bottom.yml
204 | ```
205 |
206 | **Stage 3 & 4: Train the sampler with mixture-of-experts.** To train the sampler, we first need to train a model to tokenize the parsing maps. You can access our pretrained parsing maps [here](https://drive.google.com/file/d/1GLHoOeCP6sMao1-R63ahJMJF7-J00uir/view?usp=sharing).
207 | ```python
208 | python train_parsing_token.py -opt ./configs/parsing_token.yml
209 | ```
210 |
211 | With the parsing tokenization model, the sampler is trained by:
212 | ```python
213 | python train_sampler.py -opt ./configs/sampler.yml
214 | ```
215 | Our pretrained sampler is provided [here](https://drive.google.com/file/d/1OQO_kG2fK7eKiG1VJH1OL782X71UQAmS/view?usp=sharing).
216 |
217 | **Stage 5: Train the index prediction network.**
218 | We provide our pretrained index prediction network [here](https://drive.google.com/file/d/1rqhkQD-JGd7YBeIfDvMV-vjfbNHpIhYm/view?usp=sharing). It is trained by:
219 | ```python
220 | python train_index_prediction.py -opt ./configs/index_pred_net.yml
221 | ```
222 |
223 |
224 | **Remarks**: In the config files, we use the path to our models as the required pretrained models. If you want to train the models from scratch, please replace the path to your own one. We set the numbers of the training epochs as large numbers and you can choose the best epoch for each model. For your reference, our pretrained parsing generation network is trained for 50 epochs, top-level VQVAE is trained for 135 epochs, bottom-level VQVAE is trained for 70 epochs, parsing tokenization network is trained for 20 epochs, sampler is trained for 95 epochs, and the index prediction network is trained for 70 epochs.
225 |
226 | ## (4) Results
227 |
228 | Please visit our [Project Page](https://yumingj.github.io/projects/Text2Human.html#results) to view more results.
229 | You can select the attribtues to customize the desired human images.
230 | [
231 | ](https://yumingj.github.io/projects/Text2Human.html#results)
232 |
233 | ## DeepFashion-MultiModal Dataset
234 |
235 |
236 |
237 | In this work, we also propose **DeepFashion-MultiModal**, a large-scale high-quality human dataset with rich multi-modal annotations. It has the following properties:
238 | 1. It contains 44,096 high-resolution human images, including 12,701 full body human images.
239 | 2. For each full body images, we **manually annotate** the human parsing labels of 24 classes.
240 | 3. For each full body images, we **manually annotate** the keypoints.
241 | 4. We extract DensePose for each human image.
242 | 5. Each image is **manually annotated** with attributes for both clothes shapes and textures.
243 | 6. We provide a textual description for each image.
244 |
245 |
246 |
247 | Please refer to [this repo](https://github.com/yumingj/DeepFashion-MultiModal) for more details about our proposed dataset.
248 |
249 | ## Citation
250 |
251 | If you find this work useful for your research, please consider citing our paper:
252 |
253 | ```bibtex
254 | @article{jiang2022text2human,
255 | title={Text2Human: Text-Driven Controllable Human Image Generation},
256 | author={Jiang, Yuming and Yang, Shuai and Qiu, Haonan and Wu, Wayne and Loy, Chen Change and Liu, Ziwei},
257 | journal={ACM Transactions on Graphics (TOG)},
258 | volume={41},
259 | number={4},
260 | articleno={162},
261 | pages={1--11},
262 | year={2022},
263 | publisher={ACM New York, NY, USA},
264 | doi={10.1145/3528223.3530104},
265 | }
266 | ```
267 |
268 | ## Acknowledgments
269 |
270 | Part of the code is borrowed from [unleashing-transformers](https://github.com/samb-t/unleashing-transformers), [taming-transformers](https://github.com/CompVis/taming-transformers) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
271 |
--------------------------------------------------------------------------------
/assets/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/1.png
--------------------------------------------------------------------------------
/assets/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/2.png
--------------------------------------------------------------------------------
/assets/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/3.png
--------------------------------------------------------------------------------
/assets/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/4.png
--------------------------------------------------------------------------------
/assets/dataset_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/dataset_logo.png
--------------------------------------------------------------------------------
/assets/dataset_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/dataset_overview.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/results.png
--------------------------------------------------------------------------------
/assets/ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/assets/ui.png
--------------------------------------------------------------------------------
/configs/index_pred_net.yml:
--------------------------------------------------------------------------------
1 | name: index_prediction_network
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | train_img_dir: ./datasets/train_images
10 | test_img_dir: ./datasets/test_images
11 | segm_dir: ./datasets/segm
12 | pose_dir: ./datasets/densepose
13 | train_ann_file: ./datasets/texture_ann/train
14 | val_ann_file: ./datasets/texture_ann/val
15 | test_ann_file: ./datasets/texture_ann/test
16 | downsample_factor: 2
17 |
18 | model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
19 | # network configs
20 | embed_dim: 256
21 | n_embed: 1024
22 | codebook_spatial_size: 2
23 |
24 | # bottom level vqvae
25 | bot_n_embed: 512
26 | bot_double_z: false
27 | bot_z_channels: 256
28 | bot_resolution: 512
29 | bot_in_channels: 3
30 | bot_out_ch: 3
31 | bot_ch: 128
32 | bot_ch_mult: [1, 1, 2, 4]
33 | bot_num_res_blocks: 2
34 | bot_attn_resolutions: [64]
35 | bot_dropout: 0.0
36 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth
37 |
38 | # top level vqgan
39 | top_double_z: false
40 | top_z_channels: 256
41 | top_resolution: 512
42 | top_in_channels: 3
43 | top_out_ch: 3
44 | top_ch: 128
45 | top_ch_mult: [1, 1, 2, 2, 4]
46 | top_num_res_blocks: 2
47 | top_attn_resolutions: [32]
48 | top_dropout: 0.0
49 | top_vae_path: ./pretrained_models/vqvae_top.pth
50 |
51 | # unet configs
52 | encoder_in_channels: 256
53 | fc_in_channels: 64
54 | fc_in_index: 4
55 | fc_channels: 64
56 | fc_num_convs: 1
57 | fc_concat_input: False
58 | fc_dropout_ratio: 0.1
59 | fc_num_classes: 512
60 | fc_align_corners: False
61 |
62 | disc_layers: 3
63 | disc_weight_max: 1
64 | disc_start_step: 30001
65 | n_channels: 3
66 | ndf: 64
67 | nf: 128
68 | perceptual_weight: 1.0
69 |
70 | num_segm_classes: 24
71 |
72 | # training configs
73 | val_freq: 5
74 | print_freq: 100
75 | weight_decay: 0
76 | manual_seed: 2021
77 | num_epochs: 100
78 | lr: !!float 1.0e-04
79 | lr_decay: step
80 | gamma: 1.0
81 | step: 50
82 | optimizer: Adam
83 | loss_function: cross_entropy
84 |
85 |
--------------------------------------------------------------------------------
/configs/parsing_gen.yml:
--------------------------------------------------------------------------------
1 | name: parsing_generation
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 8
8 | num_workers: 4
9 | segm_dir: ./datasets/segm
10 | pose_dir: ./datasets/densepose
11 | train_ann_file: ./datasets/shape_ann/train_ann_file.txt
12 | val_ann_file: ./datasets/shape_ann/val_ann_file.txt
13 | test_ann_file: ./datasets/shape_ann/test_ann_file.txt
14 | downsample_factor: 2
15 |
16 | model_type: ParsingGenModel
17 | # network configs
18 | embedder_dim: 8
19 | embedder_out_dim: 128
20 | attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
21 | encoder_in_channels: 1
22 | fc_in_channels: 64
23 | fc_in_index: 4
24 | fc_channels: 64
25 | fc_num_convs: 1
26 | fc_concat_input: False
27 | fc_dropout_ratio: 0.1
28 | fc_num_classes: 24
29 | fc_align_corners: False
30 |
31 | # training configs
32 | val_freq: 5
33 | print_freq: 100
34 | weight_decay: 0
35 | manual_seed: 2021
36 | num_epochs: 100
37 | lr: !!float 1e-4
38 | lr_decay: step
39 | gamma: 0.1
40 | step: 50
41 |
--------------------------------------------------------------------------------
/configs/parsing_token.yml:
--------------------------------------------------------------------------------
1 | name: parsing_tokenization
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | train_img_dir: ./datasets/train_images
10 | test_img_dir: ./datasets/test_images
11 | segm_dir: ./datasets/segm
12 | pose_dir: ./datasets/densepose
13 | train_ann_file: ./datasets/texture_ann/train
14 | val_ann_file: ./datasets/texture_ann/val
15 | test_ann_file: ./datasets/texture_ann/test
16 | downsample_factor: 2
17 |
18 | model_type: VQSegmentationModel
19 | # network configs
20 | embed_dim: 32
21 | n_embed: 1024
22 | image_key: "segmentation"
23 | n_labels: 24
24 | double_z: false
25 | z_channels: 32
26 | resolution: 512
27 | in_channels: 24
28 | out_ch: 24
29 | ch: 64
30 | ch_mult: [1, 1, 2, 2, 4]
31 | num_res_blocks: 1
32 | attn_resolutions: [16]
33 | dropout: 0.0
34 |
35 | num_segm_classes: 24
36 |
37 |
38 | # training configs
39 | val_freq: 5
40 | print_freq: 100
41 | weight_decay: 0
42 | manual_seed: 2021
43 | num_epochs: 100
44 | lr: !!float 4.5e-05
45 | lr_decay: step
46 | gamma: 0.1
47 | step: 50
48 |
--------------------------------------------------------------------------------
/configs/sample_from_parsing.yml:
--------------------------------------------------------------------------------
1 | name: sample_from_parsing
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | test_img_dir: ./datasets/test_images
10 | segm_dir: ./datasets/segm
11 | pose_dir: ./datasets/densepose
12 | test_ann_file: ./datasets/texture_ann/test
13 | downsample_factor: 2
14 |
15 | model_type: SampleFromParsingModel
16 | # network configs
17 | embed_dim: 256
18 | n_embed: 1024
19 | codebook_spatial_size: 2
20 |
21 | # bottom level vqvae
22 | bot_n_embed: 512
23 | bot_codebook_spatial_size: 2
24 | bot_double_z: false
25 | bot_z_channels: 256
26 | bot_resolution: 512
27 | bot_in_channels: 3
28 | bot_out_ch: 3
29 | bot_ch: 128
30 | bot_ch_mult: [1, 1, 2, 4]
31 | bot_num_res_blocks: 2
32 | bot_attn_resolutions: [64]
33 | bot_dropout: 0.0
34 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth
35 |
36 | # top level vqgan
37 | top_double_z: false
38 | top_z_channels: 256
39 | top_resolution: 512
40 | top_in_channels: 3
41 | top_out_ch: 3
42 | top_ch: 128
43 | top_ch_mult: [1, 1, 2, 2, 4]
44 | top_num_res_blocks: 2
45 | top_attn_resolutions: [32]
46 | top_dropout: 0.0
47 | top_vae_path: ./pretrained_models/vqvae_top.pth
48 |
49 | # unet configs
50 | index_pred_encoder_in_channels: 256
51 | index_pred_fc_in_channels: 64
52 | index_pred_fc_in_index: 4
53 | index_pred_fc_channels: 64
54 | index_pred_fc_num_convs: 1
55 | index_pred_fc_concat_input: False
56 | index_pred_fc_dropout_ratio: 0.1
57 | index_pred_fc_num_classes: 512
58 | index_pred_fc_align_corners: False
59 | pretrained_index_network: ./pretrained_models/index_pred_net.pth
60 |
61 | # segmentation tokenization
62 | segm_double_z: false
63 | segm_z_channels: 32
64 | segm_resolution: 512
65 | segm_in_channels: 24
66 | segm_out_ch: 24
67 | segm_ch: 64
68 | segm_ch_mult: [1, 1, 2, 2, 4]
69 | segm_num_res_blocks: 1
70 | segm_attn_resolutions: [16]
71 | segm_dropout: 0.0
72 | segm_num_segm_classes: 24
73 | segm_n_embed: 1024
74 | segm_embed_dim: 32
75 | segm_token_path: ./pretrained_models/parsing_token.pth
76 |
77 | # sampler configs
78 | codebook_size: 18432
79 | segm_codebook_size: 1024
80 | texture_codebook_size: 18
81 | bert_n_emb: 512
82 | bert_n_layers: 24
83 | bert_n_head: 8
84 | block_size: 512 # 32 x 16
85 | latent_shape: [32, 16]
86 | embd_pdrop: 0.0
87 | resid_pdrop: 0.0
88 | attn_pdrop: 0.0
89 | num_head: 18
90 | pretrained_sampler: ./pretrained_models/sampler.pth
91 |
92 | manual_seed: 2021
93 | sample_steps: 256
94 |
--------------------------------------------------------------------------------
/configs/sample_from_pose.yml:
--------------------------------------------------------------------------------
1 | name: sample_from_pose
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | pose_dir: ./datasets/densepose
10 | texture_ann_file: ./datasets/texture_ann/test
11 | shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
12 | downsample_factor: 2
13 |
14 | model_type: SampleFromPoseModel
15 | # network configs
16 | embed_dim: 256
17 | n_embed: 1024
18 | codebook_spatial_size: 2
19 |
20 | # bottom level vqgan
21 | bot_n_embed: 512
22 | bot_codebook_spatial_size: 2
23 | bot_double_z: false
24 | bot_z_channels: 256
25 | bot_resolution: 512
26 | bot_in_channels: 3
27 | bot_out_ch: 3
28 | bot_ch: 128
29 | bot_ch_mult: [1, 1, 2, 4]
30 | bot_num_res_blocks: 2
31 | bot_attn_resolutions: [64]
32 | bot_dropout: 0.0
33 | bot_vae_path: ./pretrained_models/vqvae_bottom.pth
34 |
35 | # top level vqgan
36 | top_double_z: false
37 | top_z_channels: 256
38 | top_resolution: 512
39 | top_in_channels: 3
40 | top_out_ch: 3
41 | top_ch: 128
42 | top_ch_mult: [1, 1, 2, 2, 4]
43 | top_num_res_blocks: 2
44 | top_attn_resolutions: [32]
45 | top_dropout: 0.0
46 | top_vae_path: ./pretrained_models/vqvae_top.pth
47 |
48 | # unet configs
49 | index_pred_encoder_in_channels: 256
50 | index_pred_fc_in_channels: 64
51 | index_pred_fc_in_index: 4
52 | index_pred_fc_channels: 64
53 | index_pred_fc_num_convs: 1
54 | index_pred_fc_concat_input: False
55 | index_pred_fc_dropout_ratio: 0.1
56 | index_pred_fc_num_classes: 512
57 | index_pred_fc_align_corners: False
58 | pretrained_index_network: ./pretrained_models/index_pred_net.pth
59 |
60 | # segmentation tokenization
61 | segm_double_z: false
62 | segm_z_channels: 32
63 | segm_resolution: 512
64 | segm_in_channels: 24
65 | segm_out_ch: 24
66 | segm_ch: 64
67 | segm_ch_mult: [1, 1, 2, 2, 4]
68 | segm_num_res_blocks: 1
69 | segm_attn_resolutions: [16]
70 | segm_dropout: 0.0
71 | segm_num_segm_classes: 24
72 | segm_n_embed: 1024
73 | segm_embed_dim: 32
74 | segm_token_path: ./pretrained_models/parsing_token.pth
75 |
76 | # sampler configs
77 | codebook_size: 18432
78 | segm_codebook_size: 1024
79 | texture_codebook_size: 18
80 | bert_n_emb: 512
81 | bert_n_layers: 24
82 | bert_n_head: 8
83 | block_size: 512 # 32 x 16
84 | latent_shape: [32, 16]
85 | embd_pdrop: 0.0
86 | resid_pdrop: 0.0
87 | attn_pdrop: 0.0
88 | num_head: 18
89 | pretrained_sampler: ./pretrained_models/sampler.pth
90 |
91 | # shape network configs
92 | shape_embedder_dim: 8
93 | shape_embedder_out_dim: 128
94 | shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
95 | shape_encoder_in_channels: 1
96 | shape_fc_in_channels: 64
97 | shape_fc_in_index: 4
98 | shape_fc_channels: 64
99 | shape_fc_num_convs: 1
100 | shape_fc_concat_input: False
101 | shape_fc_dropout_ratio: 0.1
102 | shape_fc_num_classes: 24
103 | shape_fc_align_corners: False
104 | pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
105 |
106 | manual_seed: 2021
107 | sample_steps: 256
108 |
--------------------------------------------------------------------------------
/configs/sampler.yml:
--------------------------------------------------------------------------------
1 | name: sampler
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 1
9 | train_img_dir: ./datasets/train_images
10 | test_img_dir: ./datasets/test_images
11 | segm_dir: ./datasets/segm
12 | pose_dir: ./datasets/densepose
13 | train_ann_file: ./datasets/texture_ann/train
14 | val_ann_file: ./datasets/texture_ann/val
15 | test_ann_file: ./datasets/texture_ann/test
16 | downsample_factor: 2
17 |
18 | # pretrained models
19 | img_ae_path: ./pretrained_models/vqvae_top.pth
20 | segm_ae_path: ./pretrained_models/parsing_token.pth
21 |
22 | model_type: TransformerTextureAwareModel
23 | # network configs
24 |
25 | # image autoencoder
26 | img_embed_dim: 256
27 | img_n_embed: 1024
28 | img_double_z: false
29 | img_z_channels: 256
30 | img_resolution: 512
31 | img_in_channels: 3
32 | img_out_ch: 3
33 | img_ch: 128
34 | img_ch_mult: [1, 1, 2, 2, 4]
35 | img_num_res_blocks: 2
36 | img_attn_resolutions: [32]
37 | img_dropout: 0.0
38 |
39 | # segmentation tokenization
40 | segm_double_z: false
41 | segm_z_channels: 32
42 | segm_resolution: 512
43 | segm_in_channels: 24
44 | segm_out_ch: 24
45 | segm_ch: 64
46 | segm_ch_mult: [1, 1, 2, 2, 4]
47 | segm_num_res_blocks: 1
48 | segm_attn_resolutions: [16]
49 | segm_dropout: 0.0
50 | segm_num_segm_classes: 24
51 | segm_n_embed: 1024
52 | segm_embed_dim: 32
53 |
54 | # sampler configs
55 | codebook_size: 18432
56 | segm_codebook_size: 1024
57 | texture_codebook_size: 18
58 | bert_n_emb: 512
59 | bert_n_layers: 24
60 | bert_n_head: 8
61 | block_size: 512 # 32 x 16
62 | latent_shape: [32, 16]
63 | embd_pdrop: 0.0
64 | resid_pdrop: 0.0
65 | attn_pdrop: 0.0
66 | num_head: 18
67 |
68 | # loss configs
69 | loss_type: reweighted_elbo
70 | mask_schedule: random
71 |
72 | sample_steps: 256
73 |
74 | # training configs
75 | val_freq: 5
76 | print_freq: 100
77 | weight_decay: 0
78 | manual_seed: 2021
79 | num_epochs: 100
80 | lr: !!float 1e-4
81 | lr_decay: step
82 | gamma: 1.0
83 | step: 50
84 |
--------------------------------------------------------------------------------
/configs/vqvae_bottom.yml:
--------------------------------------------------------------------------------
1 | name: vqvae_bottom
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | train_img_dir: ./datasets/train_images
10 | test_img_dir: ./datasets/test_images
11 | segm_dir: ./datasets/segm
12 | pose_dir: ./datasets/densepose
13 | train_ann_file: ./datasets/texture_ann/train
14 | val_ann_file: ./datasets/texture_ann/val
15 | test_ann_file: ./datasets/texture_ann/test
16 | downsample_factor: 2
17 |
18 | model_type: HierarchyVQSpatialTextureAwareModel
19 | # network configs
20 | embed_dim: 256
21 | n_embed: 1024
22 | codebook_spatial_size: 2
23 |
24 | # bottom level vqvae
25 | bot_n_embed: 512
26 | bot_double_z: false
27 | bot_z_channels: 256
28 | bot_resolution: 512
29 | bot_in_channels: 3
30 | bot_out_ch: 3
31 | bot_ch: 128
32 | bot_ch_mult: [1, 1, 2, 4]
33 | bot_num_res_blocks: 2
34 | bot_attn_resolutions: [64]
35 | bot_dropout: 0.0
36 |
37 | # top level vqgan
38 | top_double_z: false
39 | top_z_channels: 256
40 | top_resolution: 512
41 | top_in_channels: 3
42 | top_out_ch: 3
43 | top_ch: 128
44 | top_ch_mult: [1, 1, 2, 2, 4]
45 | top_num_res_blocks: 2
46 | top_attn_resolutions: [32]
47 | top_dropout: 0.0
48 | top_vae_path: ./pretrained_models/vqvae_top.pth
49 |
50 | fix_decoder: false
51 |
52 | disc_layers: 3
53 | disc_weight_max: 1
54 | disc_start_step: 1
55 | n_channels: 3
56 | ndf: 64
57 | nf: 128
58 | perceptual_weight: 1.0
59 |
60 | num_segm_classes: 24
61 |
62 | # training configs
63 | val_freq: 5
64 | print_freq: 100
65 | weight_decay: 0
66 | manual_seed: 2021
67 | num_epochs: 1000
68 | lr: !!float 1.0e-04
69 | lr_decay: step
70 | gamma: 1.0
71 | step: 50
72 |
73 |
--------------------------------------------------------------------------------
/configs/vqvae_top.yml:
--------------------------------------------------------------------------------
1 | name: vqvae_top
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 4
9 | train_img_dir: ./datasets/train_images
10 | test_img_dir: ./datasets/test_images
11 | segm_dir: ./datasets/segm
12 | pose_dir: ./datasets/densepose
13 | train_ann_file: ./datasets/texture_ann/train
14 | val_ann_file: ./datasets/texture_ann/val
15 | test_ann_file: ./datasets/texture_ann/test
16 | downsample_factor: 2
17 |
18 | model_type: VQImageSegmTextureModel
19 | # network configs
20 | embed_dim: 256
21 | n_embed: 1024
22 | double_z: false
23 | z_channels: 256
24 | resolution: 512
25 | in_channels: 3
26 | out_ch: 3
27 | ch: 128
28 | ch_mult: [1, 1, 2, 2, 4]
29 | num_res_blocks: 2
30 | attn_resolutions: [32]
31 | dropout: 0.0
32 |
33 | disc_layers: 3
34 | disc_weight_max: 1
35 | disc_start_step: 30001
36 | n_channels: 3
37 | ndf: 64
38 | nf: 128
39 | perceptual_weight: 1.0
40 |
41 | num_segm_classes: 24
42 |
43 |
44 | # training configs
45 | val_freq: 5
46 | print_freq: 100
47 | weight_decay: 0
48 | manual_seed: 2021
49 | num_epochs: 1000
50 | lr: !!float 1.0e-04
51 | lr_decay: step
52 | gamma: 1.0
53 | step: 50
54 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/data/__init__.py
--------------------------------------------------------------------------------
/data/mask_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | from PIL import Image
9 |
10 |
11 | class MaskDataset(data.Dataset):
12 |
13 | def __init__(self, segm_dir, ann_dir, downsample_factor=2, xflip=False):
14 |
15 | self._segm_path = segm_dir
16 | self._image_fnames = []
17 |
18 | self.downsample_factor = downsample_factor
19 | self.xflip = xflip
20 |
21 | # load attributes
22 | assert os.path.exists(f'{ann_dir}/upper_fused.txt')
23 | for idx, row in enumerate(
24 | open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
25 | annotations = row.split()
26 | self._image_fnames.append(annotations[0])
27 |
28 | def _open_file(self, path_prefix, fname):
29 | return open(os.path.join(path_prefix, fname), 'rb')
30 |
31 | def _load_segm(self, raw_idx):
32 | fname = self._image_fnames[raw_idx]
33 | fname = f'{fname[:-4]}_segm.png'
34 | with self._open_file(self._segm_path, fname) as f:
35 | segm = Image.open(f)
36 | if self.downsample_factor != 1:
37 | width, height = segm.size
38 | width = width // self.downsample_factor
39 | height = height // self.downsample_factor
40 | segm = segm.resize(
41 | size=(width, height), resample=Image.NEAREST)
42 | segm = np.array(segm)
43 | # segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
44 | return segm.astype(np.float32)
45 |
46 | def __getitem__(self, index):
47 | segm = self._load_segm(index)
48 |
49 | if self.xflip and random.random() > 0.5:
50 | segm = segm[:, ::-1].copy()
51 |
52 | segm = torch.from_numpy(segm).long()
53 |
54 | return_dict = {'segm': segm, 'img_name': self._image_fnames[index]}
55 |
56 | return return_dict
57 |
58 | def __len__(self):
59 | return len(self._image_fnames)
60 |
--------------------------------------------------------------------------------
/data/parsing_generation_segm_attr_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 |
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 | from PIL import Image
8 |
9 |
10 | class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset):
11 |
12 | def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2):
13 | self._densepose_path = pose_dir
14 | self._segm_path = segm_dir
15 | self._image_fnames = []
16 | self.attrs = []
17 |
18 | self.downsample_factor = downsample_factor
19 |
20 | # training, ground-truth available
21 | assert os.path.exists(ann_file)
22 | for row in open(os.path.join(ann_file), 'r'):
23 | annotations = row.split()
24 | self._image_fnames.append(annotations[0])
25 | self.attrs.append([int(i) for i in annotations[1:]])
26 |
27 | def _open_file(self, path_prefix, fname):
28 | return open(os.path.join(path_prefix, fname), 'rb')
29 |
30 | def _load_densepose(self, raw_idx):
31 | fname = self._image_fnames[raw_idx]
32 | fname = f'{fname[:-4]}_densepose.png'
33 | with self._open_file(self._densepose_path, fname) as f:
34 | densepose = Image.open(f)
35 | if self.downsample_factor != 1:
36 | width, height = densepose.size
37 | width = width // self.downsample_factor
38 | height = height // self.downsample_factor
39 | densepose = densepose.resize(
40 | size=(width, height), resample=Image.NEAREST)
41 | # channel-wise IUV order, [3, H, W]
42 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
43 | return densepose.astype(np.float32)
44 |
45 | def _load_segm(self, raw_idx):
46 | fname = self._image_fnames[raw_idx]
47 | fname = f'{fname[:-4]}_segm.png'
48 | with self._open_file(self._segm_path, fname) as f:
49 | segm = Image.open(f)
50 | if self.downsample_factor != 1:
51 | width, height = segm.size
52 | width = width // self.downsample_factor
53 | height = height // self.downsample_factor
54 | segm = segm.resize(
55 | size=(width, height), resample=Image.NEAREST)
56 | segm = np.array(segm)
57 | return segm.astype(np.float32)
58 |
59 | def __getitem__(self, index):
60 | pose = self._load_densepose(index)
61 | segm = self._load_segm(index)
62 | attr = self.attrs[index]
63 |
64 | pose = torch.from_numpy(pose)
65 | segm = torch.LongTensor(segm)
66 | attr = torch.LongTensor(attr)
67 |
68 | pose = pose / 12. - 1
69 |
70 | return_dict = {
71 | 'densepose': pose,
72 | 'segm': segm,
73 | 'attr': attr,
74 | 'img_name': self._image_fnames[index]
75 | }
76 |
77 | return return_dict
78 |
79 | def __len__(self):
80 | return len(self._image_fnames)
81 |
--------------------------------------------------------------------------------
/data/pose_attr_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | from PIL import Image
9 |
10 |
11 | class DeepFashionAttrPoseDataset(data.Dataset):
12 |
13 | def __init__(self,
14 | pose_dir,
15 | texture_ann_dir,
16 | shape_ann_path,
17 | downsample_factor=2,
18 | xflip=False):
19 | self._densepose_path = pose_dir
20 | self._image_fnames_target = []
21 | self._image_fnames = []
22 | self.upper_fused_attrs = []
23 | self.lower_fused_attrs = []
24 | self.outer_fused_attrs = []
25 | self.shape_attrs = []
26 |
27 | self.downsample_factor = downsample_factor
28 | self.xflip = xflip
29 |
30 | # load attributes
31 | assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt')
32 | for idx, row in enumerate(
33 | open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')):
34 | annotations = row.split()
35 | self._image_fnames_target.append(annotations[0])
36 | self._image_fnames.append(f'{annotations[0].split(".")[0]}.png')
37 | self.upper_fused_attrs.append(int(annotations[1]))
38 |
39 | assert len(self._image_fnames_target) == len(self.upper_fused_attrs)
40 |
41 | assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt')
42 | for idx, row in enumerate(
43 | open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')):
44 | annotations = row.split()
45 | assert self._image_fnames_target[idx] == annotations[0]
46 | self.lower_fused_attrs.append(int(annotations[1]))
47 |
48 | assert len(self._image_fnames_target) == len(self.lower_fused_attrs)
49 |
50 | assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt')
51 | for idx, row in enumerate(
52 | open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')):
53 | annotations = row.split()
54 | assert self._image_fnames_target[idx] == annotations[0]
55 | self.outer_fused_attrs.append(int(annotations[1]))
56 |
57 | assert len(self._image_fnames_target) == len(self.outer_fused_attrs)
58 |
59 | assert os.path.exists(shape_ann_path)
60 | for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')):
61 | annotations = row.split()
62 | assert self._image_fnames_target[idx] == annotations[0]
63 | self.shape_attrs.append([int(i) for i in annotations[1:]])
64 |
65 | def _open_file(self, path_prefix, fname):
66 | return open(os.path.join(path_prefix, fname), 'rb')
67 |
68 | def _load_densepose(self, raw_idx):
69 | fname = self._image_fnames[raw_idx]
70 | fname = f'{fname[:-4]}_densepose.png'
71 | with self._open_file(self._densepose_path, fname) as f:
72 | densepose = Image.open(f)
73 | if self.downsample_factor != 1:
74 | width, height = densepose.size
75 | width = width // self.downsample_factor
76 | height = height // self.downsample_factor
77 | densepose = densepose.resize(
78 | size=(width, height), resample=Image.NEAREST)
79 | # channel-wise IUV order, [3, H, W]
80 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
81 | return densepose.astype(np.float32)
82 |
83 | def __getitem__(self, index):
84 | pose = self._load_densepose(index)
85 | shape_attr = self.shape_attrs[index]
86 | shape_attr = torch.LongTensor(shape_attr)
87 |
88 | if self.xflip and random.random() > 0.5:
89 | pose = pose[:, :, ::-1].copy()
90 |
91 | upper_fused_attr = self.upper_fused_attrs[index]
92 | lower_fused_attr = self.lower_fused_attrs[index]
93 | outer_fused_attr = self.outer_fused_attrs[index]
94 |
95 | pose = pose / 12. - 1
96 |
97 | return_dict = {
98 | 'densepose': pose,
99 | 'img_name': self._image_fnames_target[index],
100 | 'shape_attr': shape_attr,
101 | 'upper_fused_attr': upper_fused_attr,
102 | 'lower_fused_attr': lower_fused_attr,
103 | 'outer_fused_attr': outer_fused_attr,
104 | }
105 |
106 | return return_dict
107 |
108 | def __len__(self):
109 | return len(self._image_fnames)
110 |
--------------------------------------------------------------------------------
/data/segm_attr_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | from PIL import Image
9 |
10 |
11 | class DeepFashionAttrSegmDataset(data.Dataset):
12 |
13 | def __init__(self,
14 | img_dir,
15 | segm_dir,
16 | pose_dir,
17 | ann_dir,
18 | downsample_factor=2,
19 | xflip=False):
20 | self._img_path = img_dir
21 | self._densepose_path = pose_dir
22 | self._segm_path = segm_dir
23 | self._image_fnames = []
24 | self.upper_fused_attrs = []
25 | self.lower_fused_attrs = []
26 | self.outer_fused_attrs = []
27 |
28 | self.downsample_factor = downsample_factor
29 | self.xflip = xflip
30 |
31 | # load attributes
32 | assert os.path.exists(f'{ann_dir}/upper_fused.txt')
33 | for idx, row in enumerate(
34 | open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
35 | annotations = row.split()
36 | self._image_fnames.append(annotations[0])
37 | # assert self._image_fnames[idx] == annotations[0]
38 | self.upper_fused_attrs.append(int(annotations[1]))
39 |
40 | assert len(self._image_fnames) == len(self.upper_fused_attrs)
41 |
42 | assert os.path.exists(f'{ann_dir}/lower_fused.txt')
43 | for idx, row in enumerate(
44 | open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')):
45 | annotations = row.split()
46 | assert self._image_fnames[idx] == annotations[0]
47 | self.lower_fused_attrs.append(int(annotations[1]))
48 |
49 | assert len(self._image_fnames) == len(self.lower_fused_attrs)
50 |
51 | assert os.path.exists(f'{ann_dir}/outer_fused.txt')
52 | for idx, row in enumerate(
53 | open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')):
54 | annotations = row.split()
55 | assert self._image_fnames[idx] == annotations[0]
56 | self.outer_fused_attrs.append(int(annotations[1]))
57 |
58 | assert len(self._image_fnames) == len(self.outer_fused_attrs)
59 |
60 | # remove the overlapping item between upper cls and lower cls
61 | # cls 21 can appear with upper clothes
62 | # cls 4 can appear with lower clothes
63 | self.upper_cls = [1., 4.]
64 | self.lower_cls = [3., 5., 21.]
65 | self.outer_cls = [2.]
66 | self.other_cls = [
67 | 11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15.,
68 | 14., 13., 0., 6.
69 | ]
70 |
71 | def _open_file(self, path_prefix, fname):
72 | return open(os.path.join(path_prefix, fname), 'rb')
73 |
74 | def _load_raw_image(self, raw_idx):
75 | fname = self._image_fnames[raw_idx]
76 | with self._open_file(self._img_path, fname) as f:
77 | image = Image.open(f)
78 | if self.downsample_factor != 1:
79 | width, height = image.size
80 | width = width // self.downsample_factor
81 | height = height // self.downsample_factor
82 | image = image.resize(
83 | size=(width, height), resample=Image.LANCZOS)
84 | image = np.array(image)
85 | if image.ndim == 2:
86 | image = image[:, :, np.newaxis] # HW => HWC
87 | image = image.transpose(2, 0, 1) # HWC => CHW
88 | return image
89 |
90 | def _load_densepose(self, raw_idx):
91 | fname = self._image_fnames[raw_idx]
92 | fname = f'{fname[:-4]}_densepose.png'
93 | with self._open_file(self._densepose_path, fname) as f:
94 | densepose = Image.open(f)
95 | if self.downsample_factor != 1:
96 | width, height = densepose.size
97 | width = width // self.downsample_factor
98 | height = height // self.downsample_factor
99 | densepose = densepose.resize(
100 | size=(width, height), resample=Image.NEAREST)
101 | # channel-wise IUV order, [3, H, W]
102 | densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
103 | return densepose.astype(np.float32)
104 |
105 | def _load_segm(self, raw_idx):
106 | fname = self._image_fnames[raw_idx]
107 | fname = f'{fname[:-4]}_segm.png'
108 | with self._open_file(self._segm_path, fname) as f:
109 | segm = Image.open(f)
110 | if self.downsample_factor != 1:
111 | width, height = segm.size
112 | width = width // self.downsample_factor
113 | height = height // self.downsample_factor
114 | segm = segm.resize(
115 | size=(width, height), resample=Image.NEAREST)
116 | segm = np.array(segm)
117 | segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
118 | return segm.astype(np.float32)
119 |
120 | def __getitem__(self, index):
121 | image = self._load_raw_image(index)
122 | pose = self._load_densepose(index)
123 | segm = self._load_segm(index)
124 |
125 | if self.xflip and random.random() > 0.5:
126 | assert image.ndim == 3 # CHW
127 | image = image[:, :, ::-1].copy()
128 | pose = pose[:, :, ::-1].copy()
129 | segm = segm[:, :, ::-1].copy()
130 |
131 | image = torch.from_numpy(image)
132 | segm = torch.from_numpy(segm)
133 |
134 | upper_fused_attr = self.upper_fused_attrs[index]
135 | lower_fused_attr = self.lower_fused_attrs[index]
136 | outer_fused_attr = self.outer_fused_attrs[index]
137 |
138 | # mask 0: denotes the common codebook,
139 | # mask (attr + 1): denotes the texture-specific codebook
140 | mask = torch.zeros_like(segm)
141 | if upper_fused_attr != 17:
142 | for cls in self.upper_cls:
143 | mask[segm == cls] = upper_fused_attr + 1
144 |
145 | if lower_fused_attr != 17:
146 | for cls in self.lower_cls:
147 | mask[segm == cls] = lower_fused_attr + 1
148 |
149 | if outer_fused_attr != 17:
150 | for cls in self.outer_cls:
151 | mask[segm == cls] = outer_fused_attr + 1
152 |
153 | pose = pose / 12. - 1
154 | image = image / 127.5 - 1
155 |
156 | return_dict = {
157 | 'image': image,
158 | 'densepose': pose,
159 | 'segm': segm,
160 | 'texture_mask': mask,
161 | 'img_name': self._image_fnames[index]
162 | }
163 |
164 | return return_dict
165 |
166 | def __len__(self):
167 | return len(self._image_fnames)
168 |
--------------------------------------------------------------------------------
/environment/text2human_env.yaml:
--------------------------------------------------------------------------------
1 | name: text2human
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - astroid=2.5=py36h06a4308_1
9 | - blas=1.0=mkl
10 | - brotlipy=0.7.0=py36h7b6447c_1000
11 | - ca-certificates=2021.10.26=h06a4308_2
12 | - certifi=2021.5.30=py36h06a4308_0
13 | - cffi=1.14.3=py36he30daa8_0
14 | - chardet=3.0.4=py36_1003
15 | - click=8.0.3=pyhd3eb1b0_0
16 | - cryptography=3.1.1=py36h1ba5d50_0
17 | - cudatoolkit=10.1.243=h6bb024c_0
18 | - dataclasses=0.8=pyh4f3eec9_6
19 | - dbus=1.13.18=hb2f20db_0
20 | - expat=2.2.10=he6710b0_2
21 | - filelock=3.4.0=pyhd3eb1b0_0
22 | - fontconfig=2.13.0=h9420a91_0
23 | - freetype=2.10.4=h5ab3b9f_0
24 | - glib=2.56.2=hd408876_0
25 | - gst-plugins-base=1.14.0=hbbd80ab_1
26 | - gstreamer=1.14.0=hb453b48_1
27 | - icu=58.2=he6710b0_3
28 | - idna=2.10=py_0
29 | - importlib-metadata=4.8.1=py36h06a4308_0
30 | - importlib_metadata=4.8.1=hd3eb1b0_0
31 | - intel-openmp=2020.2=254
32 | - isort=5.7.0=pyhd3eb1b0_0
33 | - joblib=1.0.1=pyhd3eb1b0_0
34 | - jpeg=9b=habf39ab_1
35 | - lazy-object-proxy=1.5.2=py36h27cfd23_0
36 | - lcms2=2.11=h396b838_0
37 | - ld_impl_linux-64=2.33.1=h53a641e_7
38 | - libffi=3.3=he6710b0_2
39 | - libgcc-ng=9.1.0=hdf63c60_0
40 | - libpng=1.6.37=hbc83047_0
41 | - libprotobuf=3.17.2=h4ff587b_1
42 | - libstdcxx-ng=9.1.0=hdf63c60_0
43 | - libtiff=4.2.0=h3942068_0
44 | - libuuid=1.0.3=h1bed415_2
45 | - libuv=1.40.0=h7b6447c_0
46 | - libwebp-base=1.2.0=h27cfd23_0
47 | - libxcb=1.14=h7b6447c_0
48 | - libxml2=2.9.10=hb55368b_3
49 | - lz4-c=1.9.3=h2531618_0
50 | - mccabe=0.6.1=py36_1
51 | - mkl=2020.2=256
52 | - mkl-service=2.3.0=py36he8ac12f_0
53 | - mkl_fft=1.3.0=py36h54f3939_0
54 | - mkl_random=1.1.1=py36h0573a6f_0
55 | - ncurses=6.2=he6710b0_1
56 | - ninja=1.10.2=h5e70eb0_2
57 | - numpy=1.19.2=py36h54aff64_0
58 | - numpy-base=1.19.2=py36hfa32c7d_0
59 | - olefile=0.46=py36_0
60 | - openssl=1.1.1m=h7f8727e_0
61 | - packaging=21.3=pyhd3eb1b0_0
62 | - pcre=8.44=he6710b0_0
63 | - pillow=8.1.2=py36he98fc37_0
64 | - pip=21.0.1=py36h06a4308_0
65 | - protobuf=3.17.2=py36h295c915_0
66 | - pycparser=2.20=py_2
67 | - pylint=2.7.2=py36h06a4308_1
68 | - pyopenssl=19.1.0=py_1
69 | - pyqt=5.9.2=py36h05f1152_2
70 | - pysocks=1.7.1=py36_0
71 | - python=3.6.13=hdb3f193_0
72 | - pytorch=1.7.1=py3.6_cuda10.1.243_cudnn7.6.3_0
73 | - qt=5.9.7=h5867ecd_1
74 | - readline=8.1=h27cfd23_0
75 | - regex=2021.8.3=py36h7f8727e_0
76 | - requests=2.24.0=py_0
77 | - setuptools=52.0.0=py36h06a4308_0
78 | - sip=4.19.8=py36hf484d3e_0
79 | - six=1.15.0=py36h06a4308_0
80 | - sqlite=3.35.2=hdfb4753_0
81 | - tk=8.6.10=hbc83047_0
82 | - toml=0.10.2=pyhd3eb1b0_0
83 | - torchvision=0.8.2=py36_cu101
84 | - tqdm=4.62.3=pyhd3eb1b0_1
85 | - typed-ast=1.4.2=py36h27cfd23_1
86 | - typing-extensions=3.10.0.2=hd3eb1b0_0
87 | - typing_extensions=3.10.0.2=pyh06a4308_0
88 | - urllib3=1.25.11=py_0
89 | - wheel=0.36.2=pyhd3eb1b0_0
90 | - wrapt=1.12.1=py36h7b6447c_1
91 | - xz=5.2.5=h7b6447c_0
92 | - yaml=0.2.5=h7b6447c_0
93 | - zipp=3.6.0=pyhd3eb1b0_0
94 | - zlib=1.2.11=h7b6447c_3
95 | - zstd=1.4.5=h9ceee32_0
96 | - pip:
97 | - addict==2.4.0
98 | - cycler==0.11.0
99 | - einops==0.4.0
100 | - kiwisolver==1.3.1
101 | - matplotlib==3.3.4
102 | - huggingface-hub==0.4.0
103 | - nltk==3.6.7
104 | - opencv-python==4.5.5.62
105 | - pyparsing==3.0.7
106 | - python-dateutil==2.8.2
107 | - pyyaml==6.0
108 | - scikit-learn==0.24.2
109 | - scipy==1.5.4
110 | - sentencepiece==0.1.96
111 | - terminaltables==3.1.10
112 | - threadpoolctl==3.0.0
113 | - yapf==0.32.0
114 | - lpips==0.1.4
115 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import importlib
3 | import logging
4 | import os.path as osp
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0]
12 | for v in glob.glob(f'{model_folder}/*_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamically instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = logging.getLogger('base')
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/models/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/models/archs/__init__.py
--------------------------------------------------------------------------------
/models/archs/shape_attr_embedding_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class ShapeAttrEmbedding(nn.Module):
7 |
8 | def __init__(self, dim, out_dim, cls_num_list):
9 | super(ShapeAttrEmbedding, self).__init__()
10 |
11 | for idx, cls_num in enumerate(cls_num_list):
12 | setattr(
13 | self, f'attr_{idx}',
14 | nn.Sequential(
15 | nn.Linear(cls_num, dim), nn.LeakyReLU(),
16 | nn.Linear(dim, dim)))
17 | self.cls_num_list = cls_num_list
18 | self.attr_num = len(cls_num_list)
19 | self.fusion = nn.Sequential(
20 | nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
21 | nn.Linear(out_dim, out_dim))
22 |
23 | def forward(self, attr):
24 | attr_embedding_list = []
25 | for idx in range(self.attr_num):
26 | attr_embed_fc = getattr(self, f'attr_{idx}')
27 | attr_embedding_list.append(
28 | attr_embed_fc(
29 | F.one_hot(
30 | attr[:, idx],
31 | num_classes=self.cls_num_list[idx]).to(torch.float32)))
32 | attr_embedding = torch.cat(attr_embedding_list, dim=1)
33 | attr_embedding = self.fusion(attr_embedding)
34 |
35 | return attr_embedding
36 |
--------------------------------------------------------------------------------
/models/archs/transformer_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class CausalSelfAttention(nn.Module):
10 | """
11 | A vanilla multi-head masked self-attention layer with a projection at the end.
12 | It is possible to use torch.nn.MultiheadAttention here but I am including an
13 | explicit implementation here to show that there is nothing too scary here.
14 | """
15 |
16 | def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
17 | latent_shape, sampler):
18 | super().__init__()
19 | assert bert_n_emb % bert_n_head == 0
20 | # key, query, value projections for all heads
21 | self.key = nn.Linear(bert_n_emb, bert_n_emb)
22 | self.query = nn.Linear(bert_n_emb, bert_n_emb)
23 | self.value = nn.Linear(bert_n_emb, bert_n_emb)
24 | # regularization
25 | self.attn_drop = nn.Dropout(attn_pdrop)
26 | self.resid_drop = nn.Dropout(resid_pdrop)
27 | # output projection
28 | self.proj = nn.Linear(bert_n_emb, bert_n_emb)
29 | self.n_head = bert_n_head
30 | self.causal = True if sampler == 'autoregressive' else False
31 | if self.causal:
32 | block_size = np.prod(latent_shape)
33 | mask = torch.tril(torch.ones(block_size, block_size))
34 | self.register_buffer("mask", mask.view(1, 1, block_size,
35 | block_size))
36 |
37 | def forward(self, x, layer_past=None):
38 | B, T, C = x.size()
39 |
40 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
41 | k = self.key(x).view(B, T, self.n_head,
42 | C // self.n_head).transpose(1,
43 | 2) # (B, nh, T, hs)
44 | q = self.query(x).view(B, T, self.n_head,
45 | C // self.n_head).transpose(1,
46 | 2) # (B, nh, T, hs)
47 | v = self.value(x).view(B, T, self.n_head,
48 | C // self.n_head).transpose(1,
49 | 2) # (B, nh, T, hs)
50 |
51 | present = torch.stack((k, v))
52 | if self.causal and layer_past is not None:
53 | past_key, past_value = layer_past
54 | k = torch.cat((past_key, k), dim=-2)
55 | v = torch.cat((past_value, v), dim=-2)
56 |
57 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
58 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
59 |
60 | if self.causal and layer_past is None:
61 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
62 |
63 | att = F.softmax(att, dim=-1)
64 | att = self.attn_drop(att)
65 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
66 | # re-assemble all head outputs side by side
67 | y = y.transpose(1, 2).contiguous().view(B, T, C)
68 |
69 | # output projection
70 | y = self.resid_drop(self.proj(y))
71 | return y, present
72 |
73 |
74 | class Block(nn.Module):
75 | """ an unassuming Transformer block """
76 |
77 | def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
78 | latent_shape, sampler):
79 | super().__init__()
80 | self.ln1 = nn.LayerNorm(bert_n_emb)
81 | self.ln2 = nn.LayerNorm(bert_n_emb)
82 | self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
83 | resid_pdrop, latent_shape, sampler)
84 | self.mlp = nn.Sequential(
85 | nn.Linear(bert_n_emb, 4 * bert_n_emb),
86 | nn.GELU(), # nice
87 | nn.Linear(4 * bert_n_emb, bert_n_emb),
88 | nn.Dropout(resid_pdrop),
89 | )
90 |
91 | def forward(self, x, layer_past=None, return_present=False):
92 |
93 | attn, present = self.attn(self.ln1(x), layer_past)
94 | x = x + attn
95 | x = x + self.mlp(self.ln2(x))
96 |
97 | if layer_past is not None or return_present:
98 | return x, present
99 | return x
100 |
101 |
102 | class Transformer(nn.Module):
103 | """ the full GPT language model, with a context size of block_size """
104 |
105 | def __init__(self,
106 | codebook_size,
107 | segm_codebook_size,
108 | bert_n_emb,
109 | bert_n_layers,
110 | bert_n_head,
111 | block_size,
112 | latent_shape,
113 | embd_pdrop,
114 | resid_pdrop,
115 | attn_pdrop,
116 | sampler='absorbing'):
117 | super().__init__()
118 |
119 | self.vocab_size = codebook_size + 1
120 | self.n_embd = bert_n_emb
121 | self.block_size = block_size
122 | self.n_layers = bert_n_layers
123 | self.codebook_size = codebook_size
124 | self.segm_codebook_size = segm_codebook_size
125 | self.causal = sampler == 'autoregressive'
126 | if self.causal:
127 | self.vocab_size = codebook_size
128 |
129 | self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
130 | self.pos_emb = nn.Parameter(
131 | torch.zeros(1, self.block_size, self.n_embd))
132 | self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
133 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
134 | self.drop = nn.Dropout(embd_pdrop)
135 |
136 | # transformer
137 | self.blocks = nn.Sequential(*[
138 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
139 | latent_shape, sampler) for _ in range(self.n_layers)
140 | ])
141 | # decoder head
142 | self.ln_f = nn.LayerNorm(self.n_embd)
143 | self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
144 |
145 | def get_block_size(self):
146 | return self.block_size
147 |
148 | def _init_weights(self, module):
149 | if isinstance(module, (nn.Linear, nn.Embedding)):
150 | module.weight.data.normal_(mean=0.0, std=0.02)
151 | if isinstance(module, nn.Linear) and module.bias is not None:
152 | module.bias.data.zero_()
153 | elif isinstance(module, nn.LayerNorm):
154 | module.bias.data.zero_()
155 | module.weight.data.fill_(1.0)
156 |
157 | def forward(self, idx, segm_tokens, t=None):
158 | # each index maps to a (learnable) vector
159 | token_embeddings = self.tok_emb(idx)
160 |
161 | segm_embeddings = self.segm_emb(segm_tokens)
162 |
163 | if self.causal:
164 | token_embeddings = torch.cat((self.start_tok.repeat(
165 | token_embeddings.size(0), 1, 1), token_embeddings),
166 | dim=1)
167 |
168 | t = token_embeddings.shape[1]
169 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
170 | # each position maps to a (learnable) vector
171 |
172 | position_embeddings = self.pos_emb[:, :t, :]
173 |
174 | x = token_embeddings + position_embeddings + segm_embeddings
175 | x = self.drop(x)
176 | for block in self.blocks:
177 | x = block(x)
178 | x = self.ln_f(x)
179 | logits = self.head(x)
180 |
181 | return logits
182 |
183 |
184 | class TransformerMultiHead(nn.Module):
185 | """ the full GPT language model, with a context size of block_size """
186 |
187 | def __init__(self,
188 | codebook_size,
189 | segm_codebook_size,
190 | texture_codebook_size,
191 | bert_n_emb,
192 | bert_n_layers,
193 | bert_n_head,
194 | block_size,
195 | latent_shape,
196 | embd_pdrop,
197 | resid_pdrop,
198 | attn_pdrop,
199 | num_head,
200 | sampler='absorbing'):
201 | super().__init__()
202 |
203 | self.vocab_size = codebook_size + 1
204 | self.n_embd = bert_n_emb
205 | self.block_size = block_size
206 | self.n_layers = bert_n_layers
207 | self.codebook_size = codebook_size
208 | self.segm_codebook_size = segm_codebook_size
209 | self.texture_codebook_size = texture_codebook_size
210 | self.causal = sampler == 'autoregressive'
211 | if self.causal:
212 | self.vocab_size = codebook_size
213 |
214 | self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
215 | self.pos_emb = nn.Parameter(
216 | torch.zeros(1, self.block_size, self.n_embd))
217 | self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
218 | self.texture_emb = nn.Embedding(self.texture_codebook_size,
219 | self.n_embd)
220 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
221 | self.drop = nn.Dropout(embd_pdrop)
222 |
223 | # transformer
224 | self.blocks = nn.Sequential(*[
225 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
226 | latent_shape, sampler) for _ in range(self.n_layers)
227 | ])
228 | # decoder head
229 | self.num_head = num_head
230 | self.head_class_num = codebook_size // self.num_head
231 | self.ln_f = nn.LayerNorm(self.n_embd)
232 | self.head_list = nn.ModuleList([
233 | nn.Linear(self.n_embd, self.head_class_num, bias=False)
234 | for _ in range(self.num_head)
235 | ])
236 |
237 | def get_block_size(self):
238 | return self.block_size
239 |
240 | def _init_weights(self, module):
241 | if isinstance(module, (nn.Linear, nn.Embedding)):
242 | module.weight.data.normal_(mean=0.0, std=0.02)
243 | if isinstance(module, nn.Linear) and module.bias is not None:
244 | module.bias.data.zero_()
245 | elif isinstance(module, nn.LayerNorm):
246 | module.bias.data.zero_()
247 | module.weight.data.fill_(1.0)
248 |
249 | def forward(self, idx, segm_tokens, texture_tokens, t=None):
250 | # each index maps to a (learnable) vector
251 | token_embeddings = self.tok_emb(idx)
252 | segm_embeddings = self.segm_emb(segm_tokens)
253 | texture_embeddings = self.texture_emb(texture_tokens)
254 |
255 | if self.causal:
256 | token_embeddings = torch.cat((self.start_tok.repeat(
257 | token_embeddings.size(0), 1, 1), token_embeddings),
258 | dim=1)
259 |
260 | t = token_embeddings.shape[1]
261 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
262 | # each position maps to a (learnable) vector
263 |
264 | position_embeddings = self.pos_emb[:, :t, :]
265 |
266 | x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
267 | x = self.drop(x)
268 | for block in self.blocks:
269 | x = block(x)
270 | x = self.ln_f(x)
271 | logits_list = [self.head_list[i](x) for i in range(self.num_head)]
272 |
273 | return logits_list
274 |
--------------------------------------------------------------------------------
/models/hierarchy_inference_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torchvision.utils import save_image
8 |
9 | from models.archs.fcn_arch import MultiHeadFCNHead
10 | from models.archs.unet_arch import UNet
11 | from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
12 | VectorQuantizerSpatialTextureAware,
13 | VectorQuantizerTexture)
14 | from models.losses.accuracy import accuracy
15 | from models.losses.cross_entropy_loss import CrossEntropyLoss
16 |
17 | logger = logging.getLogger('base')
18 |
19 |
20 | class VQGANTextureAwareSpatialHierarchyInferenceModel():
21 |
22 | def __init__(self, opt):
23 | self.opt = opt
24 | self.device = torch.device('cuda')
25 | self.is_train = opt['is_train']
26 |
27 | self.top_encoder = Encoder(
28 | ch=opt['top_ch'],
29 | num_res_blocks=opt['top_num_res_blocks'],
30 | attn_resolutions=opt['top_attn_resolutions'],
31 | ch_mult=opt['top_ch_mult'],
32 | in_channels=opt['top_in_channels'],
33 | resolution=opt['top_resolution'],
34 | z_channels=opt['top_z_channels'],
35 | double_z=opt['top_double_z'],
36 | dropout=opt['top_dropout']).to(self.device)
37 | self.decoder = Decoder(
38 | in_channels=opt['top_in_channels'],
39 | resolution=opt['top_resolution'],
40 | z_channels=opt['top_z_channels'],
41 | ch=opt['top_ch'],
42 | out_ch=opt['top_out_ch'],
43 | num_res_blocks=opt['top_num_res_blocks'],
44 | attn_resolutions=opt['top_attn_resolutions'],
45 | ch_mult=opt['top_ch_mult'],
46 | dropout=opt['top_dropout'],
47 | resamp_with_conv=True,
48 | give_pre_end=False).to(self.device)
49 | self.top_quantize = VectorQuantizerTexture(
50 | 1024, opt['embed_dim'], beta=0.25).to(self.device)
51 | self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
52 | opt['embed_dim'],
53 | 1).to(self.device)
54 | self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
55 | opt["top_z_channels"],
56 | 1).to(self.device)
57 | self.load_top_pretrain_models()
58 |
59 | self.bot_encoder = Encoder(
60 | ch=opt['bot_ch'],
61 | num_res_blocks=opt['bot_num_res_blocks'],
62 | attn_resolutions=opt['bot_attn_resolutions'],
63 | ch_mult=opt['bot_ch_mult'],
64 | in_channels=opt['bot_in_channels'],
65 | resolution=opt['bot_resolution'],
66 | z_channels=opt['bot_z_channels'],
67 | double_z=opt['bot_double_z'],
68 | dropout=opt['bot_dropout']).to(self.device)
69 | self.bot_decoder_res = DecoderRes(
70 | in_channels=opt['bot_in_channels'],
71 | resolution=opt['bot_resolution'],
72 | z_channels=opt['bot_z_channels'],
73 | ch=opt['bot_ch'],
74 | num_res_blocks=opt['bot_num_res_blocks'],
75 | ch_mult=opt['bot_ch_mult'],
76 | dropout=opt['bot_dropout'],
77 | give_pre_end=False).to(self.device)
78 | self.bot_quantize = VectorQuantizerSpatialTextureAware(
79 | opt['bot_n_embed'],
80 | opt['embed_dim'],
81 | beta=0.25,
82 | spatial_size=opt['codebook_spatial_size']).to(self.device)
83 | self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
84 | opt['embed_dim'],
85 | 1).to(self.device)
86 | self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
87 | opt["bot_z_channels"],
88 | 1).to(self.device)
89 |
90 | self.load_bot_pretrain_network()
91 |
92 | self.guidance_encoder = UNet(
93 | in_channels=opt['encoder_in_channels']).to(self.device)
94 | self.index_decoder = MultiHeadFCNHead(
95 | in_channels=opt['fc_in_channels'],
96 | in_index=opt['fc_in_index'],
97 | channels=opt['fc_channels'],
98 | num_convs=opt['fc_num_convs'],
99 | concat_input=opt['fc_concat_input'],
100 | dropout_ratio=opt['fc_dropout_ratio'],
101 | num_classes=opt['fc_num_classes'],
102 | align_corners=opt['fc_align_corners'],
103 | num_head=18).to(self.device)
104 |
105 | self.init_training_settings()
106 |
107 | def init_training_settings(self):
108 | optim_params = []
109 | for v in self.guidance_encoder.parameters():
110 | if v.requires_grad:
111 | optim_params.append(v)
112 | for v in self.index_decoder.parameters():
113 | if v.requires_grad:
114 | optim_params.append(v)
115 | # set up optimizers
116 | if self.opt['optimizer'] == 'Adam':
117 | self.optimizer = torch.optim.Adam(
118 | optim_params,
119 | self.opt['lr'],
120 | weight_decay=self.opt['weight_decay'])
121 | elif self.opt['optimizer'] == 'SGD':
122 | self.optimizer = torch.optim.SGD(
123 | optim_params,
124 | self.opt['lr'],
125 | momentum=self.opt['momentum'],
126 | weight_decay=self.opt['weight_decay'])
127 | self.log_dict = OrderedDict()
128 | if self.opt['loss_function'] == 'cross_entropy':
129 | self.loss_func = CrossEntropyLoss().to(self.device)
130 |
131 | def load_top_pretrain_models(self):
132 | # load pretrained vqgan for segmentation mask
133 | top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
134 | self.top_encoder.load_state_dict(
135 | top_vae_checkpoint['encoder'], strict=True)
136 | self.decoder.load_state_dict(
137 | top_vae_checkpoint['decoder'], strict=True)
138 | self.top_quantize.load_state_dict(
139 | top_vae_checkpoint['quantize'], strict=True)
140 | self.top_quant_conv.load_state_dict(
141 | top_vae_checkpoint['quant_conv'], strict=True)
142 | self.top_post_quant_conv.load_state_dict(
143 | top_vae_checkpoint['post_quant_conv'], strict=True)
144 | self.top_encoder.eval()
145 | self.top_quantize.eval()
146 | self.top_quant_conv.eval()
147 | self.top_post_quant_conv.eval()
148 |
149 | def load_bot_pretrain_network(self):
150 | checkpoint = torch.load(self.opt['bot_vae_path'])
151 | self.bot_encoder.load_state_dict(
152 | checkpoint['bot_encoder'], strict=True)
153 | self.bot_decoder_res.load_state_dict(
154 | checkpoint['bot_decoder_res'], strict=True)
155 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
156 | self.bot_quantize.load_state_dict(
157 | checkpoint['bot_quantize'], strict=True)
158 | self.bot_quant_conv.load_state_dict(
159 | checkpoint['bot_quant_conv'], strict=True)
160 | self.bot_post_quant_conv.load_state_dict(
161 | checkpoint['bot_post_quant_conv'], strict=True)
162 |
163 | self.bot_encoder.eval()
164 | self.bot_decoder_res.eval()
165 | self.decoder.eval()
166 | self.bot_quantize.eval()
167 | self.bot_quant_conv.eval()
168 | self.bot_post_quant_conv.eval()
169 |
170 | def top_encode(self, x, mask):
171 | h = self.top_encoder(x)
172 | h = self.top_quant_conv(h)
173 | quant, _, _ = self.top_quantize(h, mask)
174 | quant = self.top_post_quant_conv(quant)
175 |
176 | return quant, quant
177 |
178 | def feed_data(self, data):
179 | self.image = data['image'].to(self.device)
180 | self.texture_mask = data['texture_mask'].float().to(self.device)
181 | self.get_gt_indices()
182 |
183 | self.texture_tokens = F.interpolate(
184 | self.texture_mask, size=(32, 16),
185 | mode='nearest').view(self.image.size(0), -1).long()
186 |
187 | def bot_encode(self, x, mask):
188 | h = self.bot_encoder(x)
189 | h = self.bot_quant_conv(h)
190 | _, _, (_, _, indices_list) = self.bot_quantize(h, mask)
191 |
192 | return indices_list
193 |
194 | def get_gt_indices(self):
195 | self.quant_t, self.feature_t = self.top_encode(self.image,
196 | self.texture_mask)
197 | self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
198 |
199 | def index_to_image(self, index_bottom_list, texture_mask):
200 | quant_b = self.bot_quantize.get_codebook_entry(
201 | index_bottom_list, texture_mask,
202 | (index_bottom_list[0].size(0), index_bottom_list[0].size(1),
203 | index_bottom_list[0].size(2),
204 | self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
205 | quant_b = self.bot_post_quant_conv(quant_b)
206 | bot_dec_res = self.bot_decoder_res(quant_b)
207 |
208 | dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
209 |
210 | return dec
211 |
212 | def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
213 | rec_img = self.index_to_image(rec_img_index, texture_mask)
214 | pred_img = self.index_to_image(pred_img_index, texture_mask)
215 |
216 | base_img = self.decoder(self.quant_t)
217 | img_cat = torch.cat([
218 | self.image,
219 | rec_img,
220 | base_img,
221 | pred_img,
222 | ], dim=3).detach()
223 | img_cat = ((img_cat + 1) / 2)
224 | img_cat = img_cat.clamp_(0, 1)
225 | save_image(img_cat, save_path, nrow=1, padding=4)
226 |
227 | def optimize_parameters(self):
228 | self.guidance_encoder.train()
229 | self.index_decoder.train()
230 |
231 | self.feature_enc = self.guidance_encoder(self.feature_t)
232 | self.memory_logits_list = self.index_decoder(self.feature_enc)
233 |
234 | loss = 0
235 | for i in range(18):
236 | loss += self.loss_func(
237 | self.memory_logits_list[i],
238 | self.gt_indices_list[i],
239 | ignore_index=-1)
240 |
241 | self.optimizer.zero_grad()
242 | loss.backward()
243 | self.optimizer.step()
244 |
245 | self.log_dict['loss_total'] = loss
246 |
247 | def inference(self, data_loader, save_dir):
248 | self.guidance_encoder.eval()
249 | self.index_decoder.eval()
250 |
251 | acc = 0
252 | num = 0
253 |
254 | for _, data in enumerate(data_loader):
255 | self.feed_data(data)
256 | img_name = data['img_name']
257 |
258 | num += self.image.size(0)
259 |
260 | texture_mask_flatten = self.texture_tokens.view(-1)
261 | min_encodings_indices_list = [
262 | torch.full(
263 | texture_mask_flatten.size(),
264 | fill_value=-1,
265 | dtype=torch.long,
266 | device=texture_mask_flatten.device) for _ in range(18)
267 | ]
268 | with torch.no_grad():
269 | self.feature_enc = self.guidance_encoder(self.feature_t)
270 | memory_logits_list = self.index_decoder(self.feature_enc)
271 | # memory_indices_pred = memory_logits.argmax(dim=1)
272 | batch_acc = 0
273 | for codebook_idx, memory_logits in enumerate(memory_logits_list):
274 | region_of_interest = texture_mask_flatten == codebook_idx
275 | if torch.sum(region_of_interest) > 0:
276 | memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
277 | batch_acc += torch.sum(
278 | memory_indices_pred[region_of_interest] ==
279 | self.gt_indices_list[codebook_idx].view(
280 | -1)[region_of_interest])
281 | memory_indices_pred = memory_indices_pred
282 | min_encodings_indices_list[codebook_idx][
283 | region_of_interest] = memory_indices_pred[
284 | region_of_interest]
285 | min_encodings_indices_return_list = [
286 | min_encodings_indices.view(self.gt_indices_list[0].size())
287 | for min_encodings_indices in min_encodings_indices_list
288 | ]
289 | batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
290 | ) * self.image.size(0)
291 | acc += batch_acc
292 | self.get_vis(min_encodings_indices_return_list,
293 | self.gt_indices_list, self.texture_mask,
294 | f'{save_dir}/{img_name[0]}')
295 |
296 | self.guidance_encoder.train()
297 | self.index_decoder.train()
298 | return (acc / num).item()
299 |
300 | def load_network(self):
301 | checkpoint = torch.load(self.opt['pretrained_models'])
302 | self.guidance_encoder.load_state_dict(
303 | checkpoint['guidance_encoder'], strict=True)
304 | self.guidance_encoder.eval()
305 |
306 | self.index_decoder.load_state_dict(
307 | checkpoint['index_decoder'], strict=True)
308 | self.index_decoder.eval()
309 |
310 | def save_network(self, save_path):
311 | """Save networks.
312 |
313 | Args:
314 | net (nn.Module): Network to be saved.
315 | net_label (str): Network label.
316 | current_iter (int): Current iter number.
317 | """
318 |
319 | save_dict = {}
320 | save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
321 | save_dict['index_decoder'] = self.index_decoder.state_dict()
322 |
323 | torch.save(save_dict, save_path)
324 |
325 | def update_learning_rate(self, epoch):
326 | """Update learning rate.
327 |
328 | Args:
329 | current_iter (int): Current iteration.
330 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
331 | Default: -1.
332 | """
333 | lr = self.optimizer.param_groups[0]['lr']
334 |
335 | if self.opt['lr_decay'] == 'step':
336 | lr = self.opt['lr'] * (
337 | self.opt['gamma']**(epoch // self.opt['step']))
338 | elif self.opt['lr_decay'] == 'cos':
339 | lr = self.opt['lr'] * (
340 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
341 | elif self.opt['lr_decay'] == 'linear':
342 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
343 | elif self.opt['lr_decay'] == 'linear2exp':
344 | if epoch < self.opt['turning_point'] + 1:
345 | # learning rate decay as 95%
346 | # at the turning point (1 / 95% = 1.0526)
347 | lr = self.opt['lr'] * (
348 | 1 - epoch / int(self.opt['turning_point'] * 1.0526))
349 | else:
350 | lr *= self.opt['gamma']
351 | elif self.opt['lr_decay'] == 'schedule':
352 | if epoch in self.opt['schedule']:
353 | lr *= self.opt['gamma']
354 | else:
355 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
356 | # set learning rate
357 | for param_group in self.optimizer.param_groups:
358 | param_group['lr'] = lr
359 |
360 | return lr
361 |
362 | def get_current_log(self):
363 | return self.log_dict
364 |
--------------------------------------------------------------------------------
/models/hierarchy_vqgan_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from collections import OrderedDict
4 |
5 | sys.path.append('..')
6 | import lpips
7 | import torch
8 | import torch.nn.functional as F
9 | from torchvision.utils import save_image
10 |
11 | from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
12 | Encoder,
13 | VectorQuantizerSpatialTextureAware,
14 | VectorQuantizerTexture)
15 | from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
16 | calculate_adaptive_weight, hinge_d_loss)
17 |
18 |
19 | class HierarchyVQSpatialTextureAwareModel():
20 |
21 | def __init__(self, opt):
22 | self.opt = opt
23 | self.device = torch.device('cuda')
24 | self.top_encoder = Encoder(
25 | ch=opt['top_ch'],
26 | num_res_blocks=opt['top_num_res_blocks'],
27 | attn_resolutions=opt['top_attn_resolutions'],
28 | ch_mult=opt['top_ch_mult'],
29 | in_channels=opt['top_in_channels'],
30 | resolution=opt['top_resolution'],
31 | z_channels=opt['top_z_channels'],
32 | double_z=opt['top_double_z'],
33 | dropout=opt['top_dropout']).to(self.device)
34 | self.decoder = Decoder(
35 | in_channels=opt['top_in_channels'],
36 | resolution=opt['top_resolution'],
37 | z_channels=opt['top_z_channels'],
38 | ch=opt['top_ch'],
39 | out_ch=opt['top_out_ch'],
40 | num_res_blocks=opt['top_num_res_blocks'],
41 | attn_resolutions=opt['top_attn_resolutions'],
42 | ch_mult=opt['top_ch_mult'],
43 | dropout=opt['top_dropout'],
44 | resamp_with_conv=True,
45 | give_pre_end=False).to(self.device)
46 | self.top_quantize = VectorQuantizerTexture(
47 | 1024, opt['embed_dim'], beta=0.25).to(self.device)
48 | self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
49 | opt['embed_dim'],
50 | 1).to(self.device)
51 | self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
52 | opt["top_z_channels"],
53 | 1).to(self.device)
54 | self.load_top_pretrain_models()
55 |
56 | self.bot_encoder = Encoder(
57 | ch=opt['bot_ch'],
58 | num_res_blocks=opt['bot_num_res_blocks'],
59 | attn_resolutions=opt['bot_attn_resolutions'],
60 | ch_mult=opt['bot_ch_mult'],
61 | in_channels=opt['bot_in_channels'],
62 | resolution=opt['bot_resolution'],
63 | z_channels=opt['bot_z_channels'],
64 | double_z=opt['bot_double_z'],
65 | dropout=opt['bot_dropout']).to(self.device)
66 | self.bot_decoder_res = DecoderRes(
67 | in_channels=opt['bot_in_channels'],
68 | resolution=opt['bot_resolution'],
69 | z_channels=opt['bot_z_channels'],
70 | ch=opt['bot_ch'],
71 | num_res_blocks=opt['bot_num_res_blocks'],
72 | ch_mult=opt['bot_ch_mult'],
73 | dropout=opt['bot_dropout'],
74 | give_pre_end=False).to(self.device)
75 | self.bot_quantize = VectorQuantizerSpatialTextureAware(
76 | opt['bot_n_embed'],
77 | opt['embed_dim'],
78 | beta=0.25,
79 | spatial_size=opt['codebook_spatial_size']).to(self.device)
80 | self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
81 | opt['embed_dim'],
82 | 1).to(self.device)
83 | self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
84 | opt["bot_z_channels"],
85 | 1).to(self.device)
86 |
87 | self.disc = Discriminator(
88 | opt['n_channels'], opt['ndf'],
89 | n_layers=opt['disc_layers']).to(self.device)
90 | self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
91 | self.perceptual_weight = opt['perceptual_weight']
92 | self.disc_start_step = opt['disc_start_step']
93 | self.disc_weight_max = opt['disc_weight_max']
94 | self.diff_aug = opt['diff_aug']
95 | self.policy = "color,translation"
96 |
97 | self.load_discriminator_models()
98 |
99 | self.disc.train()
100 |
101 | self.fix_decoder = opt['fix_decoder']
102 |
103 | self.init_training_settings()
104 |
105 | def load_top_pretrain_models(self):
106 | # load pretrained vqgan for segmentation mask
107 | top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
108 | self.top_encoder.load_state_dict(
109 | top_vae_checkpoint['encoder'], strict=True)
110 | self.decoder.load_state_dict(
111 | top_vae_checkpoint['decoder'], strict=True)
112 | self.top_quantize.load_state_dict(
113 | top_vae_checkpoint['quantize'], strict=True)
114 | self.top_quant_conv.load_state_dict(
115 | top_vae_checkpoint['quant_conv'], strict=True)
116 | self.top_post_quant_conv.load_state_dict(
117 | top_vae_checkpoint['post_quant_conv'], strict=True)
118 | self.top_encoder.eval()
119 | self.top_quantize.eval()
120 | self.top_quant_conv.eval()
121 | self.top_post_quant_conv.eval()
122 |
123 | def init_training_settings(self):
124 | self.log_dict = OrderedDict()
125 | self.configure_optimizers()
126 |
127 | def configure_optimizers(self):
128 | optim_params = []
129 | for v in self.bot_encoder.parameters():
130 | if v.requires_grad:
131 | optim_params.append(v)
132 | for v in self.bot_decoder_res.parameters():
133 | if v.requires_grad:
134 | optim_params.append(v)
135 | for v in self.bot_quantize.parameters():
136 | if v.requires_grad:
137 | optim_params.append(v)
138 | for v in self.bot_quant_conv.parameters():
139 | if v.requires_grad:
140 | optim_params.append(v)
141 | for v in self.bot_post_quant_conv.parameters():
142 | if v.requires_grad:
143 | optim_params.append(v)
144 | if not self.fix_decoder:
145 | for name, v in self.decoder.named_parameters():
146 | if v.requires_grad:
147 | if 'up.0' in name:
148 | optim_params.append(v)
149 | if 'up.1' in name:
150 | optim_params.append(v)
151 | if 'up.2' in name:
152 | optim_params.append(v)
153 | if 'up.3' in name:
154 | optim_params.append(v)
155 |
156 | self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
157 |
158 | self.disc_optimizer = torch.optim.Adam(
159 | self.disc.parameters(), lr=self.opt['lr'])
160 |
161 | def load_discriminator_models(self):
162 | # load pretrained vqgan for segmentation mask
163 | top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
164 | self.disc.load_state_dict(
165 | top_vae_checkpoint['discriminator'], strict=True)
166 |
167 | def save_network(self, save_path):
168 | """Save networks.
169 | """
170 |
171 | save_dict = {}
172 | save_dict['bot_encoder'] = self.bot_encoder.state_dict()
173 | save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
174 | save_dict['decoder'] = self.decoder.state_dict()
175 | save_dict['bot_quantize'] = self.bot_quantize.state_dict()
176 | save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
177 | save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
178 | )
179 | save_dict['discriminator'] = self.disc.state_dict()
180 | torch.save(save_dict, save_path)
181 |
182 | def load_network(self):
183 | checkpoint = torch.load(self.opt['pretrained_models'])
184 | self.bot_encoder.load_state_dict(
185 | checkpoint['bot_encoder'], strict=True)
186 | self.bot_decoder_res.load_state_dict(
187 | checkpoint['bot_decoder_res'], strict=True)
188 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
189 | self.bot_quantize.load_state_dict(
190 | checkpoint['bot_quantize'], strict=True)
191 | self.bot_quant_conv.load_state_dict(
192 | checkpoint['bot_quant_conv'], strict=True)
193 | self.bot_post_quant_conv.load_state_dict(
194 | checkpoint['bot_post_quant_conv'], strict=True)
195 |
196 | def optimize_parameters(self, data, step):
197 | self.bot_encoder.train()
198 | self.bot_decoder_res.train()
199 | if not self.fix_decoder:
200 | self.decoder.train()
201 | self.bot_quantize.train()
202 | self.bot_quant_conv.train()
203 | self.bot_post_quant_conv.train()
204 |
205 | loss, d_loss = self.training_step(data, step)
206 | self.optimizer.zero_grad()
207 | loss.backward()
208 | self.optimizer.step()
209 |
210 | if step > self.disc_start_step:
211 | self.disc_optimizer.zero_grad()
212 | d_loss.backward()
213 | self.disc_optimizer.step()
214 |
215 | def top_encode(self, x, mask):
216 | h = self.top_encoder(x)
217 | h = self.top_quant_conv(h)
218 | quant, _, _ = self.top_quantize(h, mask)
219 | quant = self.top_post_quant_conv(quant)
220 | return quant
221 |
222 | def bot_encode(self, x, mask):
223 | h = self.bot_encoder(x)
224 | h = self.bot_quant_conv(h)
225 | quant, emb_loss, info = self.bot_quantize(h, mask)
226 | quant = self.bot_post_quant_conv(quant)
227 | bot_dec_res = self.bot_decoder_res(quant)
228 | return bot_dec_res, emb_loss, info
229 |
230 | def decode(self, quant_top, bot_dec_res):
231 | dec = self.decoder(quant_top, bot_h=bot_dec_res)
232 | return dec
233 |
234 | def forward_step(self, input, mask):
235 | with torch.no_grad():
236 | quant_top = self.top_encode(input, mask)
237 | bot_dec_res, diff, _ = self.bot_encode(input, mask)
238 | dec = self.decode(quant_top, bot_dec_res)
239 | return dec, diff
240 |
241 | def feed_data(self, data):
242 | x = data['image'].float().to(self.device)
243 | mask = data['texture_mask'].float().to(self.device)
244 |
245 | return x, mask
246 |
247 | def training_step(self, data, step):
248 | x, mask = self.feed_data(data)
249 | xrec, codebook_loss = self.forward_step(x, mask)
250 |
251 | # get recon/perceptual loss
252 | recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
253 | p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
254 | nll_loss = recon_loss + self.perceptual_weight * p_loss
255 | nll_loss = torch.mean(nll_loss)
256 |
257 | # augment for input to discriminator
258 | if self.diff_aug:
259 | xrec = DiffAugment(xrec, policy=self.policy)
260 |
261 | # update generator
262 | logits_fake = self.disc(xrec)
263 | g_loss = -torch.mean(logits_fake)
264 | last_layer = self.decoder.conv_out.weight
265 | d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
266 | self.disc_weight_max)
267 | d_weight *= adopt_weight(1, step, self.disc_start_step)
268 | loss = nll_loss + d_weight * g_loss + codebook_loss
269 |
270 | self.log_dict["loss"] = loss
271 | self.log_dict["l1"] = recon_loss.mean().item()
272 | self.log_dict["perceptual"] = p_loss.mean().item()
273 | self.log_dict["nll_loss"] = nll_loss.item()
274 | self.log_dict["g_loss"] = g_loss.item()
275 | self.log_dict["d_weight"] = d_weight
276 | self.log_dict["codebook_loss"] = codebook_loss.item()
277 |
278 | if step > self.disc_start_step:
279 | if self.diff_aug:
280 | logits_real = self.disc(
281 | DiffAugment(x.contiguous().detach(), policy=self.policy))
282 | else:
283 | logits_real = self.disc(x.contiguous().detach())
284 | logits_fake = self.disc(xrec.contiguous().detach(
285 | )) # detach so that generator isn"t also updated
286 | d_loss = hinge_d_loss(logits_real, logits_fake)
287 | self.log_dict["d_loss"] = d_loss
288 | else:
289 | d_loss = None
290 |
291 | return loss, d_loss
292 |
293 | @torch.no_grad()
294 | def inference(self, data_loader, save_dir):
295 | self.bot_encoder.eval()
296 | self.bot_decoder_res.eval()
297 | self.decoder.eval()
298 | self.bot_quantize.eval()
299 | self.bot_quant_conv.eval()
300 | self.bot_post_quant_conv.eval()
301 |
302 | loss_total = 0
303 | num = 0
304 |
305 | for _, data in enumerate(data_loader):
306 | img_name = data['img_name'][0]
307 | x, mask = self.feed_data(data)
308 | xrec, _ = self.forward_step(x, mask)
309 |
310 | recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
311 | p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
312 | nll_loss = recon_loss + self.perceptual_weight * p_loss
313 | nll_loss = torch.mean(nll_loss)
314 | loss_total += nll_loss
315 |
316 | num += x.size(0)
317 |
318 | if x.shape[1] > 3:
319 | # colorize with random projection
320 | assert xrec.shape[1] > 3
321 | # convert logits to indices
322 | xrec = torch.argmax(xrec, dim=1, keepdim=True)
323 | xrec = F.one_hot(xrec, num_classes=x.shape[1])
324 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
325 | x = self.to_rgb(x)
326 | xrec = self.to_rgb(xrec)
327 |
328 | img_cat = torch.cat([x, xrec], dim=3).detach()
329 | img_cat = ((img_cat + 1) / 2)
330 | img_cat = img_cat.clamp_(0, 1)
331 | save_image(
332 | img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
333 |
334 | return (loss_total / num).item()
335 |
336 | def get_current_log(self):
337 | return self.log_dict
338 |
339 | def update_learning_rate(self, epoch):
340 | """Update learning rate.
341 |
342 | Args:
343 | current_iter (int): Current iteration.
344 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
345 | Default: -1.
346 | """
347 | lr = self.optimizer.param_groups[0]['lr']
348 |
349 | if self.opt['lr_decay'] == 'step':
350 | lr = self.opt['lr'] * (
351 | self.opt['gamma']**(epoch // self.opt['step']))
352 | elif self.opt['lr_decay'] == 'cos':
353 | lr = self.opt['lr'] * (
354 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
355 | elif self.opt['lr_decay'] == 'linear':
356 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
357 | elif self.opt['lr_decay'] == 'linear2exp':
358 | if epoch < self.opt['turning_point'] + 1:
359 | # learning rate decay as 95%
360 | # at the turning point (1 / 95% = 1.0526)
361 | lr = self.opt['lr'] * (
362 | 1 - epoch / int(self.opt['turning_point'] * 1.0526))
363 | else:
364 | lr *= self.opt['gamma']
365 | elif self.opt['lr_decay'] == 'schedule':
366 | if epoch in self.opt['schedule']:
367 | lr *= self.opt['gamma']
368 | else:
369 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
370 | # set learning rate
371 | for param_group in self.optimizer.param_groups:
372 | param_group['lr'] = lr
373 |
374 | return lr
375 |
--------------------------------------------------------------------------------
/models/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/models/losses/__init__.py
--------------------------------------------------------------------------------
/models/losses/accuracy.py:
--------------------------------------------------------------------------------
1 | def accuracy(pred, target, topk=1, thresh=None):
2 | """Calculate accuracy according to the prediction and target.
3 |
4 | Args:
5 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
6 | target (torch.Tensor): The target of each prediction, shape (N, , ...)
7 | topk (int | tuple[int], optional): If the predictions in ``topk``
8 | matches the target, the predictions will be regarded as
9 | correct ones. Defaults to 1.
10 | thresh (float, optional): If not None, predictions with scores under
11 | this threshold are considered incorrect. Default to None.
12 |
13 | Returns:
14 | float | tuple[float]: If the input ``topk`` is a single integer,
15 | the function will return a single float as accuracy. If
16 | ``topk`` is a tuple containing multiple integers, the
17 | function will return a tuple containing accuracies of
18 | each ``topk`` number.
19 | """
20 | assert isinstance(topk, (int, tuple))
21 | if isinstance(topk, int):
22 | topk = (topk, )
23 | return_single = True
24 | else:
25 | return_single = False
26 |
27 | maxk = max(topk)
28 | if pred.size(0) == 0:
29 | accu = [pred.new_tensor(0.) for i in range(len(topk))]
30 | return accu[0] if return_single else accu
31 | assert pred.ndim == target.ndim + 1
32 | assert pred.size(0) == target.size(0)
33 | assert maxk <= pred.size(1), \
34 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
35 | pred_value, pred_label = pred.topk(maxk, dim=1)
36 | # transpose to shape (maxk, N, ...)
37 | pred_label = pred_label.transpose(0, 1)
38 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
39 | if thresh is not None:
40 | # Only prediction values larger than thresh are counted as correct
41 | correct = correct & (pred_value > thresh).t()
42 | res = []
43 | for k in topk:
44 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
45 | res.append(correct_k.mul_(100.0 / target.numel()))
46 | return res[0] if return_single else res
47 |
--------------------------------------------------------------------------------
/models/losses/cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def reduce_loss(loss, reduction):
7 | """Reduce loss as specified.
8 |
9 | Args:
10 | loss (Tensor): Elementwise loss tensor.
11 | reduction (str): Options are "none", "mean" and "sum".
12 |
13 | Return:
14 | Tensor: Reduced loss tensor.
15 | """
16 | reduction_enum = F._Reduction.get_enum(reduction)
17 | # none: 0, elementwise_mean:1, sum: 2
18 | if reduction_enum == 0:
19 | return loss
20 | elif reduction_enum == 1:
21 | return loss.mean()
22 | elif reduction_enum == 2:
23 | return loss.sum()
24 |
25 |
26 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
27 | """Apply element-wise weight and reduce loss.
28 |
29 | Args:
30 | loss (Tensor): Element-wise loss.
31 | weight (Tensor): Element-wise weights.
32 | reduction (str): Same as built-in losses of PyTorch.
33 | avg_factor (float): Avarage factor when computing the mean of losses.
34 |
35 | Returns:
36 | Tensor: Processed loss values.
37 | """
38 | # if weight is specified, apply element-wise weight
39 | if weight is not None:
40 | assert weight.dim() == loss.dim()
41 | if weight.dim() > 1:
42 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
43 | loss = loss * weight
44 |
45 | # if avg_factor is not specified, just reduce the loss
46 | if avg_factor is None:
47 | loss = reduce_loss(loss, reduction)
48 | else:
49 | # if reduction is mean, then average the loss by avg_factor
50 | if reduction == 'mean':
51 | loss = loss.sum() / avg_factor
52 | # if reduction is 'none', then do nothing, otherwise raise an error
53 | elif reduction != 'none':
54 | raise ValueError('avg_factor can not be used with reduction="sum"')
55 | return loss
56 |
57 |
58 | def cross_entropy(pred,
59 | label,
60 | weight=None,
61 | class_weight=None,
62 | reduction='mean',
63 | avg_factor=None,
64 | ignore_index=-100):
65 | """The wrapper function for :func:`F.cross_entropy`"""
66 | # class_weight is a manual rescaling weight given to each class.
67 | # If given, has to be a Tensor of size C element-wise losses
68 | loss = F.cross_entropy(
69 | pred,
70 | label,
71 | weight=class_weight,
72 | reduction='none',
73 | ignore_index=ignore_index)
74 |
75 | # apply weights and do the reduction
76 | if weight is not None:
77 | weight = weight.float()
78 | loss = weight_reduce_loss(
79 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
80 |
81 | return loss
82 |
83 |
84 | def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
85 | """Expand onehot labels to match the size of prediction."""
86 | bin_labels = labels.new_zeros(target_shape)
87 | valid_mask = (labels >= 0) & (labels != ignore_index)
88 | inds = torch.nonzero(valid_mask, as_tuple=True)
89 |
90 | if inds[0].numel() > 0:
91 | if labels.dim() == 3:
92 | bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
93 | else:
94 | bin_labels[inds[0], labels[valid_mask]] = 1
95 |
96 | valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
97 | if label_weights is None:
98 | bin_label_weights = valid_mask
99 | else:
100 | bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
101 | bin_label_weights *= valid_mask
102 |
103 | return bin_labels, bin_label_weights
104 |
105 |
106 | def binary_cross_entropy(pred,
107 | label,
108 | weight=None,
109 | reduction='mean',
110 | avg_factor=None,
111 | class_weight=None,
112 | ignore_index=255):
113 | """Calculate the binary CrossEntropy loss.
114 |
115 | Args:
116 | pred (torch.Tensor): The prediction with shape (N, 1).
117 | label (torch.Tensor): The learning label of the prediction.
118 | weight (torch.Tensor, optional): Sample-wise loss weight.
119 | reduction (str, optional): The method used to reduce the loss.
120 | Options are "none", "mean" and "sum".
121 | avg_factor (int, optional): Average factor that is used to average
122 | the loss. Defaults to None.
123 | class_weight (list[float], optional): The weight for each class.
124 | ignore_index (int | None): The label index to be ignored. Default: 255
125 |
126 | Returns:
127 | torch.Tensor: The calculated loss
128 | """
129 | if pred.dim() != label.dim():
130 | assert (pred.dim() == 2 and label.dim() == 1) or (
131 | pred.dim() == 4 and label.dim() == 3), \
132 | 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
133 | 'H, W], label shape [N, H, W] are supported'
134 | label, weight = _expand_onehot_labels(label, weight, pred.shape,
135 | ignore_index)
136 |
137 | # weighted element-wise losses
138 | if weight is not None:
139 | weight = weight.float()
140 | loss = F.binary_cross_entropy_with_logits(
141 | pred, label.float(), pos_weight=class_weight, reduction='none')
142 | # do the reduction for the weighted loss
143 | loss = weight_reduce_loss(
144 | loss, weight, reduction=reduction, avg_factor=avg_factor)
145 |
146 | return loss
147 |
148 |
149 | def mask_cross_entropy(pred,
150 | target,
151 | label,
152 | reduction='mean',
153 | avg_factor=None,
154 | class_weight=None,
155 | ignore_index=None):
156 | """Calculate the CrossEntropy loss for masks.
157 |
158 | Args:
159 | pred (torch.Tensor): The prediction with shape (N, C), C is the number
160 | of classes.
161 | target (torch.Tensor): The learning label of the prediction.
162 | label (torch.Tensor): ``label`` indicates the class label of the mask'
163 | corresponding object. This will be used to select the mask in the
164 | of the class which the object belongs to when the mask prediction
165 | if not class-agnostic.
166 | reduction (str, optional): The method used to reduce the loss.
167 | Options are "none", "mean" and "sum".
168 | avg_factor (int, optional): Average factor that is used to average
169 | the loss. Defaults to None.
170 | class_weight (list[float], optional): The weight for each class.
171 | ignore_index (None): Placeholder, to be consistent with other loss.
172 | Default: None.
173 |
174 | Returns:
175 | torch.Tensor: The calculated loss
176 | """
177 | assert ignore_index is None, 'BCE loss does not support ignore_index'
178 | # TODO: handle these two reserved arguments
179 | assert reduction == 'mean' and avg_factor is None
180 | num_rois = pred.size()[0]
181 | inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
182 | pred_slice = pred[inds, label].squeeze(1)
183 | return F.binary_cross_entropy_with_logits(
184 | pred_slice, target, weight=class_weight, reduction='mean')[None]
185 |
186 |
187 | class CrossEntropyLoss(nn.Module):
188 | """CrossEntropyLoss.
189 |
190 | Args:
191 | use_sigmoid (bool, optional): Whether the prediction uses sigmoid
192 | of softmax. Defaults to False.
193 | use_mask (bool, optional): Whether to use mask cross entropy loss.
194 | Defaults to False.
195 | reduction (str, optional): . Defaults to 'mean'.
196 | Options are "none", "mean" and "sum".
197 | class_weight (list[float], optional): Weight of each class.
198 | Defaults to None.
199 | loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
200 | """
201 |
202 | def __init__(self,
203 | use_sigmoid=False,
204 | use_mask=False,
205 | reduction='mean',
206 | class_weight=None,
207 | loss_weight=1.0):
208 | super(CrossEntropyLoss, self).__init__()
209 | assert (use_sigmoid is False) or (use_mask is False)
210 | self.use_sigmoid = use_sigmoid
211 | self.use_mask = use_mask
212 | self.reduction = reduction
213 | self.loss_weight = loss_weight
214 | self.class_weight = class_weight
215 |
216 | if self.use_sigmoid:
217 | self.cls_criterion = binary_cross_entropy
218 | elif self.use_mask:
219 | self.cls_criterion = mask_cross_entropy
220 | else:
221 | self.cls_criterion = cross_entropy
222 |
223 | def forward(self,
224 | cls_score,
225 | label,
226 | weight=None,
227 | avg_factor=None,
228 | reduction_override=None,
229 | **kwargs):
230 | """Forward function."""
231 | assert reduction_override in (None, 'none', 'mean', 'sum')
232 | reduction = (
233 | reduction_override if reduction_override else self.reduction)
234 | if self.class_weight is not None:
235 | class_weight = cls_score.new_tensor(self.class_weight)
236 | else:
237 | class_weight = None
238 | loss_cls = self.loss_weight * self.cls_criterion(
239 | cls_score,
240 | label,
241 | weight,
242 | class_weight=class_weight,
243 | reduction=reduction,
244 | avg_factor=avg_factor,
245 | **kwargs)
246 | return loss_cls
247 |
--------------------------------------------------------------------------------
/models/losses/segmentation_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 |
7 | def forward(self, prediction, target):
8 | loss = F.binary_cross_entropy_with_logits(prediction, target)
9 | return loss, {}
10 |
11 |
12 | class BCELossWithQuant(nn.Module):
13 |
14 | def __init__(self, codebook_weight=1.):
15 | super().__init__()
16 | self.codebook_weight = codebook_weight
17 |
18 | def forward(self, qloss, target, prediction, split):
19 | bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
20 | loss = bce_loss + self.codebook_weight * qloss
21 | return loss, {
22 | "{}/total_loss".format(split): loss.clone().detach().mean(),
23 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
24 | "{}/quant_loss".format(split): qloss.detach().mean()
25 | }
26 |
--------------------------------------------------------------------------------
/models/losses/vqgan_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
6 | recon_grads = torch.autograd.grad(
7 | recon_loss, last_layer, retain_graph=True)[0]
8 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
9 |
10 | d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
11 | d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
12 | return d_weight
13 |
14 |
15 | def adopt_weight(weight, global_step, threshold=0, value=0.):
16 | if global_step < threshold:
17 | weight = value
18 | return weight
19 |
20 |
21 | @torch.jit.script
22 | def hinge_d_loss(logits_real, logits_fake):
23 | loss_real = torch.mean(F.relu(1. - logits_real))
24 | loss_fake = torch.mean(F.relu(1. + logits_fake))
25 | d_loss = 0.5 * (loss_real + loss_fake)
26 | return d_loss
27 |
28 |
29 | def DiffAugment(x, policy='', channels_first=True):
30 | if policy:
31 | if not channels_first:
32 | x = x.permute(0, 3, 1, 2)
33 | for p in policy.split(','):
34 | for f in AUGMENT_FNS[p]:
35 | x = f(x)
36 | if not channels_first:
37 | x = x.permute(0, 2, 3, 1)
38 | x = x.contiguous()
39 | return x
40 |
41 |
42 | def rand_brightness(x):
43 | x = x + (
44 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
45 | return x
46 |
47 |
48 | def rand_saturation(x):
49 | x_mean = x.mean(dim=1, keepdim=True)
50 | x = (x - x_mean) * (torch.rand(
51 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
52 | return x
53 |
54 |
55 | def rand_contrast(x):
56 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
57 | x = (x - x_mean) * (torch.rand(
58 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
59 | return x
60 |
61 |
62 | def rand_translation(x, ratio=0.125):
63 | shift_x, shift_y = int(x.size(2) * ratio +
64 | 0.5), int(x.size(3) * ratio + 0.5)
65 | translation_x = torch.randint(
66 | -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
67 | translation_y = torch.randint(
68 | -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
69 | grid_batch, grid_x, grid_y = torch.meshgrid(
70 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
71 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
72 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
73 | )
74 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
75 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
76 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
77 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
78 | grid_y].permute(0, 3, 1, 2)
79 | return x
80 |
81 |
82 | def rand_cutout(x, ratio=0.5):
83 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
84 | offset_x = torch.randint(
85 | 0,
86 | x.size(2) + (1 - cutout_size[0] % 2),
87 | size=[x.size(0), 1, 1],
88 | device=x.device)
89 | offset_y = torch.randint(
90 | 0,
91 | x.size(3) + (1 - cutout_size[1] % 2),
92 | size=[x.size(0), 1, 1],
93 | device=x.device)
94 | grid_batch, grid_x, grid_y = torch.meshgrid(
95 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
96 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
97 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
98 | )
99 | grid_x = torch.clamp(
100 | grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
101 | grid_y = torch.clamp(
102 | grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
103 | mask = torch.ones(
104 | x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
105 | mask[grid_batch, grid_x, grid_y] = 0
106 | x = x * mask.unsqueeze(1)
107 | return x
108 |
109 |
110 | AUGMENT_FNS = {
111 | 'color': [rand_brightness, rand_saturation, rand_contrast],
112 | 'translation': [rand_translation],
113 | 'cutout': [rand_cutout],
114 | }
115 |
--------------------------------------------------------------------------------
/models/parsing_gen_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from collections import OrderedDict
4 |
5 | import mmcv
6 | import numpy as np
7 | import torch
8 | from torchvision.utils import save_image
9 |
10 | from models.archs.fcn_arch import FCNHead
11 | from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
12 | from models.archs.unet_arch import ShapeUNet
13 | from models.losses.accuracy import accuracy
14 | from models.losses.cross_entropy_loss import CrossEntropyLoss
15 |
16 | logger = logging.getLogger('base')
17 |
18 |
19 | class ParsingGenModel():
20 | """Paring Generation model.
21 | """
22 |
23 | def __init__(self, opt):
24 | self.opt = opt
25 | self.device = torch.device('cuda')
26 | self.is_train = opt['is_train']
27 |
28 | self.attr_embedder = ShapeAttrEmbedding(
29 | dim=opt['embedder_dim'],
30 | out_dim=opt['embedder_out_dim'],
31 | cls_num_list=opt['attr_class_num']).to(self.device)
32 | self.parsing_encoder = ShapeUNet(
33 | in_channels=opt['encoder_in_channels']).to(self.device)
34 | self.parsing_decoder = FCNHead(
35 | in_channels=opt['fc_in_channels'],
36 | in_index=opt['fc_in_index'],
37 | channels=opt['fc_channels'],
38 | num_convs=opt['fc_num_convs'],
39 | concat_input=opt['fc_concat_input'],
40 | dropout_ratio=opt['fc_dropout_ratio'],
41 | num_classes=opt['fc_num_classes'],
42 | align_corners=opt['fc_align_corners'],
43 | ).to(self.device)
44 |
45 | self.init_training_settings()
46 |
47 | self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
48 | [250, 235, 215], [255, 250, 205], [211, 211, 211],
49 | [70, 130, 180], [127, 255, 212], [0, 100, 0],
50 | [50, 205, 50], [255, 255, 0], [245, 222, 179],
51 | [255, 140, 0], [255, 0, 0], [16, 78, 139],
52 | [144, 238, 144], [50, 205, 174], [50, 155, 250],
53 | [160, 140, 88], [213, 140, 88], [90, 140, 90],
54 | [185, 210, 205], [130, 165, 180], [225, 141, 151]]
55 |
56 | def init_training_settings(self):
57 | optim_params = []
58 | for v in self.attr_embedder.parameters():
59 | if v.requires_grad:
60 | optim_params.append(v)
61 | for v in self.parsing_encoder.parameters():
62 | if v.requires_grad:
63 | optim_params.append(v)
64 | for v in self.parsing_decoder.parameters():
65 | if v.requires_grad:
66 | optim_params.append(v)
67 | # set up optimizers
68 | self.optimizer = torch.optim.Adam(
69 | optim_params,
70 | self.opt['lr'],
71 | weight_decay=self.opt['weight_decay'])
72 | self.log_dict = OrderedDict()
73 | self.entropy_loss = CrossEntropyLoss().to(self.device)
74 |
75 | def feed_data(self, data):
76 | self.pose = data['densepose'].to(self.device)
77 | self.attr = data['attr'].to(self.device)
78 | self.segm = data['segm'].to(self.device)
79 |
80 | def optimize_parameters(self):
81 | self.attr_embedder.train()
82 | self.parsing_encoder.train()
83 | self.parsing_decoder.train()
84 |
85 | self.attr_embedding = self.attr_embedder(self.attr)
86 | self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
87 | self.seg_logits = self.parsing_decoder(self.pose_enc)
88 |
89 | loss = self.entropy_loss(self.seg_logits, self.segm)
90 |
91 | self.optimizer.zero_grad()
92 | loss.backward()
93 | self.optimizer.step()
94 |
95 | self.log_dict['loss_total'] = loss
96 |
97 | def get_vis(self, save_path):
98 | img_cat = torch.cat([
99 | self.pose,
100 | self.segm,
101 | ], dim=3).detach()
102 | img_cat = ((img_cat + 1) / 2)
103 |
104 | img_cat = img_cat.clamp_(0, 1)
105 |
106 | save_image(img_cat, save_path, nrow=1, padding=4)
107 |
108 | def inference(self, data_loader, save_dir):
109 | self.attr_embedder.eval()
110 | self.parsing_encoder.eval()
111 | self.parsing_decoder.eval()
112 |
113 | acc = 0
114 | num = 0
115 |
116 | for _, data in enumerate(data_loader):
117 | pose = data['densepose'].to(self.device)
118 | attr = data['attr'].to(self.device)
119 | segm = data['segm'].to(self.device)
120 | img_name = data['img_name']
121 |
122 | num += pose.size(0)
123 | with torch.no_grad():
124 | attr_embedding = self.attr_embedder(attr)
125 | pose_enc = self.parsing_encoder(pose, attr_embedding)
126 | seg_logits = self.parsing_decoder(pose_enc)
127 | seg_pred = seg_logits.argmax(dim=1)
128 | acc += accuracy(seg_logits, segm)
129 | palette_label = self.palette_result(segm.cpu().numpy())
130 | palette_pred = self.palette_result(seg_pred.cpu().numpy())
131 | pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
132 | 3,
133 | pose[0].size(1),
134 | pose[0].size(2),
135 | ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
136 | concat_result = np.concatenate(
137 | (pose_numpy, palette_pred, palette_label), axis=1)
138 | mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
139 |
140 | self.attr_embedder.train()
141 | self.parsing_encoder.train()
142 | self.parsing_decoder.train()
143 | return (acc / num).item()
144 |
145 | def get_current_log(self):
146 | return self.log_dict
147 |
148 | def update_learning_rate(self, epoch):
149 | """Update learning rate.
150 |
151 | Args:
152 | current_iter (int): Current iteration.
153 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
154 | Default: -1.
155 | """
156 | lr = self.optimizer.param_groups[0]['lr']
157 |
158 | if self.opt['lr_decay'] == 'step':
159 | lr = self.opt['lr'] * (
160 | self.opt['gamma']**(epoch // self.opt['step']))
161 | elif self.opt['lr_decay'] == 'cos':
162 | lr = self.opt['lr'] * (
163 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
164 | elif self.opt['lr_decay'] == 'linear':
165 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
166 | elif self.opt['lr_decay'] == 'linear2exp':
167 | if epoch < self.opt['turning_point'] + 1:
168 | # learning rate decay as 95%
169 | # at the turning point (1 / 95% = 1.0526)
170 | lr = self.opt['lr'] * (
171 | 1 - epoch / int(self.opt['turning_point'] * 1.0526))
172 | else:
173 | lr *= self.opt['gamma']
174 | elif self.opt['lr_decay'] == 'schedule':
175 | if epoch in self.opt['schedule']:
176 | lr *= self.opt['gamma']
177 | else:
178 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
179 | # set learning rate
180 | for param_group in self.optimizer.param_groups:
181 | param_group['lr'] = lr
182 |
183 | return lr
184 |
185 | def save_network(self, save_path):
186 | """Save networks.
187 | """
188 |
189 | save_dict = {}
190 | save_dict['embedder'] = self.attr_embedder.state_dict()
191 | save_dict['encoder'] = self.parsing_encoder.state_dict()
192 | save_dict['decoder'] = self.parsing_decoder.state_dict()
193 |
194 | torch.save(save_dict, save_path)
195 |
196 | def load_network(self):
197 | checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
198 |
199 | self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
200 | self.attr_embedder.eval()
201 |
202 | self.parsing_encoder.load_state_dict(
203 | checkpoint['encoder'], strict=True)
204 | self.parsing_encoder.eval()
205 |
206 | self.parsing_decoder.load_state_dict(
207 | checkpoint['decoder'], strict=True)
208 | self.parsing_decoder.eval()
209 |
210 | def palette_result(self, result):
211 | seg = result[0]
212 | palette = np.array(self.palette)
213 | assert palette.shape[1] == 3
214 | assert len(palette.shape) == 2
215 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
216 | for label, color in enumerate(palette):
217 | color_seg[seg == label, :] = color
218 | # convert to BGR
219 | color_seg = color_seg[..., ::-1]
220 | return color_seg
221 |
--------------------------------------------------------------------------------
/sample_from_parsing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os.path as osp
4 | import random
5 |
6 | import torch
7 |
8 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset
9 | from models import create_model
10 | from utils.logger import get_root_logger
11 | from utils.options import dict2str, dict_to_nonedict, parse
12 | from utils.util import make_exp_dirs, set_random_seed
13 |
14 |
15 | def main():
16 | # options
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
19 | args = parser.parse_args()
20 | opt = parse(args.opt, is_train=False)
21 |
22 | # mkdir and loggers
23 | make_exp_dirs(opt)
24 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
25 | logger = get_root_logger(
26 | logger_name='base', log_level=logging.INFO, log_file=log_file)
27 | logger.info(dict2str(opt))
28 |
29 | # convert to NoneDict, which returns None for missing keys
30 | opt = dict_to_nonedict(opt)
31 |
32 | # random seed
33 | seed = opt['manual_seed']
34 | if seed is None:
35 | seed = random.randint(1, 10000)
36 | logger.info(f'Random seed: {seed}')
37 | set_random_seed(seed)
38 |
39 | test_dataset = DeepFashionAttrSegmDataset(
40 | img_dir=opt['test_img_dir'],
41 | segm_dir=opt['segm_dir'],
42 | pose_dir=opt['pose_dir'],
43 | ann_dir=opt['test_ann_file'])
44 | test_loader = torch.utils.data.DataLoader(
45 | dataset=test_dataset, batch_size=4, shuffle=False)
46 | logger.info(f'Number of test set: {len(test_dataset)}.')
47 |
48 | model = create_model(opt)
49 | _ = model.inference(test_loader, opt['path']['results_root'])
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/sample_from_pose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os.path as osp
4 | import random
5 |
6 | import torch
7 |
8 | from data.pose_attr_dataset import DeepFashionAttrPoseDataset
9 | from models import create_model
10 | from utils.logger import get_root_logger
11 | from utils.options import dict2str, dict_to_nonedict, parse
12 | from utils.util import make_exp_dirs, set_random_seed
13 |
14 |
15 | def main():
16 | # options
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
19 | args = parser.parse_args()
20 | opt = parse(args.opt, is_train=False)
21 |
22 | # mkdir and loggers
23 | make_exp_dirs(opt)
24 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
25 | logger = get_root_logger(
26 | logger_name='base', log_level=logging.INFO, log_file=log_file)
27 | logger.info(dict2str(opt))
28 |
29 | # convert to NoneDict, which returns None for missing keys
30 | opt = dict_to_nonedict(opt)
31 |
32 | # random seed
33 | seed = opt['manual_seed']
34 | if seed is None:
35 | seed = random.randint(1, 10000)
36 | logger.info(f'Random seed: {seed}')
37 | set_random_seed(seed)
38 |
39 | test_dataset = DeepFashionAttrPoseDataset(
40 | pose_dir=opt['pose_dir'],
41 | texture_ann_dir=opt['texture_ann_file'],
42 | shape_ann_path=opt['shape_ann_path'])
43 | test_loader = torch.utils.data.DataLoader(
44 | dataset=test_dataset, batch_size=4, shuffle=False)
45 | logger.info(f'Number of test set: {len(test_dataset)}.')
46 |
47 | model = create_model(opt)
48 | _ = model.inference(test_loader, opt['path']['results_root'])
49 |
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/train_index_prediction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11 | from models import create_model
12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13 | from utils.options import dict2str, dict_to_nonedict, parse
14 | from utils.util import make_exp_dirs
15 |
16 |
17 | def main():
18 | # options
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21 | args = parser.parse_args()
22 | opt = parse(args.opt, is_train=True)
23 |
24 | # mkdir and loggers
25 | make_exp_dirs(opt)
26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27 | logger = get_root_logger(
28 | logger_name='base', log_level=logging.INFO, log_file=log_file)
29 | logger.info(dict2str(opt))
30 | # initialize tensorboard logger
31 | tb_logger = None
32 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34 |
35 | # convert to NoneDict, which returns None for missing keys
36 | opt = dict_to_nonedict(opt)
37 |
38 | # set up data loader
39 | train_dataset = DeepFashionAttrSegmDataset(
40 | img_dir=opt['train_img_dir'],
41 | segm_dir=opt['segm_dir'],
42 | pose_dir=opt['pose_dir'],
43 | ann_dir=opt['train_ann_file'],
44 | xflip=True)
45 | train_loader = torch.utils.data.DataLoader(
46 | dataset=train_dataset,
47 | batch_size=opt['batch_size'],
48 | shuffle=True,
49 | num_workers=opt['num_workers'],
50 | drop_last=True)
51 | logger.info(f'Number of train set: {len(train_dataset)}.')
52 | opt['max_iters'] = opt['num_epochs'] * len(
53 | train_dataset) // opt['batch_size']
54 |
55 | val_dataset = DeepFashionAttrSegmDataset(
56 | img_dir=opt['train_img_dir'],
57 | segm_dir=opt['segm_dir'],
58 | pose_dir=opt['pose_dir'],
59 | ann_dir=opt['val_ann_file'])
60 | val_loader = torch.utils.data.DataLoader(
61 | dataset=val_dataset, batch_size=1, shuffle=False)
62 | logger.info(f'Number of val set: {len(val_dataset)}.')
63 |
64 | test_dataset = DeepFashionAttrSegmDataset(
65 | img_dir=opt['test_img_dir'],
66 | segm_dir=opt['segm_dir'],
67 | pose_dir=opt['pose_dir'],
68 | ann_dir=opt['test_ann_file'])
69 | test_loader = torch.utils.data.DataLoader(
70 | dataset=test_dataset, batch_size=1, shuffle=False)
71 | logger.info(f'Number of test set: {len(test_dataset)}.')
72 |
73 | current_iter = 0
74 | best_epoch = None
75 | best_acc = 0
76 |
77 | model = create_model(opt)
78 |
79 | data_time, iter_time = 0, 0
80 | current_iter = 0
81 |
82 | # create message logger (formatted outputs)
83 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
84 |
85 | for epoch in range(opt['num_epochs']):
86 | lr = model.update_learning_rate(epoch)
87 |
88 | for _, batch_data in enumerate(train_loader):
89 | data_time = time.time() - data_time
90 |
91 | current_iter += 1
92 |
93 | model.feed_data(batch_data)
94 | model.optimize_parameters()
95 |
96 | iter_time = time.time() - iter_time
97 | if current_iter % opt['print_freq'] == 0:
98 | log_vars = {'epoch': epoch, 'iter': current_iter}
99 | log_vars.update({'lrs': [lr]})
100 | log_vars.update({'time': iter_time, 'data_time': data_time})
101 | log_vars.update(model.get_current_log())
102 | msg_logger(log_vars)
103 |
104 | data_time = time.time()
105 | iter_time = time.time()
106 |
107 | if epoch % opt['val_freq'] == 0:
108 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
109 | os.makedirs(save_dir, exist_ok=opt['debug'])
110 | val_acc = model.inference(val_loader, save_dir)
111 |
112 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
113 | os.makedirs(save_dir, exist_ok=opt['debug'])
114 | test_acc = model.inference(test_loader, save_dir)
115 |
116 | logger.info(
117 | f'Epoch: {epoch}, val_acc: {val_acc: .4f}, test_acc: {test_acc: .4f}.'
118 | )
119 |
120 | if test_acc > best_acc:
121 | best_epoch = epoch
122 | best_acc = test_acc
123 |
124 | logger.info(f'Best epoch: {best_epoch}, '
125 | f'Best test acc: {best_acc: .4f}.')
126 |
127 | # save model
128 | model.save_network(
129 | f'{opt["path"]["models"]}/models_epoch{epoch}.pth')
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
134 |
--------------------------------------------------------------------------------
/train_parsing_gen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.parsing_generation_segm_attr_dataset import \
11 | ParsingGenerationDeepFashionAttrSegmDataset
12 | from models import create_model
13 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
14 | from utils.options import dict2str, dict_to_nonedict, parse
15 | from utils.util import make_exp_dirs
16 |
17 |
18 | def main():
19 | # options
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
22 | args = parser.parse_args()
23 | opt = parse(args.opt, is_train=True)
24 |
25 | # mkdir and loggers
26 | make_exp_dirs(opt)
27 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
28 | logger = get_root_logger(
29 | logger_name='base', log_level=logging.INFO, log_file=log_file)
30 | logger.info(dict2str(opt))
31 | # initialize tensorboard logger
32 | tb_logger = None
33 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
34 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
35 |
36 | # convert to NoneDict, which returns None for missing keys
37 | opt = dict_to_nonedict(opt)
38 |
39 | # set up data loader
40 | train_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
41 | segm_dir=opt['segm_dir'],
42 | pose_dir=opt['pose_dir'],
43 | ann_file=opt['train_ann_file'])
44 | train_loader = torch.utils.data.DataLoader(
45 | dataset=train_dataset,
46 | batch_size=opt['batch_size'],
47 | shuffle=True,
48 | num_workers=opt['num_workers'],
49 | drop_last=True)
50 | logger.info(f'Number of train set: {len(train_dataset)}.')
51 | opt['max_iters'] = opt['num_epochs'] * len(
52 | train_dataset) // opt['batch_size']
53 |
54 | val_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
55 | segm_dir=opt['segm_dir'],
56 | pose_dir=opt['pose_dir'],
57 | ann_file=opt['val_ann_file'])
58 | val_loader = torch.utils.data.DataLoader(
59 | dataset=val_dataset,
60 | batch_size=1,
61 | shuffle=False,
62 | num_workers=opt['num_workers'])
63 | logger.info(f'Number of val set: {len(val_dataset)}.')
64 |
65 | test_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
66 | segm_dir=opt['segm_dir'],
67 | pose_dir=opt['pose_dir'],
68 | ann_file=opt['test_ann_file'])
69 | test_loader = torch.utils.data.DataLoader(
70 | dataset=test_dataset,
71 | batch_size=1,
72 | shuffle=False,
73 | num_workers=opt['num_workers'])
74 | logger.info(f'Number of test set: {len(test_dataset)}.')
75 |
76 | current_iter = 0
77 | best_epoch = None
78 | best_acc = 0
79 |
80 | model = create_model(opt)
81 |
82 | data_time, iter_time = 0, 0
83 | current_iter = 0
84 |
85 | # create message logger (formatted outputs)
86 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
87 |
88 | for epoch in range(opt['num_epochs']):
89 | lr = model.update_learning_rate(epoch)
90 |
91 | for _, batch_data in enumerate(train_loader):
92 | data_time = time.time() - data_time
93 |
94 | current_iter += 1
95 |
96 | model.feed_data(batch_data)
97 | model.optimize_parameters()
98 |
99 | iter_time = time.time() - iter_time
100 | if current_iter % opt['print_freq'] == 0:
101 | log_vars = {'epoch': epoch, 'iter': current_iter}
102 | log_vars.update({'lrs': [lr]})
103 | log_vars.update({'time': iter_time, 'data_time': data_time})
104 | log_vars.update(model.get_current_log())
105 | msg_logger(log_vars)
106 |
107 | data_time = time.time()
108 | iter_time = time.time()
109 |
110 | if epoch % opt['val_freq'] == 0:
111 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}'
112 | os.makedirs(save_dir, exist_ok=opt['debug'])
113 | val_acc = model.inference(val_loader, save_dir)
114 |
115 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}'
116 | os.makedirs(save_dir, exist_ok=opt['debug'])
117 | test_acc = model.inference(test_loader, save_dir)
118 |
119 | logger.info(f'Epoch: {epoch}, '
120 | f'val_acc: {val_acc: .4f}, '
121 | f'test_acc: {test_acc: .4f}.')
122 |
123 | if test_acc > best_acc:
124 | best_epoch = epoch
125 | best_acc = test_acc
126 |
127 | logger.info(f'Best epoch: {best_epoch}, '
128 | f'Best test acc: {best_acc: .4f}.')
129 |
130 | # save model
131 | model.save_network(
132 | f'{opt["path"]["models"]}/parsing_generation_epoch{epoch}.pth')
133 |
134 |
135 | if __name__ == '__main__':
136 | main()
137 |
--------------------------------------------------------------------------------
/train_parsing_token.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.mask_dataset import MaskDataset
11 | from models import create_model
12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13 | from utils.options import dict2str, dict_to_nonedict, parse
14 | from utils.util import make_exp_dirs
15 |
16 |
17 | def main():
18 | # options
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21 | args = parser.parse_args()
22 | opt = parse(args.opt, is_train=True)
23 |
24 | # mkdir and loggers
25 | make_exp_dirs(opt)
26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27 | logger = get_root_logger(
28 | logger_name='base', log_level=logging.INFO, log_file=log_file)
29 | logger.info(dict2str(opt))
30 | # initialize tensorboard logger
31 | tb_logger = None
32 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34 |
35 | # convert to NoneDict, which returns None for missing keys
36 | opt = dict_to_nonedict(opt)
37 |
38 | # set up data loader
39 | train_dataset = MaskDataset(
40 | segm_dir=opt['segm_dir'], ann_dir=opt['train_ann_file'], xflip=True)
41 | train_loader = torch.utils.data.DataLoader(
42 | dataset=train_dataset,
43 | batch_size=opt['batch_size'],
44 | shuffle=True,
45 | num_workers=opt['num_workers'],
46 | persistent_workers=True,
47 | drop_last=True)
48 | logger.info(f'Number of train set: {len(train_dataset)}.')
49 | opt['max_iters'] = opt['num_epochs'] * len(
50 | train_dataset) // opt['batch_size']
51 |
52 | val_dataset = MaskDataset(
53 | segm_dir=opt['segm_dir'], ann_dir=opt['val_ann_file'])
54 | val_loader = torch.utils.data.DataLoader(
55 | dataset=val_dataset, batch_size=1, shuffle=False)
56 | logger.info(f'Number of val set: {len(val_dataset)}.')
57 |
58 | test_dataset = MaskDataset(
59 | segm_dir=opt['segm_dir'], ann_dir=opt['test_ann_file'])
60 | test_loader = torch.utils.data.DataLoader(
61 | dataset=test_dataset, batch_size=1, shuffle=False)
62 | logger.info(f'Number of test set: {len(test_dataset)}.')
63 |
64 | current_iter = 0
65 | best_epoch = None
66 | best_loss = 100000
67 |
68 | model = create_model(opt)
69 |
70 | data_time, iter_time = 0, 0
71 | current_iter = 0
72 |
73 | # create message logger (formatted outputs)
74 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
75 |
76 | for epoch in range(opt['num_epochs']):
77 | lr = model.update_learning_rate(epoch)
78 |
79 | for _, batch_data in enumerate(train_loader):
80 | data_time = time.time() - data_time
81 |
82 | current_iter += 1
83 |
84 | model.optimize_parameters(batch_data, current_iter)
85 |
86 | iter_time = time.time() - iter_time
87 | if current_iter % opt['print_freq'] == 0:
88 | log_vars = {'epoch': epoch, 'iter': current_iter}
89 | log_vars.update({'lrs': [lr]})
90 | log_vars.update({'time': iter_time, 'data_time': data_time})
91 | log_vars.update(model.get_current_log())
92 | msg_logger(log_vars)
93 |
94 | data_time = time.time()
95 | iter_time = time.time()
96 |
97 | if epoch % opt['val_freq'] == 0:
98 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
99 | os.makedirs(save_dir, exist_ok=opt['debug'])
100 | val_loss_total, _, _ = model.inference(val_loader, save_dir)
101 |
102 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
103 | os.makedirs(save_dir, exist_ok=opt['debug'])
104 | test_loss_total, _, _ = model.inference(test_loader, save_dir)
105 |
106 | logger.info(f'Epoch: {epoch}, '
107 | f'val_loss_total: {val_loss_total}, '
108 | f'test_loss_total: {test_loss_total}.')
109 |
110 | if test_loss_total < best_loss:
111 | best_epoch = epoch
112 | best_loss = test_loss_total
113 |
114 | logger.info(f'Best epoch: {best_epoch}, '
115 | f'Best test loss: {best_loss: .4f}.')
116 |
117 | # save model
118 | model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
119 |
120 |
121 | if __name__ == '__main__':
122 | main()
123 |
--------------------------------------------------------------------------------
/train_sampler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11 | from models import create_model
12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13 | from utils.options import dict2str, dict_to_nonedict, parse
14 | from utils.util import make_exp_dirs
15 |
16 |
17 | def main():
18 | # options
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21 | args = parser.parse_args()
22 | opt = parse(args.opt, is_train=True)
23 |
24 | # mkdir and loggers
25 | make_exp_dirs(opt)
26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27 | logger = get_root_logger(
28 | logger_name='base', log_level=logging.INFO, log_file=log_file)
29 | logger.info(dict2str(opt))
30 | # initialize tensorboard logger
31 | tb_logger = None
32 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34 |
35 | # convert to NoneDict, which returns None for missing keys
36 | opt = dict_to_nonedict(opt)
37 |
38 | # set up data loader
39 | train_dataset = DeepFashionAttrSegmDataset(
40 | img_dir=opt['train_img_dir'],
41 | segm_dir=opt['segm_dir'],
42 | pose_dir=opt['pose_dir'],
43 | ann_dir=opt['train_ann_file'],
44 | xflip=True)
45 | train_loader = torch.utils.data.DataLoader(
46 | dataset=train_dataset,
47 | batch_size=opt['batch_size'],
48 | shuffle=True,
49 | num_workers=opt['num_workers'],
50 | persistent_workers=True,
51 | drop_last=True)
52 | logger.info(f'Number of train set: {len(train_dataset)}.')
53 | opt['max_iters'] = opt['num_epochs'] * len(
54 | train_dataset) // opt['batch_size']
55 |
56 | val_dataset = DeepFashionAttrSegmDataset(
57 | img_dir=opt['train_img_dir'],
58 | segm_dir=opt['segm_dir'],
59 | pose_dir=opt['pose_dir'],
60 | ann_dir=opt['val_ann_file'])
61 | val_loader = torch.utils.data.DataLoader(
62 | dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
63 | logger.info(f'Number of val set: {len(val_dataset)}.')
64 |
65 | test_dataset = DeepFashionAttrSegmDataset(
66 | img_dir=opt['test_img_dir'],
67 | segm_dir=opt['segm_dir'],
68 | pose_dir=opt['pose_dir'],
69 | ann_dir=opt['test_ann_file'])
70 | test_loader = torch.utils.data.DataLoader(
71 | dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
72 | logger.info(f'Number of test set: {len(test_dataset)}.')
73 |
74 | current_iter = 0
75 |
76 | model = create_model(opt)
77 |
78 | data_time, iter_time = 0, 0
79 | current_iter = 0
80 |
81 | # create message logger (formatted outputs)
82 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
83 |
84 | for epoch in range(opt['num_epochs']):
85 | lr = model.update_learning_rate(epoch, current_iter)
86 |
87 | for _, batch_data in enumerate(train_loader):
88 | data_time = time.time() - data_time
89 |
90 | current_iter += 1
91 |
92 | model.feed_data(batch_data)
93 | model.optimize_parameters()
94 |
95 | iter_time = time.time() - iter_time
96 | if current_iter % opt['print_freq'] == 0:
97 | log_vars = {'epoch': epoch, 'iter': current_iter}
98 | log_vars.update({'lrs': [lr]})
99 | log_vars.update({'time': iter_time, 'data_time': data_time})
100 | log_vars.update(model.get_current_log())
101 | msg_logger(log_vars)
102 |
103 | data_time = time.time()
104 | iter_time = time.time()
105 |
106 | if epoch % opt['val_freq'] == 0 and epoch != 0:
107 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
108 | os.makedirs(save_dir, exist_ok=opt['debug'])
109 | model.inference(val_loader, save_dir)
110 |
111 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
112 | os.makedirs(save_dir, exist_ok=opt['debug'])
113 | model.inference(test_loader, save_dir)
114 |
115 | # save model
116 | model.save_network(
117 | model._denoise_fn,
118 | f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
119 |
120 |
121 | if __name__ == '__main__':
122 | main()
123 |
--------------------------------------------------------------------------------
/train_vqvae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11 | from models import create_model
12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13 | from utils.options import dict2str, dict_to_nonedict, parse
14 | from utils.util import make_exp_dirs
15 |
16 |
17 | def main():
18 | # options
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21 | args = parser.parse_args()
22 | opt = parse(args.opt, is_train=True)
23 |
24 | # mkdir and loggers
25 | make_exp_dirs(opt)
26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27 | logger = get_root_logger(
28 | logger_name='base', log_level=logging.INFO, log_file=log_file)
29 | logger.info(dict2str(opt))
30 | # initialize tensorboard logger
31 | tb_logger = None
32 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34 |
35 | # convert to NoneDict, which returns None for missing keys
36 | opt = dict_to_nonedict(opt)
37 |
38 | # set up data loader
39 | train_dataset = DeepFashionAttrSegmDataset(
40 | img_dir=opt['train_img_dir'],
41 | segm_dir=opt['segm_dir'],
42 | pose_dir=opt['pose_dir'],
43 | ann_dir=opt['train_ann_file'],
44 | xflip=True)
45 | train_loader = torch.utils.data.DataLoader(
46 | dataset=train_dataset,
47 | batch_size=opt['batch_size'],
48 | shuffle=True,
49 | num_workers=opt['num_workers'],
50 | persistent_workers=True,
51 | drop_last=True)
52 | logger.info(f'Number of train set: {len(train_dataset)}.')
53 | opt['max_iters'] = opt['num_epochs'] * len(
54 | train_dataset) // opt['batch_size']
55 |
56 | val_dataset = DeepFashionAttrSegmDataset(
57 | img_dir=opt['train_img_dir'],
58 | segm_dir=opt['segm_dir'],
59 | pose_dir=opt['pose_dir'],
60 | ann_dir=opt['val_ann_file'])
61 | val_loader = torch.utils.data.DataLoader(
62 | dataset=val_dataset, batch_size=1, shuffle=False)
63 | logger.info(f'Number of val set: {len(val_dataset)}.')
64 |
65 | test_dataset = DeepFashionAttrSegmDataset(
66 | img_dir=opt['test_img_dir'],
67 | segm_dir=opt['segm_dir'],
68 | pose_dir=opt['pose_dir'],
69 | ann_dir=opt['test_ann_file'])
70 | test_loader = torch.utils.data.DataLoader(
71 | dataset=test_dataset, batch_size=1, shuffle=False)
72 | logger.info(f'Number of test set: {len(test_dataset)}.')
73 |
74 | current_iter = 0
75 | best_epoch = None
76 | best_loss = 100000
77 |
78 | model = create_model(opt)
79 |
80 | data_time, iter_time = 0, 0
81 | current_iter = 0
82 |
83 | # create message logger (formatted outputs)
84 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
85 |
86 | for epoch in range(opt['num_epochs']):
87 | lr = model.update_learning_rate(epoch)
88 |
89 | for _, batch_data in enumerate(train_loader):
90 | data_time = time.time() - data_time
91 |
92 | current_iter += 1
93 |
94 | model.optimize_parameters(batch_data, current_iter)
95 |
96 | iter_time = time.time() - iter_time
97 | if current_iter % opt['print_freq'] == 0:
98 | log_vars = {'epoch': epoch, 'iter': current_iter}
99 | log_vars.update({'lrs': [lr]})
100 | log_vars.update({'time': iter_time, 'data_time': data_time})
101 | log_vars.update(model.get_current_log())
102 | msg_logger(log_vars)
103 |
104 | data_time = time.time()
105 | iter_time = time.time()
106 |
107 | if epoch % opt['val_freq'] == 0:
108 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
109 | os.makedirs(save_dir, exist_ok=opt['debug'])
110 | val_loss_total = model.inference(val_loader, save_dir)
111 |
112 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
113 | os.makedirs(save_dir, exist_ok=opt['debug'])
114 | test_loss_total = model.inference(test_loader, save_dir)
115 |
116 | logger.info(f'Epoch: {epoch}, '
117 | f'val_loss_total: {val_loss_total}, '
118 | f'test_loss_total: {test_loss_total}.')
119 |
120 | if test_loss_total < best_loss:
121 | best_epoch = epoch
122 | best_loss = test_loss_total
123 |
124 | logger.info(f'Best epoch: {best_epoch}, '
125 | f'Best test loss: {best_loss: .4f}.')
126 |
127 | # save model
128 | model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
129 |
130 |
131 | if __name__ == '__main__':
132 | main()
133 |
--------------------------------------------------------------------------------
/ui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/__init__.py
--------------------------------------------------------------------------------
/ui/color_blocks/class_bag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_bag.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_belt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_belt.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_bg.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_dress.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_dress.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_earstuds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_earstuds.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_eyeglass.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_eyeglass.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_face.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_footwear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_footwear.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_glove.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_glove.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_hair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_hair.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_headwear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_headwear.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_leggings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_leggings.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_necklace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_necklace.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_neckwear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_neckwear.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_outer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_outer.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_pants.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_pants.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_ring.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_ring.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_rompers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_rompers.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_skin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_skin.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_skirt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_skirt.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_socks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_socks.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_tie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_tie.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_top.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_top.png
--------------------------------------------------------------------------------
/ui/color_blocks/class_wrist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/color_blocks/class_wrist.png
--------------------------------------------------------------------------------
/ui/icons/icon_palette.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/icons/icon_palette.png
--------------------------------------------------------------------------------
/ui/icons/icon_title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui/icons/icon_title.png
--------------------------------------------------------------------------------
/ui/mouse_event.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | from PyQt5.QtCore import *
5 | from PyQt5.QtGui import *
6 | from PyQt5.QtWidgets import *
7 |
8 | color_list = [
9 | QColor(0, 0, 0),
10 | QColor(255, 250, 250),
11 | QColor(220, 220, 220),
12 | QColor(250, 235, 215),
13 | QColor(255, 250, 205),
14 | QColor(211, 211, 211),
15 | QColor(70, 130, 180),
16 | QColor(127, 255, 212),
17 | QColor(0, 100, 0),
18 | QColor(50, 205, 50),
19 | QColor(255, 255, 0),
20 | QColor(245, 222, 179),
21 | QColor(255, 140, 0),
22 | QColor(255, 0, 0),
23 | QColor(16, 78, 139),
24 | QColor(144, 238, 144),
25 | QColor(50, 205, 174),
26 | QColor(50, 155, 250),
27 | QColor(160, 140, 88),
28 | QColor(213, 140, 88),
29 | QColor(90, 140, 90),
30 | QColor(185, 210, 205),
31 | QColor(130, 165, 180),
32 | QColor(225, 141, 151)
33 | ]
34 |
35 |
36 | class GraphicsScene(QGraphicsScene):
37 |
38 | def __init__(self, mode, size, parent=None):
39 | QGraphicsScene.__init__(self, parent)
40 | self.mode = mode
41 | self.size = size
42 | self.mouse_clicked = False
43 | self.prev_pt = None
44 |
45 | # self.masked_image = None
46 |
47 | # save the points
48 | self.mask_points = []
49 | for i in range(len(color_list)):
50 | self.mask_points.append([])
51 |
52 | # save the size of points
53 | self.size_points = []
54 | for i in range(len(color_list)):
55 | self.size_points.append([])
56 |
57 | # save the history of edit
58 | self.history = []
59 |
60 | def reset(self):
61 | # save the points
62 | self.mask_points = []
63 | for i in range(len(color_list)):
64 | self.mask_points.append([])
65 | # save the size of points
66 | self.size_points = []
67 | for i in range(len(color_list)):
68 | self.size_points.append([])
69 | # save the history of edit
70 | self.history = []
71 |
72 | self.mode = 0
73 | self.prev_pt = None
74 |
75 | def mousePressEvent(self, event):
76 | self.mouse_clicked = True
77 |
78 | def mouseReleaseEvent(self, event):
79 | self.prev_pt = None
80 | self.mouse_clicked = False
81 |
82 | def mouseMoveEvent(self, event): # drawing
83 | if self.mouse_clicked:
84 | if self.prev_pt:
85 | self.drawMask(self.prev_pt, event.scenePos(),
86 | color_list[self.mode], self.size)
87 | pts = {}
88 | pts['prev'] = (int(self.prev_pt.x()), int(self.prev_pt.y()))
89 | pts['curr'] = (int(event.scenePos().x()),
90 | int(event.scenePos().y()))
91 |
92 | self.size_points[self.mode].append(self.size)
93 | self.mask_points[self.mode].append(pts)
94 | self.history.append(self.mode)
95 | self.prev_pt = event.scenePos()
96 | else:
97 | self.prev_pt = event.scenePos()
98 |
99 | def drawMask(self, prev_pt, curr_pt, color, size):
100 | lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
101 | lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
102 | self.addItem(lineItem)
103 |
104 | def erase_prev_pt(self):
105 | self.prev_pt = None
106 |
107 | def reset_items(self):
108 | for i in range(len(self.items())):
109 | item = self.items()[0]
110 | self.removeItem(item)
111 |
112 | def undo(self):
113 | if len(self.items()) > 1:
114 | if len(self.items()) >= 9:
115 | for i in range(8):
116 | item = self.items()[0]
117 | self.removeItem(item)
118 | if self.history[-1] == self.mode:
119 | self.mask_points[self.mode].pop()
120 | self.size_points[self.mode].pop()
121 | self.history.pop()
122 | else:
123 | for i in range(len(self.items()) - 1):
124 | item = self.items()[0]
125 | self.removeItem(item)
126 | if self.history[-1] == self.mode:
127 | self.mask_points[self.mode].pop()
128 | self.size_points[self.mode].pop()
129 | self.history.pop()
130 |
--------------------------------------------------------------------------------
/ui/ui.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtGui, QtWidgets
2 | from PyQt5.QtCore import *
3 | from PyQt5.QtGui import *
4 | from PyQt5.QtWidgets import *
5 |
6 |
7 | class Ui_Form(object):
8 |
9 | def setupUi(self, Form):
10 | Form.setObjectName("Form")
11 | Form.resize(1250, 670)
12 |
13 | self.pushButton_2 = QtWidgets.QPushButton(Form)
14 | self.pushButton_2.setGeometry(QtCore.QRect(20, 60, 97, 27))
15 | self.pushButton_2.setObjectName("pushButton_2")
16 |
17 | self.pushButton_6 = QtWidgets.QPushButton(Form)
18 | self.pushButton_6.setGeometry(QtCore.QRect(20, 100, 97, 27))
19 | self.pushButton_6.setObjectName("pushButton_6")
20 |
21 | # Generate Parsing
22 | self.pushButton_0 = QtWidgets.QPushButton(Form)
23 | self.pushButton_0.setGeometry(QtCore.QRect(126, 60, 150, 27))
24 | self.pushButton_0.setObjectName("pushButton_0")
25 |
26 | # Generate Human
27 | self.pushButton_1 = QtWidgets.QPushButton(Form)
28 | self.pushButton_1.setGeometry(QtCore.QRect(126, 100, 150, 27))
29 | self.pushButton_1.setObjectName("pushButton_1")
30 |
31 | # shape text box
32 | self.label_heading_1 = QtWidgets.QLabel(Form)
33 | self.label_heading_1.setText('Describe the shape.')
34 | self.label_heading_1.setObjectName("label_heading_1")
35 | self.label_heading_1.setGeometry(QtCore.QRect(320, 20, 200, 20))
36 |
37 | self.message_box_1 = QtWidgets.QLineEdit(Form)
38 | self.message_box_1.setGeometry(QtCore.QRect(320, 50, 256, 80))
39 | self.message_box_1.setObjectName("message_box_1")
40 | self.message_box_1.setAlignment(Qt.AlignTop)
41 |
42 | # texture text box
43 | self.label_heading_2 = QtWidgets.QLabel(Form)
44 | self.label_heading_2.setText('Describe the textures.')
45 | self.label_heading_2.setObjectName("label_heading_2")
46 | self.label_heading_2.setGeometry(QtCore.QRect(620, 20, 200, 20))
47 |
48 | self.message_box_2 = QtWidgets.QLineEdit(Form)
49 | self.message_box_2.setGeometry(QtCore.QRect(620, 50, 256, 80))
50 | self.message_box_2.setObjectName("message_box_2")
51 | self.message_box_2.setAlignment(Qt.AlignTop)
52 |
53 | # title icon
54 | self.title_icon = QtWidgets.QLabel(Form)
55 | self.title_icon.setGeometry(QtCore.QRect(30, 10, 200, 50))
56 | self.title_icon.setPixmap(
57 | QtGui.QPixmap('./ui/icons/icon_title.png').scaledToWidth(200))
58 |
59 | # palette icon
60 | self.palette_icon = QtWidgets.QLabel(Form)
61 | self.palette_icon.setGeometry(QtCore.QRect(950, 10, 256, 128))
62 | self.palette_icon.setPixmap(
63 | QtGui.QPixmap('./ui/icons/icon_palette.png').scaledToWidth(256))
64 |
65 | # top
66 | self.pushButton_8 = QtWidgets.QPushButton(' top', Form)
67 | self.pushButton_8.setGeometry(QtCore.QRect(940, 120, 120, 27))
68 | self.pushButton_8.setObjectName("pushButton_8")
69 | self.pushButton_8.setStyleSheet(
70 | "text-align: left; padding-left: 10px;")
71 | self.pushButton_8.setIcon(QIcon('./ui/color_blocks/class_top.png'))
72 | # skin
73 | self.pushButton_9 = QtWidgets.QPushButton(' skin', Form)
74 | self.pushButton_9.setGeometry(QtCore.QRect(940, 165, 120, 27))
75 | self.pushButton_9.setObjectName("pushButton_9")
76 | self.pushButton_9.setStyleSheet(
77 | "text-align: left; padding-left: 10px;")
78 | self.pushButton_9.setIcon(QIcon('./ui/color_blocks/class_skin.png'))
79 | # outer
80 | self.pushButton_10 = QtWidgets.QPushButton(' outer', Form)
81 | self.pushButton_10.setGeometry(QtCore.QRect(940, 210, 120, 27))
82 | self.pushButton_10.setObjectName("pushButton_10")
83 | self.pushButton_10.setStyleSheet(
84 | "text-align: left; padding-left: 10px;")
85 | self.pushButton_10.setIcon(QIcon('./ui/color_blocks/class_outer.png'))
86 | # face
87 | self.pushButton_11 = QtWidgets.QPushButton(' face', Form)
88 | self.pushButton_11.setGeometry(QtCore.QRect(940, 255, 120, 27))
89 | self.pushButton_11.setObjectName("pushButton_11")
90 | self.pushButton_11.setStyleSheet(
91 | "text-align: left; padding-left: 10px;")
92 | self.pushButton_11.setIcon(QIcon('./ui/color_blocks/class_face.png'))
93 | # skirt
94 | self.pushButton_12 = QtWidgets.QPushButton(' skirt', Form)
95 | self.pushButton_12.setGeometry(QtCore.QRect(940, 300, 120, 27))
96 | self.pushButton_12.setObjectName("pushButton_12")
97 | self.pushButton_12.setStyleSheet(
98 | "text-align: left; padding-left: 10px;")
99 | self.pushButton_12.setIcon(QIcon('./ui/color_blocks/class_skirt.png'))
100 | # hair
101 | self.pushButton_13 = QtWidgets.QPushButton(' hair', Form)
102 | self.pushButton_13.setGeometry(QtCore.QRect(940, 345, 120, 27))
103 | self.pushButton_13.setObjectName("pushButton_13")
104 | self.pushButton_13.setStyleSheet(
105 | "text-align: left; padding-left: 10px;")
106 | self.pushButton_13.setIcon(QIcon('./ui/color_blocks/class_hair.png'))
107 | # dress
108 | self.pushButton_14 = QtWidgets.QPushButton(' dress', Form)
109 | self.pushButton_14.setGeometry(QtCore.QRect(940, 390, 120, 27))
110 | self.pushButton_14.setObjectName("pushButton_14")
111 | self.pushButton_14.setStyleSheet(
112 | "text-align: left; padding-left: 10px;")
113 | self.pushButton_14.setIcon(QIcon('./ui/color_blocks/class_dress.png'))
114 | # headwear
115 | self.pushButton_15 = QtWidgets.QPushButton(' headwear', Form)
116 | self.pushButton_15.setGeometry(QtCore.QRect(940, 435, 120, 27))
117 | self.pushButton_15.setObjectName("pushButton_15")
118 | self.pushButton_15.setStyleSheet(
119 | "text-align: left; padding-left: 10px;")
120 | self.pushButton_15.setIcon(
121 | QIcon('./ui/color_blocks/class_headwear.png'))
122 | # pants
123 | self.pushButton_16 = QtWidgets.QPushButton(' pants', Form)
124 | self.pushButton_16.setGeometry(QtCore.QRect(940, 480, 120, 27))
125 | self.pushButton_16.setObjectName("pushButton_16")
126 | self.pushButton_16.setStyleSheet(
127 | "text-align: left; padding-left: 10px;")
128 | self.pushButton_16.setIcon(QIcon('./ui/color_blocks/class_pants.png'))
129 | # eyeglasses
130 | self.pushButton_17 = QtWidgets.QPushButton(' eyeglass', Form)
131 | self.pushButton_17.setGeometry(QtCore.QRect(940, 525, 120, 27))
132 | self.pushButton_17.setObjectName("pushButton_17")
133 | self.pushButton_17.setStyleSheet(
134 | "text-align: left; padding-left: 10px;")
135 | self.pushButton_17.setIcon(
136 | QIcon('./ui/color_blocks/class_eyeglass.png'))
137 | # rompers
138 | self.pushButton_18 = QtWidgets.QPushButton(' rompers', Form)
139 | self.pushButton_18.setGeometry(QtCore.QRect(940, 570, 120, 27))
140 | self.pushButton_18.setObjectName("pushButton_18")
141 | self.pushButton_18.setStyleSheet(
142 | "text-align: left; padding-left: 10px;")
143 | self.pushButton_18.setIcon(
144 | QIcon('./ui/color_blocks/class_rompers.png'))
145 | # footwear
146 | self.pushButton_19 = QtWidgets.QPushButton(' footwear', Form)
147 | self.pushButton_19.setGeometry(QtCore.QRect(940, 615, 120, 27))
148 | self.pushButton_19.setObjectName("pushButton_19")
149 | self.pushButton_19.setStyleSheet(
150 | "text-align: left; padding-left: 10px;")
151 | self.pushButton_19.setIcon(
152 | QIcon('./ui/color_blocks/class_footwear.png'))
153 |
154 | # leggings
155 | self.pushButton_20 = QtWidgets.QPushButton(' leggings', Form)
156 | self.pushButton_20.setGeometry(QtCore.QRect(1100, 120, 120, 27))
157 | self.pushButton_20.setObjectName("pushButton_10")
158 | self.pushButton_20.setStyleSheet(
159 | "text-align: left; padding-left: 10px;")
160 | self.pushButton_20.setIcon(
161 | QIcon('./ui/color_blocks/class_leggings.png'))
162 |
163 | # ring
164 | self.pushButton_21 = QtWidgets.QPushButton(' ring', Form)
165 | self.pushButton_21.setGeometry(QtCore.QRect(1100, 165, 120, 27))
166 | self.pushButton_21.setObjectName("pushButton_2`0`")
167 | self.pushButton_21.setStyleSheet(
168 | "text-align: left; padding-left: 10px;")
169 | self.pushButton_21.setIcon(QIcon('./ui/color_blocks/class_ring.png'))
170 |
171 | # belt
172 | self.pushButton_22 = QtWidgets.QPushButton(' belt', Form)
173 | self.pushButton_22.setGeometry(QtCore.QRect(1100, 210, 120, 27))
174 | self.pushButton_22.setObjectName("pushButton_2`0`")
175 | self.pushButton_22.setStyleSheet(
176 | "text-align: left; padding-left: 10px;")
177 | self.pushButton_22.setIcon(QIcon('./ui/color_blocks/class_belt.png'))
178 |
179 | # neckwear
180 | self.pushButton_23 = QtWidgets.QPushButton(' neckwear', Form)
181 | self.pushButton_23.setGeometry(QtCore.QRect(1100, 255, 120, 27))
182 | self.pushButton_23.setObjectName("pushButton_2`0`")
183 | self.pushButton_23.setStyleSheet(
184 | "text-align: left; padding-left: 10px;")
185 | self.pushButton_23.setIcon(
186 | QIcon('./ui/color_blocks/class_neckwear.png'))
187 |
188 | # wrist
189 | self.pushButton_24 = QtWidgets.QPushButton(' wrist', Form)
190 | self.pushButton_24.setGeometry(QtCore.QRect(1100, 300, 120, 27))
191 | self.pushButton_24.setObjectName("pushButton_2`0`")
192 | self.pushButton_24.setStyleSheet(
193 | "text-align: left; padding-left: 10px;")
194 | self.pushButton_24.setIcon(QIcon('./ui/color_blocks/class_wrist.png'))
195 |
196 | # socks
197 | self.pushButton_25 = QtWidgets.QPushButton(' socks', Form)
198 | self.pushButton_25.setGeometry(QtCore.QRect(1100, 345, 120, 27))
199 | self.pushButton_25.setObjectName("pushButton_2`0`")
200 | self.pushButton_25.setStyleSheet(
201 | "text-align: left; padding-left: 10px;")
202 | self.pushButton_25.setIcon(QIcon('./ui/color_blocks/class_socks.png'))
203 |
204 | # tie
205 | self.pushButton_26 = QtWidgets.QPushButton(' tie', Form)
206 | self.pushButton_26.setGeometry(QtCore.QRect(1100, 390, 120, 27))
207 | self.pushButton_26.setObjectName("pushButton_2`0`")
208 | self.pushButton_26.setStyleSheet(
209 | "text-align: left; padding-left: 10px;")
210 | self.pushButton_26.setIcon(QIcon('./ui/color_blocks/class_tie.png'))
211 |
212 | # earstuds
213 | self.pushButton_27 = QtWidgets.QPushButton(' necklace', Form)
214 | self.pushButton_27.setGeometry(QtCore.QRect(1100, 435, 120, 27))
215 | self.pushButton_27.setObjectName("pushButton_2`0`")
216 | self.pushButton_27.setStyleSheet(
217 | "text-align: left; padding-left: 10px;")
218 | self.pushButton_27.setIcon(
219 | QIcon('./ui/color_blocks/class_necklace.png'))
220 |
221 | # necklace
222 | self.pushButton_28 = QtWidgets.QPushButton(' earstuds', Form)
223 | self.pushButton_28.setGeometry(QtCore.QRect(1100, 480, 120, 27))
224 | self.pushButton_28.setObjectName("pushButton_2`0`")
225 | self.pushButton_28.setStyleSheet(
226 | "text-align: left; padding-left: 10px;")
227 | self.pushButton_28.setIcon(
228 | QIcon('./ui/color_blocks/class_earstuds.png'))
229 |
230 | # bag
231 | self.pushButton_29 = QtWidgets.QPushButton(' bag', Form)
232 | self.pushButton_29.setGeometry(QtCore.QRect(1100, 525, 120, 27))
233 | self.pushButton_29.setObjectName("pushButton_2`0`")
234 | self.pushButton_29.setStyleSheet(
235 | "text-align: left; padding-left: 10px;")
236 | self.pushButton_29.setIcon(QIcon('./ui/color_blocks/class_bag.png'))
237 |
238 | # glove
239 | self.pushButton_30 = QtWidgets.QPushButton(' glove', Form)
240 | self.pushButton_30.setGeometry(QtCore.QRect(1100, 570, 120, 27))
241 | self.pushButton_30.setObjectName("pushButton_2`0`")
242 | self.pushButton_30.setStyleSheet(
243 | "text-align: left; padding-left: 10px;")
244 | self.pushButton_30.setIcon(QIcon('./ui/color_blocks/class_glove.png'))
245 |
246 | # background
247 | self.pushButton_31 = QtWidgets.QPushButton(' background', Form)
248 | self.pushButton_31.setGeometry(QtCore.QRect(1100, 615, 120, 27))
249 | self.pushButton_31.setObjectName("pushButton_2`0`")
250 | self.pushButton_31.setStyleSheet(
251 | "text-align: left; padding-left: 10px;")
252 | self.pushButton_31.setIcon(QIcon('./ui/color_blocks/class_bg.png'))
253 |
254 | self.graphicsView = QtWidgets.QGraphicsView(Form)
255 | self.graphicsView.setGeometry(QtCore.QRect(20, 140, 256, 512))
256 | self.graphicsView.setObjectName("graphicsView")
257 | self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
258 | self.graphicsView_2.setGeometry(QtCore.QRect(320, 140, 256, 512))
259 | self.graphicsView_2.setObjectName("graphicsView_2")
260 | self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
261 | self.graphicsView_3.setGeometry(QtCore.QRect(620, 140, 256, 512))
262 | self.graphicsView_3.setObjectName("graphicsView_3")
263 |
264 | self.retranslateUi(Form)
265 | self.pushButton_2.clicked.connect(Form.open_densepose)
266 | self.pushButton_6.clicked.connect(Form.save_img)
267 | self.pushButton_8.clicked.connect(Form.top_mode)
268 | self.pushButton_9.clicked.connect(Form.skin_mode)
269 | self.pushButton_10.clicked.connect(Form.outer_mode)
270 | self.pushButton_11.clicked.connect(Form.face_mode)
271 | self.pushButton_12.clicked.connect(Form.skirt_mode)
272 | self.pushButton_13.clicked.connect(Form.hair_mode)
273 | self.pushButton_14.clicked.connect(Form.dress_mode)
274 | self.pushButton_15.clicked.connect(Form.headwear_mode)
275 | self.pushButton_16.clicked.connect(Form.pants_mode)
276 | self.pushButton_17.clicked.connect(Form.eyeglass_mode)
277 | self.pushButton_18.clicked.connect(Form.rompers_mode)
278 | self.pushButton_19.clicked.connect(Form.footwear_mode)
279 | self.pushButton_20.clicked.connect(Form.leggings_mode)
280 | self.pushButton_21.clicked.connect(Form.ring_mode)
281 | self.pushButton_22.clicked.connect(Form.belt_mode)
282 | self.pushButton_23.clicked.connect(Form.neckwear_mode)
283 | self.pushButton_24.clicked.connect(Form.wrist_mode)
284 | self.pushButton_25.clicked.connect(Form.socks_mode)
285 | self.pushButton_26.clicked.connect(Form.tie_mode)
286 | self.pushButton_27.clicked.connect(Form.earstuds_mode)
287 | self.pushButton_28.clicked.connect(Form.necklace_mode)
288 | self.pushButton_29.clicked.connect(Form.bag_mode)
289 | self.pushButton_30.clicked.connect(Form.glove_mode)
290 | self.pushButton_31.clicked.connect(Form.background_mode)
291 | self.pushButton_0.clicked.connect(Form.generate_parsing)
292 | self.pushButton_1.clicked.connect(Form.generate_human)
293 |
294 | QtCore.QMetaObject.connectSlotsByName(Form)
295 |
296 | def retranslateUi(self, Form):
297 | _translate = QtCore.QCoreApplication.translate
298 | Form.setWindowTitle(_translate("Form", "Text2Human"))
299 | self.pushButton_2.setText(_translate("Form", "Load Pose"))
300 | self.pushButton_6.setText(_translate("Form", "Save Image"))
301 |
302 | self.pushButton_0.setText(_translate("Form", "Generate Parsing"))
303 | self.pushButton_1.setText(_translate("Form", "Generate Human"))
304 |
305 |
306 | if __name__ == "__main__":
307 | import sys
308 | app = QtWidgets.QApplication(sys.argv)
309 | Form = QtWidgets.QWidget()
310 | ui = Ui_Form()
311 | ui.setupUi(Form)
312 | Form.show()
313 | sys.exit(app.exec_())
314 |
--------------------------------------------------------------------------------
/ui_demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from PyQt5.QtCore import *
8 | from PyQt5.QtGui import *
9 | from PyQt5.QtWidgets import *
10 |
11 | from models.sample_model import SampleFromPoseModel
12 | from ui.mouse_event import GraphicsScene
13 | from ui.ui import Ui_Form
14 | from utils.language_utils import (generate_shape_attributes,
15 | generate_texture_attributes)
16 | from utils.options import dict_to_nonedict, parse
17 |
18 | color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215),
19 | (255, 250, 205), (211, 211, 211), (70, 130, 180),
20 | (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0),
21 | (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139),
22 | (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88),
23 | (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180),
24 | (225, 141, 151)]
25 |
26 |
27 | class Ex(QWidget, Ui_Form):
28 |
29 | def __init__(self, opt):
30 | super(Ex, self).__init__()
31 | self.setupUi(self)
32 | self.show()
33 |
34 | self.output_img = None
35 |
36 | self.mat_img = None
37 |
38 | self.mode = 0
39 | self.size = 6
40 | self.mask = None
41 | self.mask_m = None
42 | self.img = None
43 |
44 | # about UI
45 | self.mouse_clicked = False
46 | self.scene = QGraphicsScene()
47 | self.graphicsView.setScene(self.scene)
48 | self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
49 | self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
50 | self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
51 |
52 | self.ref_scene = GraphicsScene(self.mode, self.size)
53 | self.graphicsView_2.setScene(self.ref_scene)
54 | self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
55 | self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
56 | self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
57 |
58 | self.result_scene = QGraphicsScene()
59 | self.graphicsView_3.setScene(self.result_scene)
60 | self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
61 | self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
62 | self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
63 |
64 | self.dlg = QColorDialog(self.graphicsView)
65 | self.color = None
66 |
67 | self.sample_model = SampleFromPoseModel(opt)
68 |
69 | def open_densepose(self):
70 | fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
71 | QDir.currentPath())
72 | if fileName:
73 | image = QPixmap(fileName)
74 | mat_img = Image.open(fileName)
75 | self.pose_img = mat_img.copy()
76 | if image.isNull():
77 | QMessageBox.information(self, "Image Viewer",
78 | "Cannot load %s." % fileName)
79 | return
80 | image = image.scaled(self.graphicsView.size(),
81 | Qt.IgnoreAspectRatio)
82 |
83 | if len(self.scene.items()) > 0:
84 | self.scene.removeItem(self.scene.items()[-1])
85 | self.scene.addPixmap(image)
86 |
87 | self.ref_scene.clear()
88 | self.result_scene.clear()
89 |
90 | # load pose to model
91 | self.pose_img = np.array(
92 | self.pose_img.resize(
93 | size=(256, 512),
94 | resample=Image.LANCZOS))[:, :, 2:].transpose(
95 | 2, 0, 1).astype(np.float32)
96 | self.pose_img = self.pose_img / 12. - 1
97 |
98 | self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1)
99 |
100 | self.sample_model.feed_pose_data(self.pose_img)
101 |
102 | def generate_parsing(self):
103 | self.ref_scene.reset_items()
104 | self.ref_scene.reset()
105 |
106 | shape_texts = self.message_box_1.text()
107 |
108 | shape_attributes = generate_shape_attributes(shape_texts)
109 | shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
110 | self.sample_model.feed_shape_attributes(shape_attributes)
111 |
112 | self.sample_model.generate_parsing_map()
113 | self.sample_model.generate_quantized_segm()
114 |
115 | self.colored_segm = self.sample_model.palette_result(
116 | self.sample_model.segm[0].cpu())
117 |
118 | self.mask_m = cv2.cvtColor(
119 | cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR),
120 | cv2.COLOR_BGR2RGB)
121 |
122 | qim = QImage(self.colored_segm.data.tobytes(),
123 | self.colored_segm.shape[1], self.colored_segm.shape[0],
124 | QImage.Format_RGB888)
125 |
126 | image = QPixmap.fromImage(qim)
127 |
128 | image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
129 |
130 | if len(self.ref_scene.items()) > 0:
131 | self.ref_scene.removeItem(self.ref_scene.items()[-1])
132 | self.ref_scene.addPixmap(image)
133 |
134 | self.result_scene.clear()
135 |
136 | def generate_human(self):
137 | for i in range(24):
138 | self.mask_m = self.make_mask(self.mask_m,
139 | self.ref_scene.mask_points[i],
140 | self.ref_scene.size_points[i],
141 | color_list[i])
142 |
143 | seg_map = np.full(self.mask_m.shape[:-1], -1)
144 |
145 | # convert rgb to num
146 | for index, color in enumerate(color_list):
147 | seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index
148 | assert (seg_map != -1).all()
149 |
150 | self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze(
151 | 0).unsqueeze(0).to(self.sample_model.device)
152 | self.sample_model.generate_quantized_segm()
153 |
154 | texture_texts = self.message_box_2.text()
155 | texture_attributes = generate_texture_attributes(texture_texts)
156 |
157 | texture_attributes = torch.LongTensor(texture_attributes)
158 |
159 | self.sample_model.feed_texture_attributes(texture_attributes)
160 |
161 | self.sample_model.generate_texture_map()
162 | result = self.sample_model.sample_and_refine()
163 | result = result.permute(0, 2, 3, 1)
164 | result = result.detach().cpu().numpy()
165 | result = result * 255
166 |
167 | result = np.asarray(result[0, :, :, :], dtype=np.uint8)
168 |
169 | self.output_img = result
170 |
171 | qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0],
172 | QImage.Format_RGB888)
173 | image = QPixmap.fromImage(qim)
174 |
175 | image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
176 |
177 | if len(self.result_scene.items()) > 0:
178 | self.result_scene.removeItem(self.result_scene.items()[-1])
179 | self.result_scene.addPixmap(image)
180 |
181 | def top_mode(self):
182 | self.ref_scene.mode = 1
183 |
184 | def skin_mode(self):
185 | self.ref_scene.mode = 15
186 |
187 | def outer_mode(self):
188 | self.ref_scene.mode = 2
189 |
190 | def face_mode(self):
191 | self.ref_scene.mode = 14
192 |
193 | def skirt_mode(self):
194 | self.ref_scene.mode = 3
195 |
196 | def hair_mode(self):
197 | self.ref_scene.mode = 13
198 |
199 | def dress_mode(self):
200 | self.ref_scene.mode = 4
201 |
202 | def headwear_mode(self):
203 | self.ref_scene.mode = 7
204 |
205 | def pants_mode(self):
206 | self.ref_scene.mode = 5
207 |
208 | def eyeglass_mode(self):
209 | self.ref_scene.mode = 8
210 |
211 | def rompers_mode(self):
212 | self.ref_scene.mode = 21
213 |
214 | def footwear_mode(self):
215 | self.ref_scene.mode = 11
216 |
217 | def leggings_mode(self):
218 | self.ref_scene.mode = 6
219 |
220 | def ring_mode(self):
221 | self.ref_scene.mode = 16
222 |
223 | def belt_mode(self):
224 | self.ref_scene.mode = 10
225 |
226 | def neckwear_mode(self):
227 | self.ref_scene.mode = 9
228 |
229 | def wrist_mode(self):
230 | self.ref_scene.mode = 17
231 |
232 | def socks_mode(self):
233 | self.ref_scene.mode = 18
234 |
235 | def tie_mode(self):
236 | self.ref_scene.mode = 23
237 |
238 | def earstuds_mode(self):
239 | self.ref_scene.mode = 22
240 |
241 | def necklace_mode(self):
242 | self.ref_scene.mode = 20
243 |
244 | def bag_mode(self):
245 | self.ref_scene.mode = 12
246 |
247 | def glove_mode(self):
248 | self.ref_scene.mode = 19
249 |
250 | def background_mode(self):
251 | self.ref_scene.mode = 0
252 |
253 | def make_mask(self, mask, pts, sizes, color):
254 | if len(pts) > 0:
255 | for idx, pt in enumerate(pts):
256 | cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx])
257 | return mask
258 |
259 | def save_img(self):
260 | if type(self.output_img):
261 | fileName, _ = QFileDialog.getSaveFileName(self, "Save File",
262 | QDir.currentPath())
263 | cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1])
264 |
265 | def undo(self):
266 | self.scene.undo()
267 |
268 | def clear(self):
269 |
270 | self.ref_scene.reset_items()
271 | self.ref_scene.reset()
272 |
273 | self.ref_scene.clear()
274 |
275 | self.result_scene.clear()
276 |
277 |
278 | if __name__ == '__main__':
279 |
280 | app = QApplication(sys.argv)
281 | opt = './configs/sample_from_pose.yml'
282 | opt = parse(opt, is_train=False)
283 | opt = dict_to_nonedict(opt)
284 | ex = Ex(opt)
285 | sys.exit(app.exec_())
286 |
--------------------------------------------------------------------------------
/ui_util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/ui_util/__init__.py
--------------------------------------------------------------------------------
/ui_util/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import yaml
6 |
7 | logger = logging.getLogger()
8 |
9 | class Config(object):
10 | def __init__(self, filename=None):
11 | assert os.path.exists(filename), "ERROR: Config File doesn't exist."
12 | try:
13 | with open(filename, 'r') as f:
14 | self._cfg_dict = yaml.load(f)
15 | # parent of IOError, OSError *and* WindowsError where available
16 | except EnvironmentError:
17 | logger.error('Please check the file with name of "%s"', filename)
18 | logger.info(' APP CONFIG '.center(80, '-'))
19 | logger.info(''.center(80, '-'))
20 |
21 | def __getattr__(self, name):
22 | value = self._cfg_dict[name]
23 | if isinstance(value, dict):
24 | value = DictAsMember(value)
25 | return value
26 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Human/9a66cbd4a628c144514810ce02fbc4c9924a8d9f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/language_utils.py:
--------------------------------------------------------------------------------
1 | from curses import A_ATTRIBUTES
2 |
3 | import numpy
4 | import torch
5 | from pip import main
6 | from sentence_transformers import SentenceTransformer, util
7 |
8 | # predefined shape text
9 | upper_length_text = [
10 | 'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
11 | 'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
12 | 'with short sleeves', 'medium-sleeve', 'medium sleeves',
13 | 'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
14 | 'long sleeves', 'with long sleeves'
15 | ]
16 | upper_length_attr = {
17 | 'sleeveless': 0,
18 | 'without sleeves': 0,
19 | 'sleeves have been cut off': 0,
20 | 'tank top': 0,
21 | 'tank shirt': 0,
22 | 'muscle shirt': 0,
23 | 'short-sleeve': 1,
24 | 'with short sleeves': 1,
25 | 'short sleeves': 1,
26 | 'medium-sleeve': 2,
27 | 'with medium sleeves': 2,
28 | 'medium sleeves': 2,
29 | 'sleeves reach elbow': 2,
30 | 'long-sleeve': 3,
31 | 'long sleeves': 3,
32 | 'with long sleeves': 3
33 | }
34 | lower_length_text = [
35 | 'three-point', 'medium', 'short', 'covering knee', 'cropped',
36 | 'three-quarter', 'long', 'slack', 'of long length'
37 | ]
38 | lower_length_attr = {
39 | 'three-point': 0,
40 | 'medium': 1,
41 | 'covering knee': 1,
42 | 'short': 1,
43 | 'cropped': 2,
44 | 'three-quarter': 2,
45 | 'long': 3,
46 | 'slack': 3,
47 | 'of long length': 3
48 | }
49 | socks_length_text = [
50 | 'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
51 | ]
52 | socks_length_attr = {
53 | 'socks': 0,
54 | 'stocking': 1,
55 | 'pantyhose': 1,
56 | 'leggings': 1,
57 | 'sheer hosiery': 1
58 | }
59 | hat_text = ['hat', 'cap', 'chapeau']
60 | eyeglasses_text = ['sunglasses']
61 | belt_text = ['belt', 'with a dress tied around the waist']
62 | outer_shape_text = [
63 | 'with outer clothing open', 'with outer clothing unzipped',
64 | 'covering inner clothes', 'with outer clothing zipped'
65 | ]
66 | outer_shape_attr = {
67 | 'with outer clothing open': 0,
68 | 'with outer clothing unzipped': 0,
69 | 'covering inner clothes': 1,
70 | 'with outer clothing zipped': 1
71 | }
72 |
73 | upper_types = [
74 | 'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
75 | ]
76 | outer_types = [
77 | 'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
78 | 'duffle', 'cardigan'
79 | ]
80 | skirt_types = ['skirt']
81 | dress_types = ['dress']
82 | pant_types = ['jeans', 'pants', 'trousers']
83 | rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
84 |
85 | attr_names_list = [
86 | 'gender', 'hair length', '0 upper clothing length',
87 | '1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
88 | '6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
89 | '9 skirt', '10 dress', '11 pants', '12 rompers'
90 | ]
91 |
92 |
93 | def generate_shape_attributes(user_shape_texts):
94 | model = SentenceTransformer('all-MiniLM-L6-v2')
95 | parsed_texts = user_shape_texts.split(',')
96 |
97 | text_num = len(parsed_texts)
98 |
99 | human_attr = [0, 0]
100 | attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
101 |
102 | changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
103 | for text_id, text in enumerate(parsed_texts):
104 | user_embeddings = model.encode(text)
105 | if ('man' in text) and (text_id == 0):
106 | human_attr[0] = 0
107 | human_attr[1] = 0
108 |
109 | if ('woman' in text or 'lady' in text) and (text_id == 0):
110 | human_attr[0] = 1
111 | human_attr[1] = 2
112 |
113 | if (not changed[0]) and (text_id == 1):
114 | # upper length
115 | predefined_embeddings = model.encode(upper_length_text)
116 | similarities = util.dot_score(user_embeddings,
117 | predefined_embeddings)
118 | arg_idx = torch.argmax(similarities).item()
119 | attr[0] = upper_length_attr[upper_length_text[arg_idx]]
120 | changed[0] = 1
121 |
122 | if (not changed[1]) and ((text_num == 2 and text_id == 1) or
123 | (text_num > 2 and text_id == 2)):
124 | # lower length
125 | predefined_embeddings = model.encode(lower_length_text)
126 | similarities = util.dot_score(user_embeddings,
127 | predefined_embeddings)
128 | arg_idx = torch.argmax(similarities).item()
129 | attr[1] = lower_length_attr[lower_length_text[arg_idx]]
130 | changed[1] = 1
131 |
132 | if (not changed[2]) and (text_id > 2):
133 | # socks length
134 | predefined_embeddings = model.encode(socks_length_text)
135 | similarities = util.dot_score(user_embeddings,
136 | predefined_embeddings)
137 | arg_idx = torch.argmax(similarities).item()
138 | if similarities[0][arg_idx] > 0.7:
139 | attr[2] = arg_idx + 1
140 | changed[2] = 1
141 |
142 | if (not changed[3]) and (text_id > 2):
143 | # hat
144 | predefined_embeddings = model.encode(hat_text)
145 | similarities = util.dot_score(user_embeddings,
146 | predefined_embeddings)
147 | if similarities[0][0] > 0.7:
148 | attr[3] = 1
149 | changed[3] = 1
150 |
151 | if (not changed[4]) and (text_id > 2):
152 | # glasses
153 | predefined_embeddings = model.encode(eyeglasses_text)
154 | similarities = util.dot_score(user_embeddings,
155 | predefined_embeddings)
156 | arg_idx = torch.argmax(similarities).item()
157 | if similarities[0][arg_idx] > 0.7:
158 | attr[4] = arg_idx + 1
159 | changed[4] = 1
160 |
161 | if (not changed[5]) and (text_id > 2):
162 | # belt
163 | predefined_embeddings = model.encode(belt_text)
164 | similarities = util.dot_score(user_embeddings,
165 | predefined_embeddings)
166 | arg_idx = torch.argmax(similarities).item()
167 | if similarities[0][arg_idx] > 0.7:
168 | attr[5] = arg_idx + 1
169 | changed[5] = 1
170 |
171 | if (not changed[6]) and (text_id == 3):
172 | # outer coverage
173 | predefined_embeddings = model.encode(outer_shape_text)
174 | similarities = util.dot_score(user_embeddings,
175 | predefined_embeddings)
176 | arg_idx = torch.argmax(similarities).item()
177 | if similarities[0][arg_idx] > 0.7:
178 | attr[6] = arg_idx
179 | changed[6] = 1
180 |
181 | if (not changed[10]) and (text_num == 2 and text_id == 1):
182 | # dress_types
183 | predefined_embeddings = model.encode(dress_types)
184 | similarities = util.dot_score(user_embeddings,
185 | predefined_embeddings)
186 | similarity_skirt = util.dot_score(user_embeddings,
187 | model.encode(skirt_types))
188 | if similarities[0][0] > 0.5 and similarities[0][
189 | 0] > similarity_skirt[0][0]:
190 | attr[10] = 1
191 | attr[7] = 0
192 | attr[8] = 0
193 | attr[9] = 0
194 | attr[11] = 0
195 | attr[12] = 0
196 |
197 | changed[0] = 1
198 | changed[10] = 1
199 | changed[7] = 1
200 | changed[8] = 1
201 | changed[9] = 1
202 | changed[11] = 1
203 | changed[12] = 1
204 |
205 | if (not changed[12]) and (text_num == 2 and text_id == 1):
206 | # rompers_types
207 | predefined_embeddings = model.encode(rompers_types)
208 | similarities = util.dot_score(user_embeddings,
209 | predefined_embeddings)
210 | max_similarity = torch.max(similarities).item()
211 | if max_similarity > 0.6:
212 | attr[12] = 1
213 | attr[7] = 0
214 | attr[8] = 0
215 | attr[9] = 0
216 | attr[10] = 0
217 | attr[11] = 0
218 |
219 | changed[12] = 1
220 | changed[7] = 1
221 | changed[8] = 1
222 | changed[9] = 1
223 | changed[10] = 1
224 | changed[11] = 1
225 |
226 | if (not changed[7]) and (text_num > 2 and text_id == 1):
227 | # upper_types
228 | predefined_embeddings = model.encode(upper_types)
229 | similarities = util.dot_score(user_embeddings,
230 | predefined_embeddings)
231 | max_similarity = torch.max(similarities).item()
232 | if max_similarity > 0.6:
233 | attr[7] = 1
234 | changed[7] = 1
235 |
236 | if (not changed[8]) and (text_id == 3):
237 | # outer_types
238 | predefined_embeddings = model.encode(outer_types)
239 | similarities = util.dot_score(user_embeddings,
240 | predefined_embeddings)
241 | arg_idx = torch.argmax(similarities).item()
242 | if similarities[0][arg_idx] > 0.7:
243 | attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
244 | attr[8] = 1
245 | changed[8] = 1
246 |
247 | if (not changed[9]) and (text_num > 2 and text_id == 2):
248 | # skirt_types
249 | predefined_embeddings = model.encode(skirt_types)
250 | similarity_skirt = util.dot_score(user_embeddings,
251 | predefined_embeddings)
252 | similarity_dress = util.dot_score(user_embeddings,
253 | model.encode(dress_types))
254 | if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
255 | 0] > similarity_dress[0][0]:
256 | attr[9] = 1
257 | attr[10] = 0
258 | changed[9] = 1
259 | changed[10] = 1
260 |
261 | if (not changed[11]) and (text_num > 2 and text_id == 2):
262 | # pant_types
263 | predefined_embeddings = model.encode(pant_types)
264 | similarities = util.dot_score(user_embeddings,
265 | predefined_embeddings)
266 | max_similarity = torch.max(similarities).item()
267 | if max_similarity > 0.6:
268 | attr[11] = 1
269 | attr[9] = 0
270 | attr[10] = 0
271 | attr[12] = 0
272 | changed[11] = 1
273 | changed[9] = 1
274 | changed[10] = 1
275 | changed[12] = 1
276 |
277 | return human_attr + attr
278 |
279 |
280 | def generate_texture_attributes(user_text):
281 | parsed_texts = user_text.split(',')
282 |
283 | attr = []
284 | for text in parsed_texts:
285 | if ('pure color' in text) or ('solid color' in text):
286 | attr.append(4)
287 | elif ('spline' in text) or ('stripe' in text):
288 | attr.append(3)
289 | elif ('plaid' in text) or ('lattice' in text):
290 | attr.append(5)
291 | elif 'floral' in text:
292 | attr.append(1)
293 | elif 'denim' in text:
294 | attr.append(0)
295 | else:
296 | attr.append(17)
297 |
298 | if len(attr) == 1:
299 | attr.append(attr[0])
300 | attr.append(17)
301 |
302 | if len(attr) == 2:
303 | attr.append(17)
304 |
305 | return attr
306 |
307 |
308 | if __name__ == "__main__":
309 | user_request = input('Enter your request: ')
310 | while user_request != '\\q':
311 | attr = generate_shape_attributes(user_request)
312 | print(attr)
313 | for attr_name, attr_value in zip(attr_names_list, attr):
314 | print(attr_name, attr_value)
315 | user_request = input('Enter your request: ')
316 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 |
6 | class MessageLogger():
7 | """Message logger for printing.
8 |
9 | Args:
10 | opt (dict): Config. It contains the following keys:
11 | name (str): Exp name.
12 | logger (dict): Contains 'print_freq' (str) for logger interval.
13 | train (dict): Contains 'niter' (int) for total iters.
14 | use_tb_logger (bool): Use tensorboard logger.
15 | start_iter (int): Start iter. Default: 1.
16 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
17 | """
18 |
19 | def __init__(self, opt, start_iter=1, tb_logger=None):
20 | self.exp_name = opt['name']
21 | self.interval = opt['print_freq']
22 | self.start_iter = start_iter
23 | self.max_iters = opt['max_iters']
24 | self.use_tb_logger = opt['use_tb_logger']
25 | self.tb_logger = tb_logger
26 | self.start_time = time.time()
27 | self.logger = get_root_logger()
28 |
29 | def __call__(self, log_vars):
30 | """Format logging message.
31 |
32 | Args:
33 | log_vars (dict): It contains the following keys:
34 | epoch (int): Epoch number.
35 | iter (int): Current iter.
36 | lrs (list): List for learning rates.
37 |
38 | time (float): Iter time.
39 | data_time (float): Data time for each iter.
40 | """
41 | # epoch, iter, learning rates
42 | epoch = log_vars.pop('epoch')
43 | current_iter = log_vars.pop('iter')
44 | lrs = log_vars.pop('lrs')
45 |
46 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
47 | f'iter:{current_iter:8,d}, lr:(')
48 | for v in lrs:
49 | message += f'{v:.3e},'
50 | message += ')] '
51 |
52 | # time and estimated time
53 | if 'time' in log_vars.keys():
54 | iter_time = log_vars.pop('time')
55 | data_time = log_vars.pop('data_time')
56 |
57 | total_time = time.time() - self.start_time
58 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
59 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
60 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
61 | message += f'[eta: {eta_str}, '
62 | message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
63 |
64 | # other items, especially losses
65 | for k, v in log_vars.items():
66 | message += f'{k}: {v:.4e} '
67 | # tensorboard logger
68 | if self.use_tb_logger and 'debug' not in self.exp_name:
69 | self.tb_logger.add_scalar(k, v, current_iter)
70 |
71 | self.logger.info(message)
72 |
73 |
74 | def init_tb_logger(log_dir):
75 | from torch.utils.tensorboard import SummaryWriter
76 | tb_logger = SummaryWriter(log_dir=log_dir)
77 | return tb_logger
78 |
79 |
80 | def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
81 | """Get the root logger.
82 |
83 | The logger will be initialized if it has not been initialized. By default a
84 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
85 | also be added.
86 |
87 | Args:
88 | logger_name (str): root logger name. Default: base.
89 | log_file (str | None): The log filename. If specified, a FileHandler
90 | will be added to the root logger.
91 | log_level (int): The root logger level. Note that only the process of
92 | rank 0 is affected, while other processes will set the level to
93 | "Error" and be silent most of the time.
94 |
95 | Returns:
96 | logging.Logger: The root logger.
97 | """
98 | logger = logging.getLogger(logger_name)
99 | # if the logger has been initialized, just return it
100 | if logger.hasHandlers():
101 | return logger
102 |
103 | format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
104 | logging.basicConfig(format=format_str, level=log_level)
105 |
106 | if log_file is not None:
107 | file_handler = logging.FileHandler(log_file, 'w')
108 | file_handler.setFormatter(logging.Formatter(format_str))
109 | file_handler.setLevel(log_level)
110 | logger.addHandler(file_handler)
111 |
112 | return logger
113 |
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from collections import OrderedDict
4 |
5 | import yaml
6 |
7 |
8 | def ordered_yaml():
9 | """Support OrderedDict for yaml.
10 |
11 | Returns:
12 | yaml Loader and Dumper.
13 | """
14 | try:
15 | from yaml import CDumper as Dumper
16 | from yaml import CLoader as Loader
17 | except ImportError:
18 | from yaml import Dumper, Loader
19 |
20 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
21 |
22 | def dict_representer(dumper, data):
23 | return dumper.represent_dict(data.items())
24 |
25 | def dict_constructor(loader, node):
26 | return OrderedDict(loader.construct_pairs(node))
27 |
28 | Dumper.add_representer(OrderedDict, dict_representer)
29 | Loader.add_constructor(_mapping_tag, dict_constructor)
30 | return Loader, Dumper
31 |
32 |
33 | def parse(opt_path, is_train=True):
34 | """Parse option file.
35 |
36 | Args:
37 | opt_path (str): Option file path.
38 | is_train (str): Indicate whether in training or not. Default: True.
39 |
40 | Returns:
41 | (dict): Options.
42 | """
43 | with open(opt_path, mode='r') as f:
44 | Loader, _ = ordered_yaml()
45 | opt = yaml.load(f, Loader=Loader)
46 |
47 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
48 | if opt.get('set_CUDA_VISIBLE_DEVICES', None):
49 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
50 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
51 | else:
52 | print('gpu_list: ', gpu_list, flush=True)
53 |
54 | opt['is_train'] = is_train
55 |
56 | # paths
57 | opt['path'] = {}
58 | opt['path']['root'] = osp.abspath(
59 | osp.join(__file__, osp.pardir, osp.pardir))
60 | if is_train:
61 | experiments_root = osp.join(opt['path']['root'], 'experiments',
62 | opt['name'])
63 | opt['path']['experiments_root'] = experiments_root
64 | opt['path']['models'] = osp.join(experiments_root, 'models')
65 | opt['path']['log'] = experiments_root
66 | opt['path']['visualization'] = osp.join(experiments_root,
67 | 'visualization')
68 |
69 | # change some options for debug mode
70 | if 'debug' in opt['name']:
71 | opt['debug'] = True
72 | opt['val_freq'] = 1
73 | opt['print_freq'] = 1
74 | opt['save_checkpoint_freq'] = 1
75 | else: # test
76 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
77 | opt['path']['results_root'] = results_root
78 | opt['path']['log'] = results_root
79 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
80 |
81 | return opt
82 |
83 |
84 | def dict2str(opt, indent_level=1):
85 | """dict to string for printing options.
86 |
87 | Args:
88 | opt (dict): Option dict.
89 | indent_level (int): Indent level. Default: 1.
90 |
91 | Return:
92 | (str): Option string for printing.
93 | """
94 | msg = ''
95 | for k, v in opt.items():
96 | if isinstance(v, dict):
97 | msg += ' ' * (indent_level * 2) + k + ':[\n'
98 | msg += dict2str(v, indent_level + 1)
99 | msg += ' ' * (indent_level * 2) + ']\n'
100 | else:
101 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
102 | return msg
103 |
104 |
105 | class NoneDict(dict):
106 | """None dict. It will return none if key is not in the dict."""
107 |
108 | def __missing__(self, key):
109 | return None
110 |
111 |
112 | def dict_to_nonedict(opt):
113 | """Convert to NoneDict, which returns None for missing keys.
114 |
115 | Args:
116 | opt (dict): Option dict.
117 |
118 | Returns:
119 | (dict): NoneDict for options.
120 | """
121 | if isinstance(opt, dict):
122 | new_opt = dict()
123 | for key, sub_opt in opt.items():
124 | new_opt[key] = dict_to_nonedict(sub_opt)
125 | return NoneDict(**new_opt)
126 | elif isinstance(opt, list):
127 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
128 | else:
129 | return opt
130 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import sys
5 | import time
6 | from shutil import get_terminal_size
7 |
8 | import numpy as np
9 | import torch
10 |
11 | logger = logging.getLogger('base')
12 |
13 |
14 | def make_exp_dirs(opt):
15 | """Make dirs for experiments."""
16 | path_opt = opt['path'].copy()
17 | if opt['is_train']:
18 | overwrite = True if 'debug' in opt['name'] else False
19 | os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
20 | os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
21 | else:
22 | os.makedirs(path_opt.pop('results_root'))
23 |
24 |
25 | def set_random_seed(seed):
26 | """Set random seeds."""
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 |
33 |
34 | class ProgressBar(object):
35 | """A progress bar which can print the progress.
36 |
37 | Modified from:
38 | https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
39 | """
40 |
41 | def __init__(self, task_num=0, bar_width=50, start=True):
42 | self.task_num = task_num
43 | max_bar_width = self._get_max_bar_width()
44 | self.bar_width = (
45 | bar_width if bar_width <= max_bar_width else max_bar_width)
46 | self.completed = 0
47 | if start:
48 | self.start()
49 |
50 | def _get_max_bar_width(self):
51 | terminal_width, _ = get_terminal_size()
52 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
53 | if max_bar_width < 10:
54 | print(f'terminal width is too small ({terminal_width}), '
55 | 'please consider widen the terminal for better '
56 | 'progressbar visualization')
57 | max_bar_width = 10
58 | return max_bar_width
59 |
60 | def start(self):
61 | if self.task_num > 0:
62 | sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
63 | f'elapsed: 0s, ETA:\nStart...\n')
64 | else:
65 | sys.stdout.write('completed: 0, elapsed: 0s')
66 | sys.stdout.flush()
67 | self.start_time = time.time()
68 |
69 | def update(self, msg='In progress...'):
70 | self.completed += 1
71 | elapsed = time.time() - self.start_time
72 | fps = self.completed / elapsed
73 | if self.task_num > 0:
74 | percentage = self.completed / float(self.task_num)
75 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
76 | mark_width = int(self.bar_width * percentage)
77 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
78 | sys.stdout.write('\033[2F') # cursor up 2 lines
79 | sys.stdout.write(
80 | '\033[J'
81 | ) # clean the output (remove extra chars since last display)
82 | sys.stdout.write(
83 | f'[{bar_chars}] {self.completed}/{self.task_num}, '
84 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
85 | f'ETA: {eta:5}s\n{msg}\n')
86 | else:
87 | sys.stdout.write(
88 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
89 | f'{fps:.1f} tasks/s')
90 | sys.stdout.flush()
91 |
92 |
93 | class AverageMeter(object):
94 | """
95 | Computes and stores the average and current value
96 | Imported from
97 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
98 | """
99 |
100 | def __init__(self):
101 | self.reset()
102 |
103 | def reset(self):
104 | self.val = 0
105 | self.avg = 0 # running average = running sum / running count
106 | self.sum = 0 # running sum
107 | self.count = 0 # running count
108 |
109 | def update(self, val, n=1):
110 | # n = batch_size
111 |
112 | # val = batch accuracy for an attribute
113 | # self.val = val
114 |
115 | # sum = 100 * accumulative correct predictions for this attribute
116 | self.sum += val * n
117 |
118 | # count = total samples so far
119 | self.count += n
120 |
121 | # avg = 100 * avg accuracy for this attribute
122 | # for all the batches so far
123 | self.avg = self.sum / self.count
124 |
--------------------------------------------------------------------------------