├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── assets
├── anime_ug.pth
├── clip_ug.pth
├── examples-anime
│ ├── camping.jpg
│ ├── hanfu_girl.jpg
│ ├── miku-canny.png
│ ├── miku.jpg
│ ├── pose.png
│ ├── pose_small.png
│ ├── pose_withhandface.png
│ ├── pose_withhandface_small.png
│ └── random1.jpg
├── examples
│ ├── astronautridinghouse-canny.png
│ ├── astronautridinghouse-input.jpg
│ ├── bedroom-input.jpg
│ ├── bedroom-mlsd.png
│ ├── ghibli-canny.png
│ ├── ghibli-input.jpg
│ ├── grassland-input.jpg
│ ├── grassland-scribble.png
│ ├── jeep-depth.png
│ ├── jeep-input.jpg
│ ├── nightstreet-canny.png
│ ├── nightstreet-input.jpg
│ ├── woodcar-depth.png
│ └── woodcar-input.jpg
└── figures
│ ├── anime.png
│ ├── prompt_free_diffusion.png
│ ├── qualitative_show.png
│ ├── reusability.png
│ └── seecoder.png
├── configs
└── model
│ ├── autokl.yaml
│ ├── clip.yaml
│ ├── controlnet.yaml
│ ├── openai_unet.yaml
│ ├── pfd.yaml
│ ├── seecoder.yaml
│ └── swin.yaml
├── lib
├── __init__.py
├── cfg_helper.py
├── cfg_holder.py
├── log_service.py
├── model_zoo
│ ├── __init__.py
│ ├── attention.py
│ ├── autokl.py
│ ├── autokl_modules.py
│ ├── autokl_utils.py
│ ├── clip.py
│ ├── common
│ │ ├── get_model.py
│ │ ├── get_optimizer.py
│ │ ├── get_scheduler.py
│ │ └── utils.py
│ ├── controlnet.py
│ ├── controlnet_annotator
│ │ ├── canny
│ │ │ └── __init__.py
│ │ ├── hed
│ │ │ └── __init__.py
│ │ ├── midas
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── midas
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_model.py
│ │ │ │ ├── blocks.py
│ │ │ │ ├── dpt_depth.py
│ │ │ │ ├── midas_net.py
│ │ │ │ ├── midas_net_custom.py
│ │ │ │ ├── transforms.py
│ │ │ │ └── vit.py
│ │ │ └── utils.py
│ │ ├── mlsd
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── models
│ │ │ │ ├── mbv2_mlsd_large.py
│ │ │ │ └── mbv2_mlsd_tiny.py
│ │ │ └── utils.py
│ │ ├── openpose
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── body.py
│ │ │ ├── face.py
│ │ │ ├── hand.py
│ │ │ ├── model.py
│ │ │ └── util.py
│ │ └── pidinet
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ └── model.py
│ ├── ddim.py
│ ├── diffusion_utils.py
│ ├── distributions.py
│ ├── ema.py
│ ├── openaimodel.py
│ ├── pfd.py
│ ├── sampler.py
│ ├── seecoder.py
│ └── swin.py
├── sync.py
└── utils.py
├── requirements.txt
└── tools
├── get_controlnet.py
└── model_conversion.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .vscode/
3 | data/
4 | data
5 | log/
6 | log
7 | pretrained/
8 | pretrained
9 | assets/nosync/
10 | assets/demo/temp/temp_*
11 | *.out
12 | gradio_cached_examples/
13 | src/*/build
14 | src/*/dist
15 | src/*/*.egg-info/
16 | extensions/
17 | extensions
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 SHI Labs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt-Free Diffusion
2 |
3 | [](https://huggingface.co/spaces/shi-labs/Prompt-Free-Diffusion)
4 | [](https://pytorch.org/)
5 | [](https://opensource.org/licenses/MIT)
6 |
7 | This repo hosts the official implementation of:
8 |
9 | [Xingqian Xu](https://ifp-uiuc.github.io/), Jiayi Guo, Zhangyang Wang, Gao Huang, Irfan Essa, and [Humphrey Shi](https://www.humphreyshi.com/home), **Prompt-Free Diffusion: Taking "Text" out of Text-to-Image Diffusion Models**, [Paper arXiv Link](https://arxiv.org/abs/2305.16223).
10 |
11 | ## News
12 |
13 | - **[2023.06.20]: SDWebUI plugin is created, repo at this [link](https://github.com/xingqian2018/sd-webui-prompt-free-diffusion)**
14 | - [2023.05.25]: Our demo is running on [HuggingFace🤗](https://huggingface.co/spaces/shi-labs/Prompt-Free-Diffusion)
15 | - [2023.05.25]: Repo created
16 |
17 | ## Introduction
18 |
19 | **Prompt-Free Diffusion** is a diffusion model that relys on only visual inputs to generate new images, handled by **Semantic Context Encoder (SeeCoder)** by substituting the commonly used CLIP-based text encoder. SeeCoder is **reusable to most public T2I models as well as adaptive layers** like ControlNet, LoRA, T2I-Adapter, etc. Just drop in and play!
20 |
21 |
22 |
23 |
24 |
25 | ## Performance
26 |
27 |
28 |
29 |
30 |
31 | ## Network
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ## Setup
42 |
43 | ```
44 | conda create -n prompt-free-diffusion python=3.10
45 | conda activate prompt-free-diffusion
46 | pip install torch==2.0.0+cu117 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu117
47 | pip install -r requirements.txt
48 | ```
49 |
50 | ## Demo
51 |
52 | We provide a WebUI empowered by [Gradio](https://github.com/gradio-app/gradio). Start the WebUI with the following command:
53 |
54 | ```
55 | python app.py
56 | ```
57 |
58 | ## Pretrained models
59 |
60 | To support the full functionality of our demo. You need the following models located in these paths:
61 |
62 | ```
63 | └── pretrained
64 | ├── pfd
65 | | ├── vae
66 | | │ └── sd-v2-0-base-autokl.pth
67 | | ├── diffuser
68 | | │ ├── AbyssOrangeMix-v2.safetensors
69 | | │ ├── AbyssOrangeMix-v3.safetensors
70 | | │ ├── Anything-v4.safetensors
71 | | │ ├── Deliberate-v2-0.safetensors
72 | | │ ├── OpenJouney-v4.safetensors
73 | | │ ├── RealisticVision-v2-0.safetensors
74 | | │ └── SD-v1-5.safetensors
75 | | └── seecoder
76 | | ├── seecoder-v1-0.safetensors
77 | | ├── seecoder-pa-v1-0.safetensors
78 | | └── seecoder-anime-v1-0.safetensors
79 | └── controlnet
80 | ├── control_sd15_canny_slimmed.safetensors
81 | ├── control_sd15_depth_slimmed.safetensors
82 | ├── control_sd15_hed_slimmed.safetensors
83 | ├── control_sd15_mlsd_slimmed.safetensors
84 | ├── control_sd15_normal_slimmed.safetensors
85 | ├── control_sd15_openpose_slimmed.safetensors
86 | ├── control_sd15_scribble_slimmed.safetensors
87 | ├── control_sd15_seg_slimmed.safetensors
88 | ├── control_v11p_sd15_canny_slimmed.safetensors
89 | ├── control_v11p_sd15_lineart_slimmed.safetensors
90 | ├── control_v11p_sd15_mlsd_slimmed.safetensors
91 | ├── control_v11p_sd15_openpose_slimmed.safetensors
92 | ├── control_v11p_sd15s2_lineart_anime_slimmed.safetensors
93 | ├── control_v11p_sd15_softedge_slimmed.safetensors
94 | └── preprocess
95 | ├── hed
96 | │ └── ControlNetHED.pth
97 | ├── midas
98 | │ └── dpt_hybrid-midas-501f0c75.pt
99 | ├── mlsd
100 | │ └── mlsd_large_512_fp32.pth
101 | ├── openpose
102 | │ ├── body_pose_model.pth
103 | │ ├── facenet.pth
104 | │ └── hand_pose_model.pth
105 | └── pidinet
106 | └── table5_pidinet.pth
107 | ```
108 |
109 | All models can be downloaded at [HuggingFace link](https://huggingface.co/shi-labs/prompt-free-diffusion).
110 |
111 | ## Tools
112 |
113 | We also provide tools to convert pretrained models from sdwebui and diffuser library to this codebase, please modify the following files:
114 |
115 | ```
116 | └── tools
117 | ├── get_controlnet.py
118 | └── model_conversion.pth
119 | ```
120 |
121 | You are expected to do some customized coding to make it work (i.e. changing hardcoded input output file paths)
122 |
123 | ## Performance Anime
124 |
125 |
126 |
127 |
128 |
129 | ## Citation
130 |
131 | ```
132 | @article{xu2023prompt,
133 | title={Prompt-Free Diffusion: Taking" Text" out of Text-to-Image Diffusion Models},
134 | author={Xu, Xingqian and Guo, Jiayi and Wang, Zhangyang and Huang, Gao and Essa, Irfan and Shi, Humphrey},
135 | journal={arXiv preprint arXiv:2305.16223},
136 | year={2023}
137 | }
138 | ```
139 |
140 | ## Acknowledgement
141 |
142 | Part of the codes reorganizes/reimplements code from the following repositories: [Versatile Diffusion official Github](https://github.com/SHI-Labs/Versatile-Diffusion) and [ControlNet sdwebui Github](https://github.com/Mikubill/sd-webui-controlnet), which are also great influenced by [LDM official Github](https://github.com/CompVis/latent-diffusion) and [DDPM official Github](https://github.com/lucidrains/denoising-diffusion-pytorch)
143 |
--------------------------------------------------------------------------------
/assets/anime_ug.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/anime_ug.pth
--------------------------------------------------------------------------------
/assets/clip_ug.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/clip_ug.pth
--------------------------------------------------------------------------------
/assets/examples-anime/camping.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/camping.jpg
--------------------------------------------------------------------------------
/assets/examples-anime/hanfu_girl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/hanfu_girl.jpg
--------------------------------------------------------------------------------
/assets/examples-anime/miku-canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/miku-canny.png
--------------------------------------------------------------------------------
/assets/examples-anime/miku.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/miku.jpg
--------------------------------------------------------------------------------
/assets/examples-anime/pose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose.png
--------------------------------------------------------------------------------
/assets/examples-anime/pose_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_small.png
--------------------------------------------------------------------------------
/assets/examples-anime/pose_withhandface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_withhandface.png
--------------------------------------------------------------------------------
/assets/examples-anime/pose_withhandface_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_withhandface_small.png
--------------------------------------------------------------------------------
/assets/examples-anime/random1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/random1.jpg
--------------------------------------------------------------------------------
/assets/examples/astronautridinghouse-canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/astronautridinghouse-canny.png
--------------------------------------------------------------------------------
/assets/examples/astronautridinghouse-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/astronautridinghouse-input.jpg
--------------------------------------------------------------------------------
/assets/examples/bedroom-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/bedroom-input.jpg
--------------------------------------------------------------------------------
/assets/examples/bedroom-mlsd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/bedroom-mlsd.png
--------------------------------------------------------------------------------
/assets/examples/ghibli-canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/ghibli-canny.png
--------------------------------------------------------------------------------
/assets/examples/ghibli-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/ghibli-input.jpg
--------------------------------------------------------------------------------
/assets/examples/grassland-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/grassland-input.jpg
--------------------------------------------------------------------------------
/assets/examples/grassland-scribble.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/grassland-scribble.png
--------------------------------------------------------------------------------
/assets/examples/jeep-depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/jeep-depth.png
--------------------------------------------------------------------------------
/assets/examples/jeep-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/jeep-input.jpg
--------------------------------------------------------------------------------
/assets/examples/nightstreet-canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/nightstreet-canny.png
--------------------------------------------------------------------------------
/assets/examples/nightstreet-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/nightstreet-input.jpg
--------------------------------------------------------------------------------
/assets/examples/woodcar-depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/woodcar-depth.png
--------------------------------------------------------------------------------
/assets/examples/woodcar-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/woodcar-input.jpg
--------------------------------------------------------------------------------
/assets/figures/anime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/anime.png
--------------------------------------------------------------------------------
/assets/figures/prompt_free_diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/prompt_free_diffusion.png
--------------------------------------------------------------------------------
/assets/figures/qualitative_show.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/qualitative_show.png
--------------------------------------------------------------------------------
/assets/figures/reusability.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/reusability.png
--------------------------------------------------------------------------------
/assets/figures/seecoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/seecoder.png
--------------------------------------------------------------------------------
/configs/model/autokl.yaml:
--------------------------------------------------------------------------------
1 | autokl:
2 | symbol: autokl
3 | find_unused_parameters: false
4 |
5 | autokl_v1:
6 | super_cfg: autokl
7 | type: autoencoderkl
8 | args:
9 | embed_dim: 4
10 | ddconfig:
11 | double_z: true
12 | z_channels: 4
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult: [1, 2, 4, 4]
18 | num_res_blocks: 2
19 | attn_resolutions: []
20 | dropout: 0.0
21 | lossconfig: null
22 | pth: pretrained/kl-f8.pth
23 |
24 | autokl_v2:
25 | super_cfg: autokl_v1
26 | pth: pretrained/pfd/vae/sd-v2-0-base-autokl.pth
27 |
--------------------------------------------------------------------------------
/configs/model/clip.yaml:
--------------------------------------------------------------------------------
1 | ################
2 | # clip for sd1 #
3 | ################
4 |
5 | clip:
6 | symbol: clip
7 | args: {}
8 |
9 | clip_text_context_encoder_sdv1:
10 | super_cfg: clip
11 | type: clip_text_context_encoder_sdv1
12 | args: {}
13 |
--------------------------------------------------------------------------------
/configs/model/controlnet.yaml:
--------------------------------------------------------------------------------
1 | controlnet:
2 | symbol: controlnet
3 | type: controlnet
4 | find_unused_parameters: false
5 | args:
6 | image_size: 32 # unused
7 | in_channels: 4
8 | hint_channels: 3
9 | model_channels: 320
10 | attention_resolutions: [ 4, 2, 1 ]
11 | num_res_blocks: 2
12 | channel_mult: [ 1, 2, 4, 4 ]
13 | num_heads: 8
14 | use_spatial_transformer: True
15 | transformer_depth: 1
16 | context_dim: 768
17 | use_checkpoint: True
18 | legacy: False
19 |
--------------------------------------------------------------------------------
/configs/model/openai_unet.yaml:
--------------------------------------------------------------------------------
1 | openai_unet_sd:
2 | type: openai_unet
3 | args:
4 | image_size: null # no use
5 | in_channels: 4
6 | out_channels: 4
7 | model_channels: 320
8 | attention_resolutions: [ 4, 2, 1 ]
9 | num_res_blocks: [ 2, 2, 2, 2 ]
10 | channel_mult: [ 1, 2, 4, 4 ]
11 | # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
12 | num_heads: 8
13 | use_spatial_transformer: True
14 | transformer_depth: 1
15 | context_dim: 768
16 | use_checkpoint: True
17 | legacy: False
18 |
19 | #########
20 | # v1 2d #
21 | #########
22 |
23 | openai_unet_2d_v1:
24 | type: openai_unet_2d_next
25 | args:
26 | in_channels: 4
27 | out_channels: 4
28 | model_channels: 320
29 | attention_resolutions: [ 4, 2, 1 ]
30 | num_res_blocks: [ 2, 2, 2, 2 ]
31 | channel_mult: [ 1, 2, 4, 4 ]
32 | num_heads: 8
33 | context_dim: 768
34 | use_checkpoint: False
35 | parts: [global, data, context]
36 |
--------------------------------------------------------------------------------
/configs/model/pfd.yaml:
--------------------------------------------------------------------------------
1 | pfd_base:
2 | symbol: pfd
3 | find_unused_parameters: true
4 | type: pfd
5 | args:
6 | beta_linear_start: 0.00085
7 | beta_linear_end: 0.012
8 | timesteps: 1000
9 | use_ema: false
10 |
11 | pfd_seecoder:
12 | super_cfg: pfd_base
13 | args:
14 | vae_cfg_list:
15 | - [image, MODEL(autokl_v2)]
16 | ctx_cfg_list:
17 | - [image, MODEL(seecoder)]
18 | diffuser_cfg_list:
19 | - [image, MODEL(openai_unet_2d_v1)]
20 | latent_scale_factor:
21 | image: 0.18215
22 |
23 | pdf_seecoder_pa:
24 | super_cfg: pfd_seecoder
25 | args:
26 | ctx_cfg_list:
27 | - [image, MODEL(seecoder_pa)]
28 |
29 | pfd_seecoder_with_controlnet:
30 | super_cfg: pfd_seecoder
31 | type: pfd_with_control
32 | args:
33 | ctl_cfg: MODEL(controlnet)
34 |
--------------------------------------------------------------------------------
/configs/model/seecoder.yaml:
--------------------------------------------------------------------------------
1 | seecoder_base:
2 | symbol: seecoder
3 | args: {}
4 |
5 | seecoder:
6 | super_cfg: seecoder_base
7 | type: seecoder
8 | args:
9 | imencoder_cfg : MODEL(swin_large)
10 | imdecoder_cfg : MODEL(seecoder_decoder)
11 | qtransformer_cfg : MODEL(seecoder_query_transformer)
12 |
13 | seecoder_pa:
14 | super_cfg: seet
15 | type: seecoder
16 | args:
17 | imencoder_cfg : MODEL(swin_large)
18 | imdecoder_cfg : MODEL(seecoder_decoder)
19 | qtransformer_cfg : MODEL(seecoder_query_transformer_position_aware)
20 |
21 | ###########
22 | # decoder #
23 | ###########
24 |
25 | seecoder_decoder:
26 | super_cfg: seecoder_base
27 | type: seecoder_decoder
28 | args:
29 | inchannels:
30 | res3: 384
31 | res4: 768
32 | res5: 1536
33 | trans_input_tags: ['res3', 'res4', 'res5']
34 | trans_dim: 768
35 | trans_dropout: 0.1
36 | trans_nheads: 8
37 | trans_feedforward_dim: 1024
38 | trans_num_layers: 6
39 |
40 | #####################
41 | # query_transformer #
42 | #####################
43 |
44 | seecoder_query_transformer:
45 | super_cfg: seecoder_base
46 | type: seecoder_query_transformer
47 | args:
48 | in_channels : 768
49 | hidden_dim: 768
50 | num_queries: [4, 144]
51 | nheads: 8
52 | num_layers: 9
53 | feedforward_dim: 2048
54 | pre_norm: False
55 | num_feature_levels: 3
56 | enforce_input_project: False
57 | with_fea2d_pos: false
58 |
59 | seecoder_query_transformer_position_aware:
60 | super_cfg: seecoder_query_transformer
61 | args:
62 | with_fea2d_pos: true
63 |
--------------------------------------------------------------------------------
/configs/model/swin.yaml:
--------------------------------------------------------------------------------
1 | swin:
2 | symbol: swin
3 | args: {}
4 |
5 | swin_base:
6 | super_cfg: swin
7 | type: swin
8 | args:
9 | embed_dim: 128
10 | depths: [ 2, 2, 18, 2 ]
11 | num_heads: [ 4, 8, 16, 32 ]
12 | window_size: 7
13 | ape: False
14 | drop_path_rate: 0.3
15 | patch_norm: True
16 | strict_sd: False
17 |
18 | swin_large:
19 | super_cfg: swin
20 | type: swin
21 | args:
22 | embed_dim: 192
23 | depths: [ 2, 2, 18, 2 ]
24 | num_heads: [ 6, 12, 24, 48 ]
25 | window_size: 12
26 | ape: False
27 | drop_path_rate: 0.3
28 | patch_norm: True
29 | strict_sd: False
30 |
31 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/lib/__init__.py
--------------------------------------------------------------------------------
/lib/cfg_holder.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | def singleton(class_):
4 | instances = {}
5 | def getinstance(*args, **kwargs):
6 | if class_ not in instances:
7 | instances[class_] = class_(*args, **kwargs)
8 | return instances[class_]
9 | return getinstance
10 |
11 | ##############
12 | # cfg_holder #
13 | ##############
14 |
15 | @singleton
16 | class cfg_unique_holder(object):
17 | def __init__(self):
18 | self.cfg = None
19 | # this is use to track the main codes.
20 | self.code = set()
21 | def save_cfg(self, cfg):
22 | self.cfg = copy.deepcopy(cfg)
23 | def add_code(self, code):
24 | """
25 | A new main code is reached and
26 | its name is added.
27 | """
28 | self.code.add(code)
29 |
--------------------------------------------------------------------------------
/lib/log_service.py:
--------------------------------------------------------------------------------
1 | import timeit
2 | import numpy as np
3 | import os
4 | import os.path as osp
5 | import shutil
6 | import copy
7 | import torch
8 | import torch.nn as nn
9 | import torch.distributed as dist
10 | from .cfg_holder import cfg_unique_holder as cfguh
11 | from . import sync
12 |
13 | def print_log(*console_info):
14 | grank, lrank, _ = sync.get_rank('all')
15 | if lrank!=0:
16 | return
17 |
18 | console_info = [str(i) for i in console_info]
19 | console_info = ' '.join(console_info)
20 | print(console_info)
21 |
22 | if grank!=0:
23 | return
24 |
25 | log_file = None
26 | try:
27 | log_file = cfguh().cfg.train.log_file
28 | except:
29 | try:
30 | log_file = cfguh().cfg.eval.log_file
31 | except:
32 | return
33 | if log_file is not None:
34 | with open(log_file, 'a') as f:
35 | f.write(console_info + '\n')
36 |
37 | class distributed_log_manager(object):
38 | def __init__(self):
39 | self.sum = {}
40 | self.cnt = {}
41 | self.time_check = timeit.default_timer()
42 |
43 | cfgt = cfguh().cfg.train
44 | self.ddp = sync.is_ddp()
45 | self.grank, self.lrank, _ = sync.get_rank('all')
46 | self.gwsize = sync.get_world_size('global')
47 |
48 | use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0)
49 |
50 | self.tb = None
51 | if use_tensorboard:
52 | import tensorboardX
53 | monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
54 | self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
55 |
56 | def accumulate(self, n, **data):
57 | if n < 0:
58 | raise ValueError
59 |
60 | for itemn, di in data.items():
61 | if itemn in self.sum:
62 | self.sum[itemn] += di * n
63 | self.cnt[itemn] += n
64 | else:
65 | self.sum[itemn] = di * n
66 | self.cnt[itemn] = n
67 |
68 | def get_mean_value_dict(self):
69 | value_gather = [
70 | self.sum[itemn]/self.cnt[itemn] \
71 | for itemn in sorted(self.sum.keys()) ]
72 |
73 | value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank)
74 | if self.ddp:
75 | dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
76 | value_gather_tensor /= self.gwsize
77 |
78 | mean = {}
79 | for idx, itemn in enumerate(sorted(self.sum.keys())):
80 | mean[itemn] = value_gather_tensor[idx].item()
81 | return mean
82 |
83 | def tensorboard_log(self, step, data, mode='train', **extra):
84 | if self.tb is None:
85 | return
86 | if mode == 'train':
87 | self.tb.add_scalar('other/epochn', extra['epochn'], step)
88 | if ('lr' in extra) and (extra['lr'] is not None):
89 | self.tb.add_scalar('other/lr', extra['lr'], step)
90 | for itemn, di in data.items():
91 | if itemn.find('loss') == 0:
92 | self.tb.add_scalar('loss/'+itemn, di, step)
93 | elif itemn == 'Loss':
94 | self.tb.add_scalar('Loss', di, step)
95 | else:
96 | self.tb.add_scalar('other/'+itemn, di, step)
97 | elif mode == 'eval':
98 | if isinstance(data, dict):
99 | for itemn, di in data.items():
100 | self.tb.add_scalar('eval/'+itemn, di, step)
101 | else:
102 | self.tb.add_scalar('eval', data, step)
103 | return
104 |
105 | def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
106 | console_info = [
107 | 'Iter:{}'.format(itern),
108 | 'Epoch:{}'.format(epochn),
109 | 'Sample:{}'.format(samplen),]
110 |
111 | if lr is not None:
112 | console_info += ['LR:{:.4E}'.format(lr)]
113 |
114 | mean = self.get_mean_value_dict()
115 |
116 | tbstep = itern if tbstep is None else tbstep
117 | self.tensorboard_log(
118 | tbstep, mean, mode='train',
119 | itern=itern, epochn=epochn, lr=lr)
120 |
121 | loss = mean.pop('Loss')
122 | mean_info = ['Loss:{:.4f}'.format(loss)] + [
123 | '{}:{:.4f}'.format(itemn, mean[itemn]) \
124 | for itemn in sorted(mean.keys()) \
125 | if itemn.find('loss') == 0
126 | ]
127 | console_info += mean_info
128 | console_info.append('Time:{:.2f}s'.format(
129 | timeit.default_timer() - self.time_check))
130 | return ' , '.join(console_info)
131 |
132 | def clear(self):
133 | self.sum = {}
134 | self.cnt = {}
135 | self.time_check = timeit.default_timer()
136 |
137 | def tensorboard_close(self):
138 | if self.tb is not None:
139 | self.tb.close()
140 |
141 | # ----- also include some small utils -----
142 |
143 | def torch_to_numpy(*argv):
144 | if len(argv) > 1:
145 | data = list(argv)
146 | else:
147 | data = argv[0]
148 |
149 | if isinstance(data, torch.Tensor):
150 | return data.to('cpu').detach().numpy()
151 |
152 | elif isinstance(data, (list, tuple)):
153 | out = []
154 | for di in data:
155 | out.append(torch_to_numpy(di))
156 | return out
157 |
158 | elif isinstance(data, dict):
159 | out = {}
160 | for ni, di in data.items():
161 | out[ni] = torch_to_numpy(di)
162 | return out
163 |
164 | else:
165 | return data
166 |
--------------------------------------------------------------------------------
/lib/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .common.get_model import get_model
2 | from .common.get_optimizer import get_optimizer
3 | from .common.get_scheduler import get_scheduler
4 | from .common.utils import get_unit
5 |
--------------------------------------------------------------------------------
/lib/model_zoo/autokl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 | from lib.model_zoo.common.get_model import get_model, register
6 |
7 | # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8 |
9 | from .autokl_modules import Encoder, Decoder
10 | from .distributions import DiagonalGaussianDistribution
11 |
12 | from .autokl_utils import LPIPSWithDiscriminator
13 |
14 | @register('autoencoderkl')
15 | class AutoencoderKL(nn.Module):
16 | def __init__(self,
17 | ddconfig,
18 | lossconfig,
19 | embed_dim,):
20 | super().__init__()
21 | self.encoder = Encoder(**ddconfig)
22 | self.decoder = Decoder(**ddconfig)
23 | if lossconfig is not None:
24 | self.loss = LPIPSWithDiscriminator(**lossconfig)
25 | assert ddconfig["double_z"]
26 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
27 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
28 | self.embed_dim = embed_dim
29 |
30 | @torch.no_grad()
31 | def encode(self, x, out_posterior=False):
32 | return self.encode_trainable(x, out_posterior)
33 |
34 | def encode_trainable(self, x, out_posterior=False):
35 | x = x*2-1
36 | h = self.encoder(x)
37 | moments = self.quant_conv(h)
38 | posterior = DiagonalGaussianDistribution(moments)
39 | if out_posterior:
40 | return posterior
41 | else:
42 | return posterior.sample()
43 |
44 | @torch.no_grad()
45 | def decode(self, z):
46 | dec = self.decode_trainable(z)
47 | dec = torch.clamp(dec, 0, 1)
48 | return dec
49 |
50 | def decode_trainable(self, z):
51 | z = self.post_quant_conv(z)
52 | dec = self.decoder(z)
53 | dec = (dec+1)/2
54 | return dec
55 |
56 | def apply_model(self, input, sample_posterior=True):
57 | posterior = self.encode_trainable(input, out_posterior=True)
58 | if sample_posterior:
59 | z = posterior.sample()
60 | else:
61 | z = posterior.mode()
62 | dec = self.decode_trainable(z)
63 | return dec, posterior
64 |
65 | def get_input(self, batch, k):
66 | x = batch[k]
67 | if len(x.shape) == 3:
68 | x = x[..., None]
69 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
70 | return x
71 |
72 | def forward(self, x, optimizer_idx, global_step):
73 | reconstructions, posterior = self.apply_model(x)
74 |
75 | if optimizer_idx == 0:
76 | # train encoder+decoder+logvar
77 | aeloss, log_dict_ae = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
78 | last_layer=self.get_last_layer(), split="train")
79 | return aeloss, log_dict_ae
80 |
81 | if optimizer_idx == 1:
82 | # train the discriminator
83 | discloss, log_dict_disc = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
84 | last_layer=self.get_last_layer(), split="train")
85 |
86 | return discloss, log_dict_disc
87 |
88 | def validation_step(self, batch, batch_idx):
89 | inputs = self.get_input(batch, self.image_key)
90 | reconstructions, posterior = self(inputs)
91 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
92 | last_layer=self.get_last_layer(), split="val")
93 |
94 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
95 | last_layer=self.get_last_layer(), split="val")
96 |
97 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
98 | self.log_dict(log_dict_ae)
99 | self.log_dict(log_dict_disc)
100 | return self.log_dict
101 |
102 | def configure_optimizers(self):
103 | lr = self.learning_rate
104 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
105 | list(self.decoder.parameters())+
106 | list(self.quant_conv.parameters())+
107 | list(self.post_quant_conv.parameters()),
108 | lr=lr, betas=(0.5, 0.9))
109 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
110 | lr=lr, betas=(0.5, 0.9))
111 | return [opt_ae, opt_disc], []
112 |
113 | def get_last_layer(self):
114 | return self.decoder.conv_out.weight
115 |
116 | @torch.no_grad()
117 | def log_images(self, batch, only_inputs=False, **kwargs):
118 | log = dict()
119 | x = self.get_input(batch, self.image_key)
120 | x = x.to(self.device)
121 | if not only_inputs:
122 | xrec, posterior = self(x)
123 | if x.shape[1] > 3:
124 | # colorize with random projection
125 | assert xrec.shape[1] > 3
126 | x = self.to_rgb(x)
127 | xrec = self.to_rgb(xrec)
128 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
129 | log["reconstructions"] = xrec
130 | log["inputs"] = x
131 | return log
132 |
133 | def to_rgb(self, x):
134 | assert self.image_key == "segmentation"
135 | if not hasattr(self, "colorize"):
136 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
137 | x = F.conv2d(x, weight=self.colorize)
138 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
139 | return x
140 |
141 | @register('autoencoderkl_customnorm')
142 | class AutoencoderKL_CustomNorm(AutoencoderKL):
143 | def __init__(self, *args, **kwargs):
144 | super().__init__(*args, **kwargs)
145 | self.mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073])
146 | self.std = torch.Tensor([0.26862954, 0.26130258, 0.27577711])
147 |
148 | def encode_trainable(self, x, out_posterior=False):
149 | m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
150 | s = self.std[None, :, None, None].to(z.device).to(z.dtype)
151 | x = (x-m)/s
152 | h = self.encoder(x)
153 | moments = self.quant_conv(h)
154 | posterior = DiagonalGaussianDistribution(moments)
155 | if out_posterior:
156 | return posterior
157 | else:
158 | return posterior.sample()
159 |
160 | def decode_trainable(self, z):
161 | m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
162 | s = self.std[None, :, None, None].to(z.device).to(z.dtype)
163 | z = self.post_quant_conv(z)
164 | dec = self.decoder(z)
165 | dec = (dec+1)/2
166 | return dec
167 |
--------------------------------------------------------------------------------
/lib/model_zoo/common/get_model.py:
--------------------------------------------------------------------------------
1 | from email.policy import strict
2 | import torch
3 | import torchvision.models
4 | import os.path as osp
5 | import copy
6 | from ...log_service import print_log
7 | from .utils import \
8 | get_total_param, get_total_param_sum, \
9 | get_unit
10 |
11 | # def load_state_dict(net, model_path):
12 | # if isinstance(net, dict):
13 | # for ni, neti in net.items():
14 | # paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
15 | # new_paras = neti.state_dict()
16 | # new_paras.update(paras)
17 | # neti.load_state_dict(new_paras)
18 | # else:
19 | # paras = torch.load(model_path, map_location=torch.device('cpu'))
20 | # new_paras = net.state_dict()
21 | # new_paras.update(paras)
22 | # net.load_state_dict(new_paras)
23 | # return
24 |
25 | # def save_state_dict(net, path):
26 | # if isinstance(net, (torch.nn.DataParallel,
27 | # torch.nn.parallel.DistributedDataParallel)):
28 | # torch.save(net.module.state_dict(), path)
29 | # else:
30 | # torch.save(net.state_dict(), path)
31 |
32 | def singleton(class_):
33 | instances = {}
34 | def getinstance(*args, **kwargs):
35 | if class_ not in instances:
36 | instances[class_] = class_(*args, **kwargs)
37 | return instances[class_]
38 | return getinstance
39 |
40 | def preprocess_model_args(args):
41 | # If args has layer_units, get the corresponding
42 | # units.
43 | # If args get backbone, get the backbone model.
44 | args = copy.deepcopy(args)
45 | if 'layer_units' in args:
46 | layer_units = [
47 | get_unit()(i) for i in args.layer_units
48 | ]
49 | args.layer_units = layer_units
50 | if 'backbone' in args:
51 | args.backbone = get_model()(args.backbone)
52 | return args
53 |
54 | @singleton
55 | class get_model(object):
56 | def __init__(self):
57 | self.model = {}
58 |
59 | def register(self, model, name):
60 | self.model[name] = model
61 |
62 | def __call__(self, cfg, verbose=True):
63 | """
64 | Construct model based on the config.
65 | """
66 | if cfg is None:
67 | return None
68 |
69 | t = cfg.type
70 |
71 | # the register is in each file
72 | if t.find('pfd')==0:
73 | from .. import pfd
74 | elif t=='autoencoderkl':
75 | from .. import autokl
76 | elif (t.find('clip')==0) or (t.find('openclip')==0):
77 | from .. import clip
78 | elif t.find('openai_unet')==0:
79 | from .. import openaimodel
80 | elif t.find('controlnet')==0:
81 | from .. import controlnet
82 | elif t.find('seecoder')==0:
83 | from .. import seecoder
84 | elif t.find('swin')==0:
85 | from .. import swin
86 |
87 | args = preprocess_model_args(cfg.args)
88 | net = self.model[t](**args)
89 |
90 | pretrained = cfg.get('pretrained', None)
91 | if pretrained is None: # backward compatible
92 | pretrained = cfg.get('pth', None)
93 | map_location = cfg.get('map_location', 'cpu')
94 | strict_sd = cfg.get('strict_sd', True)
95 |
96 | if pretrained is not None:
97 | if osp.splitext(pretrained)[1] == '.pth':
98 | sd = torch.load(pretrained, map_location=map_location)
99 | elif osp.splitext(pretrained)[1] == '.ckpt':
100 | sd = torch.load(pretrained, map_location=map_location)['state_dict']
101 | elif osp.splitext(pretrained)[1] == '.safetensors':
102 | from safetensors.torch import load_file
103 | from collections import OrderedDict
104 | sd = load_file(pretrained, map_location)
105 | sd = OrderedDict(sd)
106 | net.load_state_dict(sd, strict=strict_sd)
107 | if verbose:
108 | print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd))
109 |
110 | # display param_num & param_sum
111 | if verbose:
112 | print_log(
113 | 'Load {} with total {} parameters,'
114 | '{:.3f} parameter sum.'.format(
115 | t,
116 | get_total_param(net),
117 | get_total_param_sum(net) ))
118 | return net
119 |
120 | def register(name):
121 | def wrapper(class_):
122 | get_model().register(class_, name)
123 | return class_
124 | return wrapper
125 |
--------------------------------------------------------------------------------
/lib/model_zoo/common/get_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import numpy as np
4 | import itertools
5 |
6 | def singleton(class_):
7 | instances = {}
8 | def getinstance(*args, **kwargs):
9 | if class_ not in instances:
10 | instances[class_] = class_(*args, **kwargs)
11 | return instances[class_]
12 | return getinstance
13 |
14 | class get_optimizer(object):
15 | def __init__(self):
16 | self.optimizer = {}
17 | self.register(optim.SGD, 'sgd')
18 | self.register(optim.Adam, 'adam')
19 | self.register(optim.AdamW, 'adamw')
20 |
21 | def register(self, optim, name):
22 | self.optimizer[name] = optim
23 |
24 | def __call__(self, net, cfg):
25 | if cfg is None:
26 | return None
27 | t = cfg.type
28 | if isinstance(net, (torch.nn.DataParallel,
29 | torch.nn.parallel.DistributedDataParallel)):
30 | netm = net.module
31 | else:
32 | netm = net
33 | pg = getattr(netm, 'parameter_group', None)
34 |
35 | if pg is not None:
36 | params = []
37 | for group_name, module_or_para in pg.items():
38 | if not isinstance(module_or_para, list):
39 | module_or_para = [module_or_para]
40 |
41 | grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
42 | grouped_params = itertools.chain(*grouped_params)
43 | pg_dict = {'params':grouped_params, 'name':group_name}
44 | params.append(pg_dict)
45 | else:
46 | params = net.parameters()
47 | return self.optimizer[t](params, lr=0, **cfg.args)
48 |
--------------------------------------------------------------------------------
/lib/model_zoo/common/get_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import numpy as np
4 | import copy
5 | from ... import sync
6 | from ...cfg_holder import cfg_unique_holder as cfguh
7 |
8 | def singleton(class_):
9 | instances = {}
10 | def getinstance(*args, **kwargs):
11 | if class_ not in instances:
12 | instances[class_] = class_(*args, **kwargs)
13 | return instances[class_]
14 | return getinstance
15 |
16 | @singleton
17 | class get_scheduler(object):
18 | def __init__(self):
19 | self.lr_scheduler = {}
20 |
21 | def register(self, lrsf, name):
22 | self.lr_scheduler[name] = lrsf
23 |
24 | def __call__(self, cfg):
25 | if cfg is None:
26 | return None
27 | if isinstance(cfg, list):
28 | schedulers = []
29 | for ci in cfg:
30 | t = ci.type
31 | schedulers.append(
32 | self.lr_scheduler[t](**ci.args))
33 | if len(schedulers) == 0:
34 | raise ValueError
35 | else:
36 | return compose_scheduler(schedulers)
37 | t = cfg.type
38 | return self.lr_scheduler[t](**cfg.args)
39 |
40 |
41 | def register(name):
42 | def wrapper(class_):
43 | get_scheduler().register(class_, name)
44 | return class_
45 | return wrapper
46 |
47 | class template_scheduler(object):
48 | def __init__(self, step):
49 | self.step = step
50 |
51 | def __getitem__(self, idx):
52 | raise ValueError
53 |
54 | def set_lr(self, optim, new_lr, pg_lrscale=None):
55 | """
56 | Set Each parameter_groups in optim with new_lr
57 | New_lr can be find according to the idx.
58 | pg_lrscale tells how to scale each pg.
59 | """
60 | # new_lr = self.__getitem__(idx)
61 | pg_lrscale = copy.deepcopy(pg_lrscale)
62 | for pg in optim.param_groups:
63 | if pg_lrscale is None:
64 | pg['lr'] = new_lr
65 | else:
66 | pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
67 | assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
68 | "pg_lrscale doesn't match pg"
69 |
70 | @register('constant')
71 | class constant_scheduler(template_scheduler):
72 | def __init__(self, lr, step):
73 | super().__init__(step)
74 | self.lr = lr
75 |
76 | def __getitem__(self, idx):
77 | if idx >= self.step:
78 | raise ValueError
79 | return self.lr
80 |
81 | @register('poly')
82 | class poly_scheduler(template_scheduler):
83 | def __init__(self, start_lr, end_lr, power, step):
84 | super().__init__(step)
85 | self.start_lr = start_lr
86 | self.end_lr = end_lr
87 | self.power = power
88 |
89 | def __getitem__(self, idx):
90 | if idx >= self.step:
91 | raise ValueError
92 | a, b = self.start_lr, self.end_lr
93 | p, n = self.power, self.step
94 | return b + (a-b)*((1-idx/n)**p)
95 |
96 | @register('linear')
97 | class linear_scheduler(template_scheduler):
98 | def __init__(self, start_lr, end_lr, step):
99 | super().__init__(step)
100 | self.start_lr = start_lr
101 | self.end_lr = end_lr
102 |
103 | def __getitem__(self, idx):
104 | if idx >= self.step:
105 | raise ValueError
106 | a, b, n = self.start_lr, self.end_lr, self.step
107 | return b + (a-b)*(1-idx/n)
108 |
109 | @register('multistage')
110 | class constant_scheduler(template_scheduler):
111 | def __init__(self, start_lr, milestones, gamma, step):
112 | super().__init__(step)
113 | self.start_lr = start_lr
114 | m = [0] + milestones + [step]
115 | lr_iter = start_lr
116 | self.lr = []
117 | for ms, me in zip(m[0:-1], m[1:]):
118 | for _ in range(ms, me):
119 | self.lr.append(lr_iter)
120 | lr_iter *= gamma
121 |
122 | def __getitem__(self, idx):
123 | if idx >= self.step:
124 | raise ValueError
125 | return self.lr[idx]
126 |
127 | class compose_scheduler(template_scheduler):
128 | def __init__(self, schedulers):
129 | self.schedulers = schedulers
130 | self.step = [si.step for si in schedulers]
131 | self.step_milestone = []
132 | acc = 0
133 | for i in self.step:
134 | acc += i
135 | self.step_milestone.append(acc)
136 | self.step = sum(self.step)
137 |
138 | def __getitem__(self, idx):
139 | if idx >= self.step:
140 | raise ValueError
141 | ms = self.step_milestone
142 | for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
143 | if mi <= idx < mj:
144 | return self.schedulers[idx-mi]
145 | raise ValueError
146 |
147 | ####################
148 | # lambda schedular #
149 | ####################
150 |
151 | class LambdaWarmUpCosineScheduler(template_scheduler):
152 | """
153 | note: use with a base_lr of 1.0
154 | """
155 | def __init__(self,
156 | base_lr,
157 | warm_up_steps,
158 | lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
159 | cfgt = cfguh().cfg.train
160 | bs = cfgt.batch_size
161 | if 'gradacc_every' not in cfgt:
162 | print('Warning, gradacc_every is not found in xml, use 1 as default.')
163 | acc = cfgt.get('gradacc_every', 1)
164 | self.lr_multi = base_lr * bs * acc
165 | self.lr_warm_up_steps = warm_up_steps
166 | self.lr_start = lr_start
167 | self.lr_min = lr_min
168 | self.lr_max = lr_max
169 | self.lr_max_decay_steps = max_decay_steps
170 | self.last_lr = 0.
171 | self.verbosity_interval = verbosity_interval
172 |
173 | def schedule(self, n):
174 | if self.verbosity_interval > 0:
175 | if n % self.verbosity_interval == 0:
176 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
177 | if n < self.lr_warm_up_steps:
178 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
179 | self.last_lr = lr
180 | return lr
181 | else:
182 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
183 | t = min(t, 1.0)
184 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
185 | 1 + np.cos(t * np.pi))
186 | self.last_lr = lr
187 | return lr
188 |
189 | def __getitem__(self, idx):
190 | return self.schedule(idx) * self.lr_multi
191 |
192 | class LambdaWarmUpCosineScheduler2(template_scheduler):
193 | """
194 | supports repeated iterations, configurable via lists
195 | note: use with a base_lr of 1.0.
196 | """
197 | def __init__(self,
198 | base_lr,
199 | warm_up_steps,
200 | f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
201 | cfgt = cfguh().cfg.train
202 | # bs = cfgt.batch_size
203 | # if 'gradacc_every' not in cfgt:
204 | # print('Warning, gradacc_every is not found in xml, use 1 as default.')
205 | # acc = cfgt.get('gradacc_every', 1)
206 | # self.lr_multi = base_lr * bs * acc
207 | self.lr_multi = base_lr
208 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
209 | self.lr_warm_up_steps = warm_up_steps
210 | self.f_start = f_start
211 | self.f_min = f_min
212 | self.f_max = f_max
213 | self.cycle_lengths = cycle_lengths
214 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
215 | self.last_f = 0.
216 | self.verbosity_interval = verbosity_interval
217 |
218 | def find_in_interval(self, n):
219 | interval = 0
220 | for cl in self.cum_cycles[1:]:
221 | if n <= cl:
222 | return interval
223 | interval += 1
224 |
225 | def schedule(self, n):
226 | cycle = self.find_in_interval(n)
227 | n = n - self.cum_cycles[cycle]
228 | if self.verbosity_interval > 0:
229 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
230 | f"current cycle {cycle}")
231 | if n < self.lr_warm_up_steps[cycle]:
232 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
233 | self.last_f = f
234 | return f
235 | else:
236 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
237 | t = min(t, 1.0)
238 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
239 | 1 + np.cos(t * np.pi))
240 | self.last_f = f
241 | return f
242 |
243 | def __getitem__(self, idx):
244 | return self.schedule(idx) * self.lr_multi
245 |
246 | @register('stable_diffusion_linear')
247 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
248 | def schedule(self, n):
249 | cycle = self.find_in_interval(n)
250 | n = n - self.cum_cycles[cycle]
251 | if self.verbosity_interval > 0:
252 | if n % self.verbosity_interval == 0:
253 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
254 | f"current cycle {cycle}")
255 | if n < self.lr_warm_up_steps[cycle]:
256 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
257 | self.last_f = f
258 | return f
259 | else:
260 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
261 | self.last_f = f
262 | return f
--------------------------------------------------------------------------------
/lib/model_zoo/common/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | import functools
7 | import itertools
8 |
9 | import matplotlib.pyplot as plt
10 |
11 | ########
12 | # unit #
13 | ########
14 |
15 | def singleton(class_):
16 | instances = {}
17 | def getinstance(*args, **kwargs):
18 | if class_ not in instances:
19 | instances[class_] = class_(*args, **kwargs)
20 | return instances[class_]
21 | return getinstance
22 |
23 | def str2value(v):
24 | v = v.strip()
25 | try:
26 | return int(v)
27 | except:
28 | pass
29 | try:
30 | return float(v)
31 | except:
32 | pass
33 | if v in ('True', 'true'):
34 | return True
35 | elif v in ('False', 'false'):
36 | return False
37 | else:
38 | return v
39 |
40 | @singleton
41 | class get_unit(object):
42 | def __init__(self):
43 | self.unit = {}
44 | self.register('none', None)
45 |
46 | # general convolution
47 | self.register('conv' , nn.Conv2d)
48 | self.register('bn' , nn.BatchNorm2d)
49 | self.register('relu' , nn.ReLU)
50 | self.register('relu6' , nn.ReLU6)
51 | self.register('lrelu' , nn.LeakyReLU)
52 | self.register('dropout' , nn.Dropout)
53 | self.register('dropout2d', nn.Dropout2d)
54 | self.register('sine', Sine)
55 | self.register('relusine', ReLUSine)
56 |
57 | def register(self,
58 | name,
59 | unitf,):
60 |
61 | self.unit[name] = unitf
62 |
63 | def __call__(self, name):
64 | if name is None:
65 | return None
66 | i = name.find('(')
67 | i = len(name) if i==-1 else i
68 | t = name[:i]
69 | f = self.unit[t]
70 | args = name[i:].strip('()')
71 | if len(args) == 0:
72 | args = {}
73 | return f
74 | else:
75 | args = args.split('=')
76 | args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
77 | args = list(itertools.chain.from_iterable(args))
78 | args = [i.strip() for i in args if len(i)>0]
79 | kwargs = {}
80 | for k, v in zip(args[::2], args[1::2]):
81 | if v[0]=='(' and v[-1]==')':
82 | kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
83 | elif v[0]=='[' and v[-1]==']':
84 | kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
85 | else:
86 | kwargs[k] = str2value(v)
87 | return functools.partial(f, **kwargs)
88 |
89 | def register(name):
90 | def wrapper(class_):
91 | get_unit().register(name, class_)
92 | return class_
93 | return wrapper
94 |
95 | class Sine(object):
96 | def __init__(self, freq, gain=1):
97 | self.freq = freq
98 | self.gain = gain
99 | self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
100 |
101 | def __call__(self, x, gain=1):
102 | act_gain = self.gain * gain
103 | return torch.sin(self.freq * x) * act_gain
104 |
105 | def __repr__(self,):
106 | return self.repr
107 |
108 | class ReLUSine(nn.Module):
109 | def __init(self):
110 | super().__init__()
111 |
112 | def forward(self, input):
113 | a = torch.sin(30 * input)
114 | b = nn.ReLU(inplace=False)(input)
115 | return a+b
116 |
117 | @register('lrelu_agc')
118 | # class lrelu_agc(nn.Module):
119 | class lrelu_agc(object):
120 | """
121 | The lrelu layer with alpha, gain and clamp
122 | """
123 | def __init__(self, alpha=0.1, gain=1, clamp=None):
124 | # super().__init__()
125 | self.alpha = alpha
126 | if gain == 'sqrt_2':
127 | self.gain = np.sqrt(2)
128 | else:
129 | self.gain = gain
130 | self.clamp = clamp
131 | self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
132 | alpha, gain, clamp)
133 |
134 | # def forward(self, x, gain=1):
135 | def __call__(self, x, gain=1):
136 | x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
137 | act_gain = self.gain * gain
138 | act_clamp = self.clamp * gain if self.clamp is not None else None
139 | if act_gain != 1:
140 | x = x * act_gain
141 | if act_clamp is not None:
142 | x = x.clamp(-act_clamp, act_clamp)
143 | return x
144 |
145 | def __repr__(self,):
146 | return self.repr
147 |
148 | ####################
149 | # spatial encoding #
150 | ####################
151 |
152 | @register('se')
153 | class SpatialEncoding(nn.Module):
154 | def __init__(self,
155 | in_dim,
156 | out_dim,
157 | sigma = 6,
158 | cat_input=True,
159 | require_grad=False,):
160 |
161 | super().__init__()
162 | assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
163 |
164 | n = out_dim // 2 // in_dim
165 | m = 2**np.linspace(0, sigma, n)
166 | m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
167 | m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
168 | self.emb = torch.FloatTensor(m)
169 | if require_grad:
170 | self.emb = nn.Parameter(self.emb, requires_grad=True)
171 | self.in_dim = in_dim
172 | self.out_dim = out_dim
173 | self.sigma = sigma
174 | self.cat_input = cat_input
175 | self.require_grad = require_grad
176 |
177 | def forward(self, x, format='[n x c]'):
178 | """
179 | Args:
180 | x: [n x m1],
181 | m1 usually is 2
182 | Outputs:
183 | y: [n x m2]
184 | m2 dimention number
185 | """
186 | if format == '[bs x c x 2D]':
187 | xshape = x.shape
188 | x = x.permute(0, 2, 3, 1).contiguous()
189 | x = x.view(-1, x.size(-1))
190 | elif format == '[n x c]':
191 | pass
192 | else:
193 | raise ValueError
194 |
195 | if not self.require_grad:
196 | self.emb = self.emb.to(x.device)
197 | y = torch.mm(x, self.emb.T)
198 | if self.cat_input:
199 | z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
200 | else:
201 | z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
202 |
203 | if format == '[bs x c x 2D]':
204 | z = z.view(xshape[0], xshape[2], xshape[3], -1)
205 | z = z.permute(0, 3, 1, 2).contiguous()
206 | return z
207 |
208 | def extra_repr(self):
209 | outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
210 | self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
211 | return outstr
212 |
213 | @register('rffe')
214 | class RFFEncoding(SpatialEncoding):
215 | """
216 | Random Fourier Features
217 | """
218 | def __init__(self,
219 | in_dim,
220 | out_dim,
221 | sigma = 6,
222 | cat_input=True,
223 | require_grad=False,):
224 |
225 | super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
226 | n = out_dim // 2
227 | m = np.random.normal(0, sigma, size=(n, in_dim))
228 | self.emb = torch.FloatTensor(m)
229 | if require_grad:
230 | self.emb = nn.Parameter(self.emb, requires_grad=True)
231 |
232 | def extra_repr(self):
233 | outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
234 | self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
235 | return outstr
236 |
237 | ##########
238 | # helper #
239 | ##########
240 |
241 | def freeze(net):
242 | for m in net.modules():
243 | if isinstance(m, (
244 | nn.BatchNorm2d,
245 | nn.SyncBatchNorm,)):
246 | # inplace_abn not supported
247 | m.eval()
248 | for pi in net.parameters():
249 | pi.requires_grad = False
250 | return net
251 |
252 | def common_init(m):
253 | if isinstance(m, (
254 | nn.Conv2d,
255 | nn.ConvTranspose2d,)):
256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257 | if m.bias is not None:
258 | nn.init.constant_(m.bias, 0)
259 | elif isinstance(m, (
260 | nn.BatchNorm2d,
261 | nn.SyncBatchNorm,)):
262 | nn.init.constant_(m.weight, 1)
263 | nn.init.constant_(m.bias, 0)
264 | else:
265 | pass
266 |
267 | def init_module(module):
268 | """
269 | Args:
270 | module: [nn.module] list or nn.module
271 | a list of module to be initialized.
272 | """
273 | if isinstance(module, (list, tuple)):
274 | module = list(module)
275 | else:
276 | module = [module]
277 |
278 | for mi in module:
279 | for mii in mi.modules():
280 | common_init(mii)
281 |
282 | def get_total_param(net):
283 | if getattr(net, 'parameters', None) is None:
284 | return 0
285 | return sum(p.numel() for p in net.parameters())
286 |
287 | def get_total_param_sum(net):
288 | if getattr(net, 'parameters', None) is None:
289 | return 0
290 | with torch.no_grad():
291 | s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
292 | return s
293 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/canny/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 |
4 | def apply_canny(img, low_threshold, high_threshold):
5 | return cv2.Canny(img, low_threshold, high_threshold)
6 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/hed/__init__.py:
--------------------------------------------------------------------------------
1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2 | # Please use this implementation in your products
3 | # This implementation may produce slightly different results from Saining Xie's official implementations,
4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6 | # and in this way it works better for gradio's RGB protocol
7 |
8 | import os
9 | import cv2
10 | import torch
11 | import numpy as np
12 |
13 | from einops import rearrange
14 | import os
15 |
16 | models_path = 'pretrained/controlnet/preprocess'
17 |
18 | def safe_step(x, step=2):
19 | y = x.astype(np.float32) * float(step + 1)
20 | y = y.astype(np.int32).astype(np.float32) / float(step)
21 | return y
22 |
23 | class DoubleConvBlock(torch.nn.Module):
24 | def __init__(self, input_channel, output_channel, layer_number):
25 | super().__init__()
26 | self.convs = torch.nn.Sequential()
27 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
28 | for i in range(1, layer_number):
29 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
30 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
31 |
32 | def __call__(self, x, down_sampling=False):
33 | h = x
34 | if down_sampling:
35 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
36 | for conv in self.convs:
37 | h = conv(h)
38 | h = torch.nn.functional.relu(h)
39 | return h, self.projection(h)
40 |
41 |
42 | class ControlNetHED_Apache2(torch.nn.Module):
43 | def __init__(self):
44 | super().__init__()
45 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
46 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
47 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
48 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
49 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
50 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
51 |
52 | def __call__(self, x):
53 | h = x - self.norm
54 | h, projection1 = self.block1(h)
55 | h, projection2 = self.block2(h, down_sampling=True)
56 | h, projection3 = self.block3(h, down_sampling=True)
57 | h, projection4 = self.block4(h, down_sampling=True)
58 | h, projection5 = self.block5(h, down_sampling=True)
59 | return projection1, projection2, projection3, projection4, projection5
60 |
61 |
62 | netNetwork = None
63 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
64 | modeldir = os.path.join(models_path, "hed")
65 | old_modeldir = os.path.dirname(os.path.realpath(__file__))
66 |
67 |
68 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
69 | """Load file form http url, will download models if necessary.
70 |
71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
72 |
73 | Args:
74 | url (str): URL to be downloaded.
75 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
76 | Default: None.
77 | progress (bool): Whether to show the download progress. Default: True.
78 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
79 |
80 | Returns:
81 | str: The path to the downloaded file.
82 | """
83 | from torch.hub import download_url_to_file, get_dir
84 | from urllib.parse import urlparse
85 | if model_dir is None: # use the pytorch hub_dir
86 | hub_dir = get_dir()
87 | model_dir = os.path.join(hub_dir, 'checkpoints')
88 |
89 | os.makedirs(model_dir, exist_ok=True)
90 |
91 | parts = urlparse(url)
92 | filename = os.path.basename(parts.path)
93 | if file_name is not None:
94 | filename = file_name
95 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
96 | if not os.path.exists(cached_file):
97 | print(f'Downloading: "{url}" to {cached_file}\n')
98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
99 | return cached_file
100 |
101 |
102 | def apply_hed(input_image, is_safe=False, device='cpu'):
103 | global netNetwork
104 | if netNetwork is None:
105 | modelpath = os.path.join(modeldir, "ControlNetHED.pth")
106 | old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
107 | if os.path.exists(old_modelpath):
108 | modelpath = old_modelpath
109 | elif not os.path.exists(modelpath):
110 | load_file_from_url(remote_model_path, model_dir=modeldir)
111 | netNetwork = ControlNetHED_Apache2().to(device)
112 | netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
113 | netNetwork.to(device).float().eval()
114 |
115 | assert input_image.ndim == 3
116 | H, W, C = input_image.shape
117 | with torch.no_grad():
118 | image_hed = torch.from_numpy(input_image.copy()).float().to(device)
119 | image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
120 | edges = netNetwork(image_hed)
121 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
122 | edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
123 | edges = np.stack(edges, axis=2)
124 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
125 | if is_safe:
126 | edge = safe_step(edge)
127 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
128 | return edge
129 |
130 |
131 | def unload_hed_model():
132 | global netNetwork
133 | if netNetwork is not None:
134 | netNetwork.cpu()
135 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | from einops import rearrange
6 | from .api import MiDaSInference
7 |
8 | model = None
9 |
10 | def unload_midas_model():
11 | global model
12 | if model is not None:
13 | model = model.cpu()
14 |
15 | def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1, device='cpu'):
16 | global model
17 | if model is None:
18 | model = MiDaSInference(model_type="dpt_hybrid")
19 | model = model.to(device)
20 |
21 | assert input_image.ndim == 3
22 | image_depth = input_image
23 | with torch.no_grad():
24 | image_depth = torch.from_numpy(image_depth).float()
25 | image_depth = image_depth.to(device)
26 | image_depth = image_depth / 127.5 - 1.0
27 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
28 | depth = model(image_depth)[0]
29 |
30 | depth_pt = depth.clone()
31 | depth_pt -= torch.min(depth_pt)
32 | depth_pt /= torch.max(depth_pt)
33 | depth_pt = depth_pt.cpu().numpy()
34 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
35 |
36 | depth_np = depth.cpu().numpy()
37 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
38 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
39 | z = np.ones_like(x) * a
40 | x[depth_pt < bg_th] = 0
41 | y[depth_pt < bg_th] = 0
42 | normal = np.stack([x, y, z], axis=2)
43 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
44 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1]
45 |
46 | return depth_image, normal_image
47 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | import os
7 | models_path = 'pretrained/controlnet/preprocess'
8 |
9 | from torchvision.transforms import Compose
10 |
11 | from .midas.dpt_depth import DPTDepthModel
12 | from .midas.midas_net import MidasNet
13 | from .midas.midas_net_custom import MidasNet_small
14 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet
15 |
16 | base_model_path = os.path.join(models_path, "midas")
17 | old_modeldir = os.path.dirname(os.path.realpath(__file__))
18 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
19 |
20 | ISL_PATHS = {
21 | "dpt_large": os.path.join(base_model_path, "dpt_large-midas-2f21e586.pt"),
22 | "dpt_hybrid": os.path.join(base_model_path, "dpt_hybrid-midas-501f0c75.pt"),
23 | "midas_v21": "",
24 | "midas_v21_small": "",
25 | }
26 |
27 | OLD_ISL_PATHS = {
28 | "dpt_large": os.path.join(old_modeldir, "dpt_large-midas-2f21e586.pt"),
29 | "dpt_hybrid": os.path.join(old_modeldir, "dpt_hybrid-midas-501f0c75.pt"),
30 | "midas_v21": "",
31 | "midas_v21_small": "",
32 | }
33 |
34 |
35 | def disabled_train(self, mode=True):
36 | """Overwrite model.train with this function to make sure train/eval mode
37 | does not change anymore."""
38 | return self
39 |
40 |
41 | def load_midas_transform(model_type):
42 | # https://github.com/isl-org/MiDaS/blob/master/run.py
43 | # load transform only
44 | if model_type == "dpt_large": # DPT-Large
45 | net_w, net_h = 384, 384
46 | resize_mode = "minimal"
47 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
48 |
49 | elif model_type == "dpt_hybrid": # DPT-Hybrid
50 | net_w, net_h = 384, 384
51 | resize_mode = "minimal"
52 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
53 |
54 | elif model_type == "midas_v21":
55 | net_w, net_h = 384, 384
56 | resize_mode = "upper_bound"
57 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58 |
59 | elif model_type == "midas_v21_small":
60 | net_w, net_h = 256, 256
61 | resize_mode = "upper_bound"
62 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63 |
64 | else:
65 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
66 |
67 | transform = Compose(
68 | [
69 | Resize(
70 | net_w,
71 | net_h,
72 | resize_target=None,
73 | keep_aspect_ratio=True,
74 | ensure_multiple_of=32,
75 | resize_method=resize_mode,
76 | image_interpolation_method=cv2.INTER_CUBIC,
77 | ),
78 | normalization,
79 | PrepareForNet(),
80 | ]
81 | )
82 |
83 | return transform
84 |
85 |
86 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
87 | """Load file form http url, will download models if necessary.
88 |
89 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
90 |
91 | Args:
92 | url (str): URL to be downloaded.
93 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
94 | Default: None.
95 | progress (bool): Whether to show the download progress. Default: True.
96 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
97 |
98 | Returns:
99 | str: The path to the downloaded file.
100 | """
101 | from torch.hub import download_url_to_file, get_dir
102 | from urllib.parse import urlparse
103 | if model_dir is None: # use the pytorch hub_dir
104 | hub_dir = get_dir()
105 | model_dir = os.path.join(hub_dir, 'checkpoints')
106 |
107 | os.makedirs(model_dir, exist_ok=True)
108 |
109 | parts = urlparse(url)
110 | filename = os.path.basename(parts.path)
111 | if file_name is not None:
112 | filename = file_name
113 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
114 | if not os.path.exists(cached_file):
115 | print(f'Downloading: "{url}" to {cached_file}\n')
116 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
117 | return cached_file
118 |
119 |
120 | def load_model(model_type):
121 | # https://github.com/isl-org/MiDaS/blob/master/run.py
122 | # load network
123 | model_path = ISL_PATHS[model_type]
124 | old_model_path = OLD_ISL_PATHS[model_type]
125 | if model_type == "dpt_large": # DPT-Large
126 | model = DPTDepthModel(
127 | path=model_path,
128 | backbone="vitl16_384",
129 | non_negative=True,
130 | )
131 | net_w, net_h = 384, 384
132 | resize_mode = "minimal"
133 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
134 |
135 | elif model_type == "dpt_hybrid": # DPT-Hybrid
136 | if os.path.exists(old_model_path):
137 | model_path = old_model_path
138 | elif not os.path.exists(model_path):
139 | load_file_from_url(remote_model_path, model_dir=base_model_path)
140 |
141 | model = DPTDepthModel(
142 | path=model_path,
143 | backbone="vitb_rn50_384",
144 | non_negative=True,
145 | )
146 | net_w, net_h = 384, 384
147 | resize_mode = "minimal"
148 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
149 |
150 | elif model_type == "midas_v21":
151 | model = MidasNet(model_path, non_negative=True)
152 | net_w, net_h = 384, 384
153 | resize_mode = "upper_bound"
154 | normalization = NormalizeImage(
155 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
156 | )
157 |
158 | elif model_type == "midas_v21_small":
159 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
160 | non_negative=True, blocks={'expand': True})
161 | net_w, net_h = 256, 256
162 | resize_mode = "upper_bound"
163 | normalization = NormalizeImage(
164 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
165 | )
166 |
167 | else:
168 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
169 | assert False
170 |
171 | transform = Compose(
172 | [
173 | Resize(
174 | net_w,
175 | net_h,
176 | resize_target=None,
177 | keep_aspect_ratio=True,
178 | ensure_multiple_of=32,
179 | resize_method=resize_mode,
180 | image_interpolation_method=cv2.INTER_CUBIC,
181 | ),
182 | normalization,
183 | PrepareForNet(),
184 | ]
185 | )
186 |
187 | return model.eval(), transform
188 |
189 |
190 | class MiDaSInference(nn.Module):
191 | MODEL_TYPES_TORCH_HUB = [
192 | "DPT_Large",
193 | "DPT_Hybrid",
194 | "MiDaS_small"
195 | ]
196 | MODEL_TYPES_ISL = [
197 | "dpt_large",
198 | "dpt_hybrid",
199 | "midas_v21",
200 | "midas_v21_small",
201 | ]
202 |
203 | def __init__(self, model_type):
204 | super().__init__()
205 | assert (model_type in self.MODEL_TYPES_ISL)
206 | model, _ = load_model(model_type)
207 | self.model = model
208 | self.model.train = disabled_train
209 |
210 | def forward(self, x):
211 | with torch.no_grad():
212 | prediction = self.model(x)
213 | return prediction
214 |
215 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/lib/model_zoo/controlnet_annotator/midas/midas/__init__.py
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/mlsd/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2021-present NAVER Corp.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/mlsd/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import os
5 |
6 | from einops import rearrange
7 | from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
8 | from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
9 | from .utils import pred_lines
10 |
11 | models_path = 'pretrained/controlnet/preprocess'
12 |
13 | mlsdmodel = None
14 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
15 | old_modeldir = os.path.dirname(os.path.realpath(__file__))
16 | modeldir = os.path.join(models_path, "mlsd")
17 |
18 | def unload_mlsd_model():
19 | global mlsdmodel
20 | if mlsdmodel is not None:
21 | mlsdmodel = mlsdmodel.cpu()
22 |
23 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
24 | """Load file form http url, will download models if necessary.
25 |
26 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
27 |
28 | Args:
29 | url (str): URL to be downloaded.
30 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
31 | Default: None.
32 | progress (bool): Whether to show the download progress. Default: True.
33 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
34 |
35 | Returns:
36 | str: The path to the downloaded file.
37 | """
38 | from torch.hub import download_url_to_file, get_dir
39 | from urllib.parse import urlparse
40 | if model_dir is None: # use the pytorch hub_dir
41 | hub_dir = get_dir()
42 | model_dir = os.path.join(hub_dir, 'checkpoints')
43 |
44 | os.makedirs(model_dir, exist_ok=True)
45 |
46 | parts = urlparse(url)
47 | filename = os.path.basename(parts.path)
48 | if file_name is not None:
49 | filename = file_name
50 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
51 | if not os.path.exists(cached_file):
52 | print(f'Downloading: "{url}" to {cached_file}\n')
53 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
54 | return cached_file
55 |
56 | def apply_mlsd(input_image, thr_v, thr_d, device='cpu'):
57 | global modelpath, mlsdmodel
58 | if mlsdmodel is None:
59 | modelpath = os.path.join(modeldir, "mlsd_large_512_fp32.pth")
60 | old_modelpath = os.path.join(old_modeldir, "mlsd_large_512_fp32.pth")
61 | if os.path.exists(old_modelpath):
62 | modelpath = old_modelpath
63 | elif not os.path.exists(modelpath):
64 | load_file_from_url(remote_model_path, model_dir=modeldir)
65 | mlsdmodel = MobileV2_MLSD_Large()
66 | mlsdmodel.load_state_dict(torch.load(modelpath), strict=True)
67 | mlsdmodel = mlsdmodel.to(device).eval()
68 |
69 | model = mlsdmodel
70 | assert input_image.ndim == 3
71 | img = input_image
72 | img_output = np.zeros_like(img)
73 | try:
74 | with torch.no_grad():
75 | lines = pred_lines(img, model, [img.shape[0], img.shape[1]], thr_v, thr_d)
76 | for line in lines:
77 | x_start, y_start, x_end, y_end = [int(val) for val in line]
78 | cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
79 | except Exception as e:
80 | pass
81 | return img_output[:, :, 0]
82 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_large.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | from torch.nn import functional as F
7 |
8 |
9 | class BlockTypeA(nn.Module):
10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11 | super(BlockTypeA, self).__init__()
12 | self.conv1 = nn.Sequential(
13 | nn.Conv2d(in_c2, out_c2, kernel_size=1),
14 | nn.BatchNorm2d(out_c2),
15 | nn.ReLU(inplace=True)
16 | )
17 | self.conv2 = nn.Sequential(
18 | nn.Conv2d(in_c1, out_c1, kernel_size=1),
19 | nn.BatchNorm2d(out_c1),
20 | nn.ReLU(inplace=True)
21 | )
22 | self.upscale = upscale
23 |
24 | def forward(self, a, b):
25 | b = self.conv1(b)
26 | a = self.conv2(a)
27 | if self.upscale:
28 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
29 | return torch.cat((a, b), dim=1)
30 |
31 |
32 | class BlockTypeB(nn.Module):
33 | def __init__(self, in_c, out_c):
34 | super(BlockTypeB, self).__init__()
35 | self.conv1 = nn.Sequential(
36 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
37 | nn.BatchNorm2d(in_c),
38 | nn.ReLU()
39 | )
40 | self.conv2 = nn.Sequential(
41 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
42 | nn.BatchNorm2d(out_c),
43 | nn.ReLU()
44 | )
45 |
46 | def forward(self, x):
47 | x = self.conv1(x) + x
48 | x = self.conv2(x)
49 | return x
50 |
51 | class BlockTypeC(nn.Module):
52 | def __init__(self, in_c, out_c):
53 | super(BlockTypeC, self).__init__()
54 | self.conv1 = nn.Sequential(
55 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
56 | nn.BatchNorm2d(in_c),
57 | nn.ReLU()
58 | )
59 | self.conv2 = nn.Sequential(
60 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(in_c),
62 | nn.ReLU()
63 | )
64 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
65 |
66 | def forward(self, x):
67 | x = self.conv1(x)
68 | x = self.conv2(x)
69 | x = self.conv3(x)
70 | return x
71 |
72 | def _make_divisible(v, divisor, min_value=None):
73 | """
74 | This function is taken from the original tf repo.
75 | It ensures that all layers have a channel number that is divisible by 8
76 | It can be seen here:
77 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
78 | :param v:
79 | :param divisor:
80 | :param min_value:
81 | :return:
82 | """
83 | if min_value is None:
84 | min_value = divisor
85 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
86 | # Make sure that round down does not go down by more than 10%.
87 | if new_v < 0.9 * v:
88 | new_v += divisor
89 | return new_v
90 |
91 |
92 | class ConvBNReLU(nn.Sequential):
93 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
94 | self.channel_pad = out_planes - in_planes
95 | self.stride = stride
96 | #padding = (kernel_size - 1) // 2
97 |
98 | # TFLite uses slightly different padding than PyTorch
99 | if stride == 2:
100 | padding = 0
101 | else:
102 | padding = (kernel_size - 1) // 2
103 |
104 | super(ConvBNReLU, self).__init__(
105 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
106 | nn.BatchNorm2d(out_planes),
107 | nn.ReLU6(inplace=True)
108 | )
109 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
110 |
111 |
112 | def forward(self, x):
113 | # TFLite uses different padding
114 | if self.stride == 2:
115 | x = F.pad(x, (0, 1, 0, 1), "constant", 0)
116 | #print(x.shape)
117 |
118 | for module in self:
119 | if not isinstance(module, nn.MaxPool2d):
120 | x = module(x)
121 | return x
122 |
123 |
124 | class InvertedResidual(nn.Module):
125 | def __init__(self, inp, oup, stride, expand_ratio):
126 | super(InvertedResidual, self).__init__()
127 | self.stride = stride
128 | assert stride in [1, 2]
129 |
130 | hidden_dim = int(round(inp * expand_ratio))
131 | self.use_res_connect = self.stride == 1 and inp == oup
132 |
133 | layers = []
134 | if expand_ratio != 1:
135 | # pw
136 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
137 | layers.extend([
138 | # dw
139 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
140 | # pw-linear
141 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
142 | nn.BatchNorm2d(oup),
143 | ])
144 | self.conv = nn.Sequential(*layers)
145 |
146 | def forward(self, x):
147 | if self.use_res_connect:
148 | return x + self.conv(x)
149 | else:
150 | return self.conv(x)
151 |
152 |
153 | class MobileNetV2(nn.Module):
154 | def __init__(self, pretrained=True):
155 | """
156 | MobileNet V2 main class
157 | Args:
158 | num_classes (int): Number of classes
159 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
160 | inverted_residual_setting: Network structure
161 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
162 | Set to 1 to turn off rounding
163 | block: Module specifying inverted residual building block for mobilenet
164 | """
165 | super(MobileNetV2, self).__init__()
166 |
167 | block = InvertedResidual
168 | input_channel = 32
169 | last_channel = 1280
170 | width_mult = 1.0
171 | round_nearest = 8
172 |
173 | inverted_residual_setting = [
174 | # t, c, n, s
175 | [1, 16, 1, 1],
176 | [6, 24, 2, 2],
177 | [6, 32, 3, 2],
178 | [6, 64, 4, 2],
179 | [6, 96, 3, 1],
180 | #[6, 160, 3, 2],
181 | #[6, 320, 1, 1],
182 | ]
183 |
184 | # only check the first element, assuming user knows t,c,n,s are required
185 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
186 | raise ValueError("inverted_residual_setting should be non-empty "
187 | "or a 4-element list, got {}".format(inverted_residual_setting))
188 |
189 | # building first layer
190 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
191 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
192 | features = [ConvBNReLU(4, input_channel, stride=2)]
193 | # building inverted residual blocks
194 | for t, c, n, s in inverted_residual_setting:
195 | output_channel = _make_divisible(c * width_mult, round_nearest)
196 | for i in range(n):
197 | stride = s if i == 0 else 1
198 | features.append(block(input_channel, output_channel, stride, expand_ratio=t))
199 | input_channel = output_channel
200 |
201 | self.features = nn.Sequential(*features)
202 | self.fpn_selected = [1, 3, 6, 10, 13]
203 | # weight initialization
204 | for m in self.modules():
205 | if isinstance(m, nn.Conv2d):
206 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
207 | if m.bias is not None:
208 | nn.init.zeros_(m.bias)
209 | elif isinstance(m, nn.BatchNorm2d):
210 | nn.init.ones_(m.weight)
211 | nn.init.zeros_(m.bias)
212 | elif isinstance(m, nn.Linear):
213 | nn.init.normal_(m.weight, 0, 0.01)
214 | nn.init.zeros_(m.bias)
215 | if pretrained:
216 | self._load_pretrained_model()
217 |
218 | def _forward_impl(self, x):
219 | # This exists since TorchScript doesn't support inheritance, so the superclass method
220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221 | fpn_features = []
222 | for i, f in enumerate(self.features):
223 | if i > self.fpn_selected[-1]:
224 | break
225 | x = f(x)
226 | if i in self.fpn_selected:
227 | fpn_features.append(x)
228 |
229 | c1, c2, c3, c4, c5 = fpn_features
230 | return c1, c2, c3, c4, c5
231 |
232 |
233 | def forward(self, x):
234 | return self._forward_impl(x)
235 |
236 | def _load_pretrained_model(self):
237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238 | model_dict = {}
239 | state_dict = self.state_dict()
240 | for k, v in pretrain_dict.items():
241 | if k in state_dict:
242 | model_dict[k] = v
243 | state_dict.update(model_dict)
244 | self.load_state_dict(state_dict)
245 |
246 |
247 | class MobileV2_MLSD_Large(nn.Module):
248 | def __init__(self):
249 | super(MobileV2_MLSD_Large, self).__init__()
250 |
251 | self.backbone = MobileNetV2(pretrained=False)
252 | ## A, B
253 | self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
254 | out_c1= 64, out_c2=64,
255 | upscale=False)
256 | self.block16 = BlockTypeB(128, 64)
257 |
258 | ## A, B
259 | self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
260 | out_c1= 64, out_c2= 64)
261 | self.block18 = BlockTypeB(128, 64)
262 |
263 | ## A, B
264 | self.block19 = BlockTypeA(in_c1=24, in_c2=64,
265 | out_c1=64, out_c2=64)
266 | self.block20 = BlockTypeB(128, 64)
267 |
268 | ## A, B, C
269 | self.block21 = BlockTypeA(in_c1=16, in_c2=64,
270 | out_c1=64, out_c2=64)
271 | self.block22 = BlockTypeB(128, 64)
272 |
273 | self.block23 = BlockTypeC(64, 16)
274 |
275 | def forward(self, x):
276 | c1, c2, c3, c4, c5 = self.backbone(x)
277 |
278 | x = self.block15(c4, c5)
279 | x = self.block16(x)
280 |
281 | x = self.block17(c3, x)
282 | x = self.block18(x)
283 |
284 | x = self.block19(c2, x)
285 | x = self.block20(x)
286 |
287 | x = self.block21(c1, x)
288 | x = self.block22(x)
289 | x = self.block23(x)
290 | x = x[:, 7:, :, :]
291 |
292 | return x
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_tiny.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | from torch.nn import functional as F
7 |
8 |
9 | class BlockTypeA(nn.Module):
10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11 | super(BlockTypeA, self).__init__()
12 | self.conv1 = nn.Sequential(
13 | nn.Conv2d(in_c2, out_c2, kernel_size=1),
14 | nn.BatchNorm2d(out_c2),
15 | nn.ReLU(inplace=True)
16 | )
17 | self.conv2 = nn.Sequential(
18 | nn.Conv2d(in_c1, out_c1, kernel_size=1),
19 | nn.BatchNorm2d(out_c1),
20 | nn.ReLU(inplace=True)
21 | )
22 | self.upscale = upscale
23 |
24 | def forward(self, a, b):
25 | b = self.conv1(b)
26 | a = self.conv2(a)
27 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
28 | return torch.cat((a, b), dim=1)
29 |
30 |
31 | class BlockTypeB(nn.Module):
32 | def __init__(self, in_c, out_c):
33 | super(BlockTypeB, self).__init__()
34 | self.conv1 = nn.Sequential(
35 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(in_c),
37 | nn.ReLU()
38 | )
39 | self.conv2 = nn.Sequential(
40 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
41 | nn.BatchNorm2d(out_c),
42 | nn.ReLU()
43 | )
44 |
45 | def forward(self, x):
46 | x = self.conv1(x) + x
47 | x = self.conv2(x)
48 | return x
49 |
50 | class BlockTypeC(nn.Module):
51 | def __init__(self, in_c, out_c):
52 | super(BlockTypeC, self).__init__()
53 | self.conv1 = nn.Sequential(
54 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
55 | nn.BatchNorm2d(in_c),
56 | nn.ReLU()
57 | )
58 | self.conv2 = nn.Sequential(
59 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
60 | nn.BatchNorm2d(in_c),
61 | nn.ReLU()
62 | )
63 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
64 |
65 | def forward(self, x):
66 | x = self.conv1(x)
67 | x = self.conv2(x)
68 | x = self.conv3(x)
69 | return x
70 |
71 | def _make_divisible(v, divisor, min_value=None):
72 | """
73 | This function is taken from the original tf repo.
74 | It ensures that all layers have a channel number that is divisible by 8
75 | It can be seen here:
76 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
77 | :param v:
78 | :param divisor:
79 | :param min_value:
80 | :return:
81 | """
82 | if min_value is None:
83 | min_value = divisor
84 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
85 | # Make sure that round down does not go down by more than 10%.
86 | if new_v < 0.9 * v:
87 | new_v += divisor
88 | return new_v
89 |
90 |
91 | class ConvBNReLU(nn.Sequential):
92 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
93 | self.channel_pad = out_planes - in_planes
94 | self.stride = stride
95 | #padding = (kernel_size - 1) // 2
96 |
97 | # TFLite uses slightly different padding than PyTorch
98 | if stride == 2:
99 | padding = 0
100 | else:
101 | padding = (kernel_size - 1) // 2
102 |
103 | super(ConvBNReLU, self).__init__(
104 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
105 | nn.BatchNorm2d(out_planes),
106 | nn.ReLU6(inplace=True)
107 | )
108 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
109 |
110 |
111 | def forward(self, x):
112 | # TFLite uses different padding
113 | if self.stride == 2:
114 | x = F.pad(x, (0, 1, 0, 1), "constant", 0)
115 | #print(x.shape)
116 |
117 | for module in self:
118 | if not isinstance(module, nn.MaxPool2d):
119 | x = module(x)
120 | return x
121 |
122 |
123 | class InvertedResidual(nn.Module):
124 | def __init__(self, inp, oup, stride, expand_ratio):
125 | super(InvertedResidual, self).__init__()
126 | self.stride = stride
127 | assert stride in [1, 2]
128 |
129 | hidden_dim = int(round(inp * expand_ratio))
130 | self.use_res_connect = self.stride == 1 and inp == oup
131 |
132 | layers = []
133 | if expand_ratio != 1:
134 | # pw
135 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
136 | layers.extend([
137 | # dw
138 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
139 | # pw-linear
140 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
141 | nn.BatchNorm2d(oup),
142 | ])
143 | self.conv = nn.Sequential(*layers)
144 |
145 | def forward(self, x):
146 | if self.use_res_connect:
147 | return x + self.conv(x)
148 | else:
149 | return self.conv(x)
150 |
151 |
152 | class MobileNetV2(nn.Module):
153 | def __init__(self, pretrained=True):
154 | """
155 | MobileNet V2 main class
156 | Args:
157 | num_classes (int): Number of classes
158 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
159 | inverted_residual_setting: Network structure
160 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
161 | Set to 1 to turn off rounding
162 | block: Module specifying inverted residual building block for mobilenet
163 | """
164 | super(MobileNetV2, self).__init__()
165 |
166 | block = InvertedResidual
167 | input_channel = 32
168 | last_channel = 1280
169 | width_mult = 1.0
170 | round_nearest = 8
171 |
172 | inverted_residual_setting = [
173 | # t, c, n, s
174 | [1, 16, 1, 1],
175 | [6, 24, 2, 2],
176 | [6, 32, 3, 2],
177 | [6, 64, 4, 2],
178 | #[6, 96, 3, 1],
179 | #[6, 160, 3, 2],
180 | #[6, 320, 1, 1],
181 | ]
182 |
183 | # only check the first element, assuming user knows t,c,n,s are required
184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
185 | raise ValueError("inverted_residual_setting should be non-empty "
186 | "or a 4-element list, got {}".format(inverted_residual_setting))
187 |
188 | # building first layer
189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
191 | features = [ConvBNReLU(4, input_channel, stride=2)]
192 | # building inverted residual blocks
193 | for t, c, n, s in inverted_residual_setting:
194 | output_channel = _make_divisible(c * width_mult, round_nearest)
195 | for i in range(n):
196 | stride = s if i == 0 else 1
197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t))
198 | input_channel = output_channel
199 | self.features = nn.Sequential(*features)
200 |
201 | self.fpn_selected = [3, 6, 10]
202 | # weight initialization
203 | for m in self.modules():
204 | if isinstance(m, nn.Conv2d):
205 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
206 | if m.bias is not None:
207 | nn.init.zeros_(m.bias)
208 | elif isinstance(m, nn.BatchNorm2d):
209 | nn.init.ones_(m.weight)
210 | nn.init.zeros_(m.bias)
211 | elif isinstance(m, nn.Linear):
212 | nn.init.normal_(m.weight, 0, 0.01)
213 | nn.init.zeros_(m.bias)
214 |
215 | #if pretrained:
216 | # self._load_pretrained_model()
217 |
218 | def _forward_impl(self, x):
219 | # This exists since TorchScript doesn't support inheritance, so the superclass method
220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221 | fpn_features = []
222 | for i, f in enumerate(self.features):
223 | if i > self.fpn_selected[-1]:
224 | break
225 | x = f(x)
226 | if i in self.fpn_selected:
227 | fpn_features.append(x)
228 |
229 | c2, c3, c4 = fpn_features
230 | return c2, c3, c4
231 |
232 |
233 | def forward(self, x):
234 | return self._forward_impl(x)
235 |
236 | def _load_pretrained_model(self):
237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238 | model_dict = {}
239 | state_dict = self.state_dict()
240 | for k, v in pretrain_dict.items():
241 | if k in state_dict:
242 | model_dict[k] = v
243 | state_dict.update(model_dict)
244 | self.load_state_dict(state_dict)
245 |
246 |
247 | class MobileV2_MLSD_Tiny(nn.Module):
248 | def __init__(self):
249 | super(MobileV2_MLSD_Tiny, self).__init__()
250 |
251 | self.backbone = MobileNetV2(pretrained=True)
252 |
253 | self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
254 | out_c1= 64, out_c2=64)
255 | self.block13 = BlockTypeB(128, 64)
256 |
257 | self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
258 | out_c1= 32, out_c2= 32)
259 | self.block15 = BlockTypeB(64, 64)
260 |
261 | self.block16 = BlockTypeC(64, 16)
262 |
263 | def forward(self, x):
264 | c2, c3, c4 = self.backbone(x)
265 |
266 | x = self.block12(c3, c4)
267 | x = self.block13(x)
268 | x = self.block14(c2, x)
269 | x = self.block15(x)
270 | x = self.block16(x)
271 | x = x[:, 7:, :, :]
272 | #print(x.shape)
273 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
274 |
275 | return x
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/openpose/LICENSE:
--------------------------------------------------------------------------------
1 | OPENPOSE: MULTIPERSON KEYPOINT DETECTION
2 | SOFTWARE LICENSE AGREEMENT
3 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
4 |
5 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
6 |
7 | This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
8 |
9 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
10 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
11 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
12 |
13 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
14 |
15 | COPYRIGHT: The Software is owned by Licensor and is protected by United
16 | States copyright laws and applicable international treaties and/or conventions.
17 |
18 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
19 |
20 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
21 |
22 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
23 |
24 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
25 |
26 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
27 |
28 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
29 |
30 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
31 |
32 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
33 |
34 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
35 |
36 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
37 |
38 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
39 |
40 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
41 |
42 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable
43 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
44 |
45 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
46 |
47 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
48 |
49 | GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
50 |
51 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
52 |
53 |
54 |
55 | ************************************************************************
56 |
57 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
58 |
59 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
60 |
61 | 1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
62 |
63 | COPYRIGHT
64 |
65 | All contributions by the University of California:
66 | Copyright (c) 2014-2017 The Regents of the University of California (Regents)
67 | All rights reserved.
68 |
69 | All other contributions:
70 | Copyright (c) 2014-2017, the respective contributors
71 | All rights reserved.
72 |
73 | Caffe uses a shared copyright model: each contributor holds copyright over
74 | their contributions to Caffe. The project versioning records all such
75 | contribution and copyright details. If a contributor wants to further mark
76 | their specific copyright on a particular contribution, they should indicate
77 | their copyright solely in the commit message of the change when it is
78 | committed.
79 |
80 | LICENSE
81 |
82 | Redistribution and use in source and binary forms, with or without
83 | modification, are permitted provided that the following conditions are met:
84 |
85 | 1. Redistributions of source code must retain the above copyright notice, this
86 | list of conditions and the following disclaimer.
87 | 2. Redistributions in binary form must reproduce the above copyright notice,
88 | this list of conditions and the following disclaimer in the documentation
89 | and/or other materials provided with the distribution.
90 |
91 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
92 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
93 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
94 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
95 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
96 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
97 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
98 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
99 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
100 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
101 |
102 | CONTRIBUTION AGREEMENT
103 |
104 | By contributing to the BVLC/caffe repository through pull-request, comment,
105 | or otherwise, the contributor releases their content to the
106 | license and copyright terms herein.
107 |
108 | ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/openpose/__init__.py:
--------------------------------------------------------------------------------
1 | # Openpose
2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4 | # 3rd Edited by ControlNet
5 | # 4th Edited by ControlNet (added face and correct hands)
6 | # 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
7 | # This preprocessor is licensed by CMU for non-commercial use only.
8 |
9 |
10 | import os
11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12 |
13 | import json
14 | import torch
15 | import numpy as np
16 | from . import util
17 | from .body import Body, BodyResult, Keypoint
18 | from .hand import Hand
19 | from .face import Face
20 |
21 | models_path = "pretrained/controlnet/preprocess"
22 |
23 | from typing import NamedTuple, Tuple, List, Callable, Union
24 |
25 | body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
26 | hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
27 | face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"
28 |
29 | HandResult = List[Keypoint]
30 | FaceResult = List[Keypoint]
31 |
32 | class PoseResult(NamedTuple):
33 | body: BodyResult
34 | left_hand: Union[HandResult, None]
35 | right_hand: Union[HandResult, None]
36 | face: Union[FaceResult, None]
37 |
38 | def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
39 | """
40 | Draw the detected poses on an empty canvas.
41 |
42 | Args:
43 | poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
44 | H (int): The height of the canvas.
45 | W (int): The width of the canvas.
46 | draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
47 | draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
48 | draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
49 |
50 | Returns:
51 | numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
52 | """
53 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
54 |
55 | for pose in poses:
56 | if draw_body:
57 | canvas = util.draw_bodypose(canvas, pose.body.keypoints)
58 |
59 | if draw_hand:
60 | canvas = util.draw_handpose(canvas, pose.left_hand)
61 | canvas = util.draw_handpose(canvas, pose.right_hand)
62 |
63 | if draw_face:
64 | canvas = util.draw_facepose(canvas, pose.face)
65 |
66 | return canvas
67 |
68 | def encode_poses_as_json(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
69 | """ Encode the pose as a JSON string following openpose JSON output format:
70 | https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
71 | """
72 | def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
73 | if not keypoints:
74 | return None
75 |
76 | return [
77 | value
78 | for keypoint in keypoints
79 | for value in (
80 | [float(keypoint.x), float(keypoint.y), 1.0]
81 | if keypoint is not None
82 | else [0.0, 0.0, 0.0]
83 | )
84 | ]
85 |
86 | return json.dumps({
87 | 'people': [
88 | {
89 | 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
90 | "face_keypoints_2d": compress_keypoints(pose.face),
91 | "hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
92 | "hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
93 | }
94 | for pose in poses
95 | ],
96 | 'canvas_height': canvas_height,
97 | 'canvas_width': canvas_width,
98 | }, indent=4)
99 |
100 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
101 | """Load file form http url, will download models if necessary.
102 |
103 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
104 |
105 | Args:
106 | url (str): URL to be downloaded.
107 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
108 | Default: None.
109 | progress (bool): Whether to show the download progress. Default: True.
110 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
111 |
112 | Returns:
113 | str: The path to the downloaded file.
114 | """
115 | from torch.hub import download_url_to_file, get_dir
116 | from urllib.parse import urlparse
117 | if model_dir is None: # use the pytorch hub_dir
118 | hub_dir = get_dir()
119 | model_dir = os.path.join(hub_dir, 'checkpoints')
120 |
121 | os.makedirs(model_dir, exist_ok=True)
122 |
123 | parts = urlparse(url)
124 | filename = os.path.basename(parts.path)
125 | if file_name is not None:
126 | filename = file_name
127 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
128 | if not os.path.exists(cached_file):
129 | print(f'Downloading: "{url}" to {cached_file}\n')
130 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
131 | return cached_file
132 |
133 | class OpenposeDetector:
134 | """
135 | A class for detecting human poses in images using the Openpose model.
136 |
137 | Attributes:
138 | model_dir (str): Path to the directory where the pose models are stored.
139 | """
140 | model_dir = os.path.join(models_path, "openpose")
141 |
142 | def __init__(self, device):
143 | self.device = device
144 | self.body_estimation = None
145 | self.hand_estimation = None
146 | self.face_estimation = None
147 |
148 | def load_model(self):
149 | """
150 | Load the Openpose body, hand, and face models.
151 | """
152 | body_modelpath = os.path.join(self.model_dir, "body_pose_model.pth")
153 | hand_modelpath = os.path.join(self.model_dir, "hand_pose_model.pth")
154 | face_modelpath = os.path.join(self.model_dir, "facenet.pth")
155 |
156 | if not os.path.exists(body_modelpath):
157 | load_file_from_url(body_model_path, model_dir=self.model_dir)
158 |
159 | if not os.path.exists(hand_modelpath):
160 | load_file_from_url(hand_model_path, model_dir=self.model_dir)
161 |
162 | if not os.path.exists(face_modelpath):
163 | load_file_from_url(face_model_path, model_dir=self.model_dir)
164 |
165 | self.body_estimation = Body(body_modelpath)
166 | self.hand_estimation = Hand(hand_modelpath)
167 | self.face_estimation = Face(face_modelpath)
168 |
169 | def unload_model(self):
170 | """
171 | Unload the Openpose models by moving them to the CPU.
172 | """
173 | if self.body_estimation is not None:
174 | self.body_estimation.model.to("cpu")
175 | self.hand_estimation.model.to("cpu")
176 | self.face_estimation.model.to("cpu")
177 |
178 | def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
179 | left_hand = None
180 | right_hand = None
181 | H, W, _ = oriImg.shape
182 | for x, y, w, is_left in util.handDetect(body, oriImg):
183 | peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32)
184 | if peaks.ndim == 2 and peaks.shape[1] == 2:
185 | peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
186 | peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
187 |
188 | hand_result = [
189 | Keypoint(x=peak[0], y=peak[1])
190 | for peak in peaks
191 | ]
192 |
193 | if is_left:
194 | left_hand = hand_result
195 | else:
196 | right_hand = hand_result
197 |
198 | return left_hand, right_hand
199 |
200 | def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
201 | face = util.faceDetect(body, oriImg)
202 | if face is None:
203 | return None
204 |
205 | x, y, w = face
206 | H, W, _ = oriImg.shape
207 | heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :])
208 | peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
209 | if peaks.ndim == 2 and peaks.shape[1] == 2:
210 | peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
211 | peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
212 | return [
213 | Keypoint(x=peak[0], y=peak[1])
214 | for peak in peaks
215 | ]
216 |
217 | return None
218 |
219 | def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
220 | """
221 | Detect poses in the given image.
222 | Args:
223 | oriImg (numpy.ndarray): The input image for pose detection.
224 | include_hand (bool, optional): Whether to include hand detection. Defaults to False.
225 | include_face (bool, optional): Whether to include face detection. Defaults to False.
226 |
227 | Returns:
228 | List[PoseResult]: A list of PoseResult objects containing the detected poses.
229 | """
230 | if self.body_estimation is None:
231 | self.load_model()
232 |
233 | self.body_estimation.model.to(self.device)
234 | self.hand_estimation.model.to(self.device)
235 | self.face_estimation.model.to(self.device)
236 |
237 | self.body_estimation.cn_device = self.device
238 | self.hand_estimation.cn_device = self.device
239 | self.face_estimation.cn_device = self.device
240 |
241 | oriImg = oriImg[:, :, ::-1].copy()
242 | H, W, C = oriImg.shape
243 | with torch.no_grad():
244 | candidate, subset = self.body_estimation(oriImg)
245 | bodies = self.body_estimation.format_body_result(candidate, subset)
246 |
247 | results = []
248 | for body in bodies:
249 | left_hand, right_hand, face = (None,) * 3
250 | if include_hand:
251 | left_hand, right_hand = self.detect_hands(body, oriImg)
252 | if include_face:
253 | face = self.detect_face(body, oriImg)
254 |
255 | results.append(PoseResult(BodyResult(
256 | keypoints=[
257 | Keypoint(
258 | x=keypoint.x / float(W),
259 | y=keypoint.y / float(H)
260 | ) if keypoint is not None else None
261 | for keypoint in body.keypoints
262 | ],
263 | total_score=body.total_score,
264 | total_parts=body.total_parts
265 | ), left_hand, right_hand, face))
266 |
267 | return results
268 |
269 | def __call__(
270 | self, oriImg, include_body=True, include_hand=False, include_face=False,
271 | json_pose_callback: Callable[[str], None] = None,
272 | ):
273 | """
274 | Detect and draw poses in the given image.
275 |
276 | Args:
277 | oriImg (numpy.ndarray): The input image for pose detection and drawing.
278 | include_body (bool, optional): Whether to include body keypoints. Defaults to True.
279 | include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
280 | include_face (bool, optional): Whether to include face keypoints. Defaults to False.
281 | json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.
282 |
283 | Returns:
284 | numpy.ndarray: The image with detected and drawn poses.
285 | """
286 | H, W, _ = oriImg.shape
287 | poses = self.detect_poses(oriImg, include_hand, include_face)
288 | if json_pose_callback:
289 | json_pose_callback(encode_poses_as_json(poses, H, W))
290 | return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
291 |
292 | class OpenposeModel(object):
293 | def __init__(self) -> None:
294 | self.model_openpose = None
295 |
296 | def run_model(
297 | self,
298 | img: np.ndarray,
299 | include_body: bool,
300 | include_hand: bool,
301 | include_face: bool,
302 | json_pose_callback: Callable[[str], None] = None,
303 | device = 'cpu', ):
304 |
305 | if json_pose_callback is None:
306 | json_pose_callback = lambda x: None
307 |
308 | if self.model_openpose is None:
309 | self.model_openpose = OpenposeDetector(device=device)
310 |
311 | return self.model_openpose(
312 | img,
313 | include_body=include_body,
314 | include_hand=include_hand,
315 | include_face=include_face,
316 | json_pose_callback=json_pose_callback)
317 |
318 | def unload(self):
319 | if self.model_openpose is not None:
320 | self.model_openpose.unload_model()
321 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/openpose/hand.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | import numpy as np
4 | import math
5 | import time
6 | from scipy.ndimage.filters import gaussian_filter
7 | import matplotlib.pyplot as plt
8 | import matplotlib
9 | import torch
10 | from skimage.measure import label
11 |
12 | from .model import handpose_model
13 | from . import util
14 |
15 | class Hand(object):
16 | def __init__(self, model_path):
17 | self.model = handpose_model()
18 | # if torch.cuda.is_available():
19 | # self.model = self.model.cuda()
20 | # print('cuda')
21 | model_dict = util.transfer(self.model, torch.load(model_path))
22 | self.model.load_state_dict(model_dict)
23 | self.model.eval()
24 |
25 | def __call__(self, oriImgRaw):
26 | scale_search = [0.5, 1.0, 1.5, 2.0]
27 | # scale_search = [0.5]
28 | boxsize = 368
29 | stride = 8
30 | padValue = 128
31 | thre = 0.05
32 | multiplier = [x * boxsize for x in scale_search]
33 |
34 | wsize = 128
35 | heatmap_avg = np.zeros((wsize, wsize, 22))
36 |
37 | Hr, Wr, Cr = oriImgRaw.shape
38 |
39 | oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
40 |
41 | for m in range(len(multiplier)):
42 | scale = multiplier[m]
43 | imageToTest = util.smart_resize(oriImg, (scale, scale))
44 |
45 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
46 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
47 | im = np.ascontiguousarray(im)
48 |
49 | data = torch.from_numpy(im).float()
50 | if torch.cuda.is_available():
51 | data = data.cuda()
52 |
53 | with torch.no_grad():
54 | data = data.to(self.cn_device)
55 | output = self.model(data).cpu().numpy()
56 |
57 | # extract outputs, resize, and remove padding
58 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
59 | heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
60 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
61 | heatmap = util.smart_resize(heatmap, (wsize, wsize))
62 |
63 | heatmap_avg += heatmap / len(multiplier)
64 |
65 | all_peaks = []
66 | for part in range(21):
67 | map_ori = heatmap_avg[:, :, part]
68 | one_heatmap = gaussian_filter(map_ori, sigma=3)
69 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
70 |
71 | if np.sum(binary) == 0:
72 | all_peaks.append([0, 0])
73 | continue
74 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
75 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
76 | label_img[label_img != max_index] = 0
77 | map_ori[label_img == 0] = 0
78 |
79 | y, x = util.npmax(map_ori)
80 | y = int(float(y) * float(Hr) / float(wsize))
81 | x = int(float(x) * float(Wr) / float(wsize))
82 | all_peaks.append([x, y])
83 | return np.array(all_peaks)
84 |
85 | if __name__ == "__main__":
86 | hand_estimation = Hand('../model/hand_pose_model.pth')
87 |
88 | # test_image = '../images/hand.jpg'
89 | test_image = '../images/hand.jpg'
90 | oriImg = cv2.imread(test_image) # B,G,R order
91 | peaks = hand_estimation(oriImg)
92 | canvas = util.draw_handpose(oriImg, peaks, True)
93 | cv2.imshow('', canvas)
94 | cv2.waitKey(0)
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/openpose/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | def make_layers(block, no_relu_layers):
8 | layers = []
9 | for layer_name, v in block.items():
10 | if 'pool' in layer_name:
11 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
12 | padding=v[2])
13 | layers.append((layer_name, layer))
14 | else:
15 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
16 | kernel_size=v[2], stride=v[3],
17 | padding=v[4])
18 | layers.append((layer_name, conv2d))
19 | if layer_name not in no_relu_layers:
20 | layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
21 |
22 | return nn.Sequential(OrderedDict(layers))
23 |
24 | class bodypose_model(nn.Module):
25 | def __init__(self):
26 | super(bodypose_model, self).__init__()
27 |
28 | # these layers have no relu layer
29 | no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
30 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
31 | 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
32 | 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
33 | blocks = {}
34 | block0 = OrderedDict([
35 | ('conv1_1', [3, 64, 3, 1, 1]),
36 | ('conv1_2', [64, 64, 3, 1, 1]),
37 | ('pool1_stage1', [2, 2, 0]),
38 | ('conv2_1', [64, 128, 3, 1, 1]),
39 | ('conv2_2', [128, 128, 3, 1, 1]),
40 | ('pool2_stage1', [2, 2, 0]),
41 | ('conv3_1', [128, 256, 3, 1, 1]),
42 | ('conv3_2', [256, 256, 3, 1, 1]),
43 | ('conv3_3', [256, 256, 3, 1, 1]),
44 | ('conv3_4', [256, 256, 3, 1, 1]),
45 | ('pool3_stage1', [2, 2, 0]),
46 | ('conv4_1', [256, 512, 3, 1, 1]),
47 | ('conv4_2', [512, 512, 3, 1, 1]),
48 | ('conv4_3_CPM', [512, 256, 3, 1, 1]),
49 | ('conv4_4_CPM', [256, 128, 3, 1, 1])
50 | ])
51 |
52 |
53 | # Stage 1
54 | block1_1 = OrderedDict([
55 | ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
56 | ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
57 | ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
58 | ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
59 | ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
60 | ])
61 |
62 | block1_2 = OrderedDict([
63 | ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
64 | ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
65 | ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
66 | ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
67 | ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
68 | ])
69 | blocks['block1_1'] = block1_1
70 | blocks['block1_2'] = block1_2
71 |
72 | self.model0 = make_layers(block0, no_relu_layers)
73 |
74 | # Stages 2 - 6
75 | for i in range(2, 7):
76 | blocks['block%d_1' % i] = OrderedDict([
77 | ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
78 | ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
79 | ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
80 | ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
81 | ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
82 | ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
83 | ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
84 | ])
85 |
86 | blocks['block%d_2' % i] = OrderedDict([
87 | ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
88 | ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
89 | ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
90 | ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
91 | ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
92 | ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
93 | ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
94 | ])
95 |
96 | for k in blocks.keys():
97 | blocks[k] = make_layers(blocks[k], no_relu_layers)
98 |
99 | self.model1_1 = blocks['block1_1']
100 | self.model2_1 = blocks['block2_1']
101 | self.model3_1 = blocks['block3_1']
102 | self.model4_1 = blocks['block4_1']
103 | self.model5_1 = blocks['block5_1']
104 | self.model6_1 = blocks['block6_1']
105 |
106 | self.model1_2 = blocks['block1_2']
107 | self.model2_2 = blocks['block2_2']
108 | self.model3_2 = blocks['block3_2']
109 | self.model4_2 = blocks['block4_2']
110 | self.model5_2 = blocks['block5_2']
111 | self.model6_2 = blocks['block6_2']
112 |
113 |
114 | def forward(self, x):
115 |
116 | out1 = self.model0(x)
117 |
118 | out1_1 = self.model1_1(out1)
119 | out1_2 = self.model1_2(out1)
120 | out2 = torch.cat([out1_1, out1_2, out1], 1)
121 |
122 | out2_1 = self.model2_1(out2)
123 | out2_2 = self.model2_2(out2)
124 | out3 = torch.cat([out2_1, out2_2, out1], 1)
125 |
126 | out3_1 = self.model3_1(out3)
127 | out3_2 = self.model3_2(out3)
128 | out4 = torch.cat([out3_1, out3_2, out1], 1)
129 |
130 | out4_1 = self.model4_1(out4)
131 | out4_2 = self.model4_2(out4)
132 | out5 = torch.cat([out4_1, out4_2, out1], 1)
133 |
134 | out5_1 = self.model5_1(out5)
135 | out5_2 = self.model5_2(out5)
136 | out6 = torch.cat([out5_1, out5_2, out1], 1)
137 |
138 | out6_1 = self.model6_1(out6)
139 | out6_2 = self.model6_2(out6)
140 |
141 | return out6_1, out6_2
142 |
143 | class handpose_model(nn.Module):
144 | def __init__(self):
145 | super(handpose_model, self).__init__()
146 |
147 | # these layers have no relu layer
148 | no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
149 | 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
150 | # stage 1
151 | block1_0 = OrderedDict([
152 | ('conv1_1', [3, 64, 3, 1, 1]),
153 | ('conv1_2', [64, 64, 3, 1, 1]),
154 | ('pool1_stage1', [2, 2, 0]),
155 | ('conv2_1', [64, 128, 3, 1, 1]),
156 | ('conv2_2', [128, 128, 3, 1, 1]),
157 | ('pool2_stage1', [2, 2, 0]),
158 | ('conv3_1', [128, 256, 3, 1, 1]),
159 | ('conv3_2', [256, 256, 3, 1, 1]),
160 | ('conv3_3', [256, 256, 3, 1, 1]),
161 | ('conv3_4', [256, 256, 3, 1, 1]),
162 | ('pool3_stage1', [2, 2, 0]),
163 | ('conv4_1', [256, 512, 3, 1, 1]),
164 | ('conv4_2', [512, 512, 3, 1, 1]),
165 | ('conv4_3', [512, 512, 3, 1, 1]),
166 | ('conv4_4', [512, 512, 3, 1, 1]),
167 | ('conv5_1', [512, 512, 3, 1, 1]),
168 | ('conv5_2', [512, 512, 3, 1, 1]),
169 | ('conv5_3_CPM', [512, 128, 3, 1, 1])
170 | ])
171 |
172 | block1_1 = OrderedDict([
173 | ('conv6_1_CPM', [128, 512, 1, 1, 0]),
174 | ('conv6_2_CPM', [512, 22, 1, 1, 0])
175 | ])
176 |
177 | blocks = {}
178 | blocks['block1_0'] = block1_0
179 | blocks['block1_1'] = block1_1
180 |
181 | # stage 2-6
182 | for i in range(2, 7):
183 | blocks['block%d' % i] = OrderedDict([
184 | ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
185 | ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
186 | ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
187 | ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
188 | ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
189 | ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
190 | ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
191 | ])
192 |
193 | for k in blocks.keys():
194 | blocks[k] = make_layers(blocks[k], no_relu_layers)
195 |
196 | self.model1_0 = blocks['block1_0']
197 | self.model1_1 = blocks['block1_1']
198 | self.model2 = blocks['block2']
199 | self.model3 = blocks['block3']
200 | self.model4 = blocks['block4']
201 | self.model5 = blocks['block5']
202 | self.model6 = blocks['block6']
203 |
204 | def forward(self, x):
205 | out1_0 = self.model1_0(x)
206 | out1_1 = self.model1_1(out1_0)
207 | concat_stage2 = torch.cat([out1_1, out1_0], 1)
208 | out_stage2 = self.model2(concat_stage2)
209 | concat_stage3 = torch.cat([out_stage2, out1_0], 1)
210 | out_stage3 = self.model3(concat_stage3)
211 | concat_stage4 = torch.cat([out_stage3, out1_0], 1)
212 | out_stage4 = self.model4(concat_stage4)
213 | concat_stage5 = torch.cat([out_stage4, out1_0], 1)
214 | out_stage5 = self.model5(concat_stage5)
215 | concat_stage6 = torch.cat([out_stage5, out1_0], 1)
216 | out_stage6 = self.model6(concat_stage6)
217 | return out_stage6
218 |
219 |
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/pidinet/LICENSE:
--------------------------------------------------------------------------------
1 | It is just for research purpose, and commercial use should be contacted with authors first.
2 |
3 | Copyright (c) 2021 Zhuo Su
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/lib/model_zoo/controlnet_annotator/pidinet/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from einops import rearrange
5 | from .model import pidinet
6 |
7 | models_path = 'pretrained/controlnet/preprocess'
8 |
9 | netNetwork = None
10 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
11 | modeldir = os.path.join(models_path, "pidinet")
12 | old_modeldir = os.path.dirname(os.path.realpath(__file__))
13 |
14 | def safe_step(x, step=2):
15 | y = x.astype(np.float32) * float(step + 1)
16 | y = y.astype(np.int32).astype(np.float32) / float(step)
17 | return y
18 |
19 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
20 | """Load file form http url, will download models if necessary.
21 |
22 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
23 |
24 | Args:
25 | url (str): URL to be downloaded.
26 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
27 | Default: None.
28 | progress (bool): Whether to show the download progress. Default: True.
29 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
30 |
31 | Returns:
32 | str: The path to the downloaded file.
33 | """
34 | from torch.hub import download_url_to_file, get_dir
35 | from urllib.parse import urlparse
36 | if model_dir is None: # use the pytorch hub_dir
37 | hub_dir = get_dir()
38 | model_dir = os.path.join(hub_dir, 'checkpoints')
39 |
40 | os.makedirs(model_dir, exist_ok=True)
41 |
42 | parts = urlparse(url)
43 | filename = os.path.basename(parts.path)
44 | if file_name is not None:
45 | filename = file_name
46 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
47 | if not os.path.exists(cached_file):
48 | print(f'Downloading: "{url}" to {cached_file}\n')
49 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
50 | return cached_file
51 |
52 | def load_state_dict(ckpt_path, location='cpu'):
53 | def get_state_dict(d):
54 | return d.get('state_dict', d)
55 |
56 | _, extension = os.path.splitext(ckpt_path)
57 | if extension.lower() == ".safetensors":
58 | import safetensors.torch
59 | state_dict = safetensors.torch.load_file(ckpt_path, device=location)
60 | else:
61 | state_dict = get_state_dict(torch.load(
62 | ckpt_path, map_location=torch.device(location)))
63 | state_dict = get_state_dict(state_dict)
64 | print(f'Loaded state_dict from [{ckpt_path}]')
65 | return state_dict
66 |
67 | def apply_pidinet(input_image, is_safe=False, apply_fliter=False, device='cpu'):
68 | global netNetwork
69 | if netNetwork is None:
70 | modelpath = os.path.join(modeldir, "table5_pidinet.pth")
71 | old_modelpath = os.path.join(old_modeldir, "table5_pidinet.pth")
72 | if os.path.exists(old_modelpath):
73 | modelpath = old_modelpath
74 | elif not os.path.exists(modelpath):
75 | load_file_from_url(remote_model_path, model_dir=modeldir)
76 | netNetwork = pidinet()
77 | ckp = load_state_dict(modelpath)
78 | netNetwork.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
79 |
80 | netNetwork = netNetwork.to(device)
81 | netNetwork.eval()
82 | assert input_image.ndim == 3
83 | input_image = input_image[:, :, ::-1].copy()
84 | with torch.no_grad():
85 | image_pidi = torch.from_numpy(input_image).float().to(device)
86 | image_pidi = image_pidi / 255.0
87 | image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
88 | edge = netNetwork(image_pidi)[-1]
89 | edge = edge.cpu().numpy()
90 | if apply_fliter:
91 | edge = edge > 0.5
92 | if is_safe:
93 | edge = safe_step(edge)
94 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
95 |
96 | return edge[0][0]
97 |
98 | def unload_pid_model():
99 | global netNetwork
100 | if netNetwork is not None:
101 | netNetwork.cpu()
--------------------------------------------------------------------------------
/lib/model_zoo/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from einops import repeat
7 |
8 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
9 | if schedule == "linear":
10 | betas = (
11 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
12 | )
13 |
14 | elif schedule == "cosine":
15 | timesteps = (
16 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
17 | )
18 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
19 | alphas = torch.cos(alphas).pow(2)
20 | alphas = alphas / alphas[0]
21 | betas = 1 - alphas[1:] / alphas[:-1]
22 | betas = np.clip(betas, a_min=0, a_max=0.999)
23 |
24 | elif schedule == "sqrt_linear":
25 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
26 | elif schedule == "sqrt":
27 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
28 | else:
29 | raise ValueError(f"schedule '{schedule}' unknown.")
30 | return betas.numpy()
31 |
32 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
33 | if ddim_discr_method == 'uniform':
34 | c = num_ddpm_timesteps // num_ddim_timesteps
35 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
36 | elif ddim_discr_method == 'quad':
37 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
38 | else:
39 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
40 |
41 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
42 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
43 | steps_out = ddim_timesteps + 1
44 | if verbose:
45 | print(f'Selected timesteps for ddim sampler: {steps_out}')
46 | return steps_out
47 |
48 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
49 | # select alphas for computing the variance schedule
50 | alphas = alphacums[ddim_timesteps]
51 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
52 |
53 | # according the the formula provided in https://arxiv.org/abs/2010.02502
54 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
55 | if verbose:
56 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
57 | print(f'For the chosen value of eta, which is {eta}, '
58 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
59 | return sigmas, alphas, alphas_prev
60 |
61 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
62 | """
63 | Create a beta schedule that discretizes the given alpha_t_bar function,
64 | which defines the cumulative product of (1-beta) over time from t = [0,1].
65 | :param num_diffusion_timesteps: the number of betas to produce.
66 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
67 | produces the cumulative product of (1-beta) up to that
68 | part of the diffusion process.
69 | :param max_beta: the maximum beta to use; use values lower than 1 to
70 | prevent singularities.
71 | """
72 | betas = []
73 | for i in range(num_diffusion_timesteps):
74 | t1 = i / num_diffusion_timesteps
75 | t2 = (i + 1) / num_diffusion_timesteps
76 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
77 | return np.array(betas)
78 |
79 | def extract_into_tensor(a, t, x_shape):
80 | b, *_ = t.shape
81 | out = a.gather(-1, t)
82 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
83 |
84 | def checkpoint(func, inputs, params, flag):
85 | """
86 | Evaluate a function without caching intermediate activations, allowing for
87 | reduced memory at the expense of extra compute in the backward pass.
88 | :param func: the function to evaluate.
89 | :param inputs: the argument sequence to pass to `func`.
90 | :param params: a sequence of parameters `func` depends on but does not
91 | explicitly take as arguments.
92 | :param flag: if False, disable gradient checkpointing.
93 | """
94 | if flag:
95 | args = tuple(inputs) + tuple(params)
96 | return CheckpointFunction.apply(func, len(inputs), *args)
97 | else:
98 | return func(*inputs)
99 |
100 | class CheckpointFunction(torch.autograd.Function):
101 | @staticmethod
102 | def forward(ctx, run_function, length, *args):
103 | ctx.run_function = run_function
104 | ctx.input_tensors = list(args[:length])
105 | ctx.input_params = list(args[length:])
106 |
107 | with torch.no_grad():
108 | output_tensors = ctx.run_function(*ctx.input_tensors)
109 | return output_tensors
110 |
111 | @staticmethod
112 | def backward(ctx, *output_grads):
113 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
114 | with torch.enable_grad():
115 | # Fixes a bug where the first op in run_function modifies the
116 | # Tensor storage in place, which is not allowed for detach()'d
117 | # Tensors.
118 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
119 | output_tensors = ctx.run_function(*shallow_copies)
120 | input_grads = torch.autograd.grad(
121 | output_tensors,
122 | ctx.input_tensors + ctx.input_params,
123 | output_grads,
124 | allow_unused=True,
125 | )
126 | del ctx.input_tensors
127 | del ctx.input_params
128 | del output_tensors
129 | return (None, None) + input_grads
130 |
131 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
132 | """
133 | Create sinusoidal timestep embeddings.
134 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
135 | These may be fractional.
136 | :param dim: the dimension of the output.
137 | :param max_period: controls the minimum frequency of the embeddings.
138 | :return: an [N x dim] Tensor of positional embeddings.
139 | """
140 | if not repeat_only:
141 | half = dim // 2
142 | freqs = torch.exp(
143 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
144 | ).to(device=timesteps.device)
145 | args = timesteps[:, None].float() * freqs[None]
146 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
147 | if dim % 2:
148 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
149 | else:
150 | embedding = repeat(timesteps, 'b -> b d', d=dim)
151 | return embedding
152 |
153 | def zero_module(module):
154 | """
155 | Zero out the parameters of a module and return it.
156 | """
157 | for p in module.parameters():
158 | p.detach().zero_()
159 | return module
160 |
161 | def scale_module(module, scale):
162 | """
163 | Scale the parameters of a module and return it.
164 | """
165 | for p in module.parameters():
166 | p.detach().mul_(scale)
167 | return module
168 |
169 | def mean_flat(tensor):
170 | """
171 | Take the mean over all non-batch dimensions.
172 | """
173 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
174 |
175 | def normalization(channels):
176 | """
177 | Make a standard normalization layer.
178 | :param channels: number of input channels.
179 | :return: an nn.Module for normalization.
180 | """
181 | return GroupNorm32(32, channels)
182 |
183 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
184 | class SiLU(nn.Module):
185 | def forward(self, x):
186 | return x * torch.sigmoid(x)
187 |
188 | class GroupNorm32(nn.GroupNorm):
189 | def forward(self, x):
190 | # return super().forward(x.float()).type(x.dtype)
191 | return super().forward(x)
192 |
193 | def conv_nd(dims, *args, **kwargs):
194 | """
195 | Create a 1D, 2D, or 3D convolution module.
196 | """
197 | if dims == 1:
198 | return nn.Conv1d(*args, **kwargs)
199 | elif dims == 2:
200 | return nn.Conv2d(*args, **kwargs)
201 | elif dims == 3:
202 | return nn.Conv3d(*args, **kwargs)
203 | raise ValueError(f"unsupported dimensions: {dims}")
204 |
205 | def linear(*args, **kwargs):
206 | """
207 | Create a linear module.
208 | """
209 | return nn.Linear(*args, **kwargs)
210 |
211 | def avg_pool_nd(dims, *args, **kwargs):
212 | """
213 | Create a 1D, 2D, or 3D average pooling module.
214 | """
215 | if dims == 1:
216 | return nn.AvgPool1d(*args, **kwargs)
217 | elif dims == 2:
218 | return nn.AvgPool2d(*args, **kwargs)
219 | elif dims == 3:
220 | return nn.AvgPool3d(*args, **kwargs)
221 | raise ValueError(f"unsupported dimensions: {dims}")
222 |
223 | class HybridConditioner(nn.Module):
224 |
225 | def __init__(self, c_concat_config, c_crossattn_config):
226 | super().__init__()
227 | self.concat_conditioner = instantiate_from_config(c_concat_config)
228 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
229 |
230 | def forward(self, c_concat, c_crossattn):
231 | c_concat = self.concat_conditioner(c_concat)
232 | c_crossattn = self.crossattn_conditioner(c_crossattn)
233 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
234 |
235 | def noise_like(x, repeat=False):
236 | noise = torch.randn_like(x)
237 | if repeat:
238 | bs = x.shape[0]
239 | noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1)))
240 | return noise
241 |
242 | ##########################
243 | # inherit from ldm.utils #
244 | ##########################
245 |
246 | def count_params(model, verbose=False):
247 | total_params = sum(p.numel() for p in model.parameters())
248 | if verbose:
249 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
250 | return total_params
251 |
--------------------------------------------------------------------------------
/lib/model_zoo/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/lib/model_zoo/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class LitEma(nn.Module):
5 | def __init__(self, model, decay=0.9999, use_num_updates=True):
6 | super().__init__()
7 | if decay < 0.0 or decay > 1.0:
8 | raise ValueError('Decay must be between 0 and 1')
9 |
10 | self.m_name2s_name = {}
11 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
12 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates
13 | else torch.tensor(-1,dtype=torch.int))
14 |
15 | for name, p in model.named_parameters():
16 | if p.requires_grad:
17 | #remove as '.'-character is not allowed in buffers
18 | s_name = name.replace('.','')
19 | self.m_name2s_name.update({name:s_name})
20 | self.register_buffer(s_name,p.clone().detach().data)
21 |
22 | self.collected_params = []
23 |
24 | def forward(self, model):
25 | decay = self.decay
26 |
27 | if self.num_updates >= 0:
28 | self.num_updates += 1
29 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
30 |
31 | one_minus_decay = 1.0 - decay
32 |
33 | with torch.no_grad():
34 | m_param = dict(model.named_parameters())
35 | shadow_params = dict(self.named_buffers())
36 |
37 | for key in m_param:
38 | if m_param[key].requires_grad:
39 | sname = self.m_name2s_name[key]
40 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
41 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
42 | else:
43 | assert not key in self.m_name2s_name
44 |
45 | def copy_to(self, model):
46 | m_param = dict(model.named_parameters())
47 | shadow_params = dict(self.named_buffers())
48 | for key in m_param:
49 | if m_param[key].requires_grad:
50 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
51 | else:
52 | assert not key in self.m_name2s_name
53 |
54 | def store(self, parameters):
55 | """
56 | Save the current parameters for restoring later.
57 | Args:
58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
59 | temporarily stored.
60 | """
61 | self.collected_params = [param.clone() for param in parameters]
62 |
63 | def restore(self, parameters):
64 | """
65 | Restore the parameters stored with the `store` method.
66 | Useful to validate the model with EMA parameters without affecting the
67 | original optimization process. Store the parameters before the
68 | `copy_to` method. After validation (or model saving), use this to
69 | restore the former parameters.
70 | Args:
71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
72 | updated with the stored parameters.
73 | """
74 | for c_param, param in zip(self.collected_params, parameters):
75 | param.data.copy_(c_param.data)
76 |
--------------------------------------------------------------------------------
/lib/model_zoo/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 | def append_dims(x, target_dims):
11 | dims_to_append = target_dims - x.ndim
12 | if dims_to_append < 0:
13 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
14 | return x[(...,) + (None,) * dims_to_append]
15 |
16 | def default_noise_sampler(x):
17 | return lambda sigma, sigma_next: torch.randn_like(x)
18 |
19 | def get_ancestral_step(sigma_from, sigma_to, eta=1.):
20 | if not eta:
21 | return sigma_to, 0.
22 | sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
23 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
24 | return sigma_down, sigma_up
25 |
26 | def to_d(x, sigma, denoised):
27 | return (x - denoised) / append_dims(sigma, x.ndim)
28 |
29 | class Sampler(object):
30 | def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5):
31 | super().__init__()
32 | self.net = net
33 | self.type = type
34 | self.steps = steps
35 | self.output_dim = output_dim
36 | self.n_samples = n_samples
37 | self.scale = scale
38 | self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5
39 | self.log_sigmas = self.sigmas.log()
40 |
41 | def t_to_sigma(self, t):
42 | t = t.float()
43 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
44 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
45 | return log_sigma.exp()
46 |
47 | def get_sigmas(self, n=None):
48 | def append_zero(x):
49 | return torch.cat([x, x.new_zeros([1])])
50 | if n is None:
51 | return append_zero(self.sigmas.flip(0))
52 | t_max = len(self.sigmas) - 1
53 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
54 | return append_zero(self.t_to_sigma(t))
55 |
56 | @torch.no_grad()
57 | def sample(self, x_info, c_info):
58 | h, w = self.output_dim
59 | shape = [self.n_samples, 4, h//8, w//8]
60 | device, dtype = self.net.get_device(), self.net.get_dtype()
61 |
62 | if ('xt' in x_info) and (x_info['xt'] is not None):
63 | xt = x_info['xt'].astype(dtype).to(device)
64 | x_info['x'] = xt
65 | elif ('x0' in x_info) and (x_info['x0'] is not None):
66 | x0 = x_info['x0'].type(dtype).to(device)
67 | ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples)
68 | ts = torch.Tensor(ts).long().to(device)
69 | timesteps = timesteps[:x_info['x0_forward_timesteps']]
70 | x0_nz = self.model.q_sample(x0, ts)
71 | x_info['x'] = x0_nz
72 | else:
73 | x_info['x'] = torch.randn(shape, device=device, dtype=dtype)
74 |
75 | sigmas = self.get_sigmas(n=self.steps)
76 |
77 | if self.type == 'eular_a':
78 | rv = self.sample_euler_ancestral(
79 | x_info=x_info,
80 | c_info=c_info,
81 | sigmas = sigmas)
82 | return rv
83 |
84 | @torch.no_grad()
85 | def sample_euler_ancestral(
86 | self, x_info, c_info, sigmas, eta=1., s_noise=1.,):
87 |
88 | x = x_info['x']
89 | x = x * sigmas[0]
90 |
91 | noise_sampler = default_noise_sampler(x)
92 |
93 | s_in = x.new_ones([x.shape[0]])
94 | for i in range(len(sigmas)-1):
95 | denoised = self.net.apply_model(x, sigmas[i] * s_in, )
96 |
97 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
98 | d = to_d(x, sigmas[i], denoised)
99 | # Euler method
100 | dt = sigma_down - sigmas[i]
101 | x = x + d * dt
102 | if sigmas[i + 1] > 0:
103 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
104 | return x
105 |
--------------------------------------------------------------------------------
/lib/sync.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import shared_memory
2 | # import multiprocessing
3 | # if hasattr(multiprocessing, "shared_memory"):
4 | # from multiprocessing import shared_memory
5 | # else:
6 | # # workaround for single gpu inference on colab
7 | # shared_memory = None
8 |
9 | import random
10 | import pickle
11 | import time
12 | import copy
13 | import torch
14 | import torch.distributed as dist
15 | from lib.cfg_holder import cfg_unique_holder as cfguh
16 |
17 | def singleton(class_):
18 | instances = {}
19 | def getinstance(*args, **kwargs):
20 | if class_ not in instances:
21 | instances[class_] = class_(*args, **kwargs)
22 | return instances[class_]
23 | return getinstance
24 |
25 | def is_ddp():
26 | return dist.is_available() and dist.is_initialized()
27 |
28 | def get_rank(type='local'):
29 | ddp = is_ddp()
30 | global_rank = dist.get_rank() if ddp else 0
31 | local_world_size = torch.cuda.device_count()
32 | if type == 'global':
33 | return global_rank
34 | elif type == 'local':
35 | return global_rank % local_world_size
36 | elif type == 'node':
37 | return global_rank // local_world_size
38 | elif type == 'all':
39 | return global_rank, \
40 | global_rank % local_world_size, \
41 | global_rank // local_world_size
42 | else:
43 | assert False, 'Unknown type'
44 |
45 | def get_world_size(type='local'):
46 | ddp = is_ddp()
47 | global_rank = dist.get_rank() if ddp else 0
48 | global_world_size = dist.get_world_size() if ddp else 1
49 | local_world_size = torch.cuda.device_count()
50 | if type == 'global':
51 | return global_world_size
52 | elif type == 'local':
53 | return local_world_size
54 | elif type == 'node':
55 | return global_world_size // local_world_size
56 | elif type == 'all':
57 | return global_world_size, local_world_size, \
58 | global_world_size // local_world_size
59 | else:
60 | assert False, 'Unknown type'
61 |
62 | class barrier_lock(object):
63 | def __init__(self, n):
64 | self.n = n
65 | id = int(random.random()*10000) + int(time.time())*10000
66 | self.lock_shmname = 'barrier_lock_{}'.format(id)
67 | lock_shm = shared_memory.SharedMemory(
68 | name=self.lock_shmname, create=True, size=n)
69 | for i in range(n):
70 | lock_shm.buf[i] = 0
71 | lock_shm.close()
72 |
73 | def destroy(self):
74 | try:
75 | lock_shm = shared_memory.SharedMemory(
76 | name=self.lock_shmname)
77 | lock_shm.close()
78 | lock_shm.unlink()
79 | except:
80 | return
81 |
82 | def wait(self, k):
83 | lock_shm = shared_memory.SharedMemory(
84 | name=self.lock_shmname)
85 | assert lock_shm.buf[k] == 0, 'Two waits on the same id is not allowed.'
86 | lock_shm.buf[k] = 1
87 | if k == 0:
88 | while sum([lock_shm.buf[i]==0 for i in range(self.n)]) != 0:
89 | pass
90 | for i in range(self.n):
91 | lock_shm.buf[i] = 0
92 | return
93 | else:
94 | while lock_shm.buf[k] != 0:
95 | pass
96 |
97 | class default_lock(object):
98 | def __init__(self):
99 | id = int(random.random()*10000) + int(time.time())*10000
100 | self.lock_shmname = 'lock_{}'.format(id)
101 | lock_shm = shared_memory.SharedMemory(
102 | name=self.lock_shmname, create=True, size=2)
103 | for i in range(2):
104 | lock_shm.buf[i] = 0
105 | lock_shm.close()
106 |
107 | def destroy(self):
108 | try:
109 | lock_shm = shared_memory.SharedMemory(
110 | name=self.lock_shmname)
111 | lock_shm.close()
112 | lock_shm.unlink()
113 | except:
114 | return
115 |
116 | def lock(self, k):
117 | lock_shm = shared_memory.SharedMemory(
118 | name=self.lock_shmname)
119 | while lock_shm.buf[0] == 1:
120 | pass
121 | lock_shm.buf[0] = 1
122 | lock_shm.buf[1] = k
123 |
124 | def unlock(self, k):
125 | lock_shm = shared_memory.SharedMemory(
126 | name=self.lock_shmname)
127 | if lock_shm.buf[1] != k:
128 | return
129 | lock_shm.buf[0] = 0
130 | return
131 |
132 | class nodewise_sync_global(object):
133 | """
134 | This is the global part of nodewise_sync that need to call at master process
135 | before spawn.
136 | """
137 | def __init__(self):
138 | self.local_world_size = get_world_size('local')
139 | self.reg_lock = default_lock()
140 | self.b_lock = barrier_lock(self.local_world_size)
141 | id = int(random.random()*10000) + int(time.time())*10000
142 | self.id_shmname = 'nodewise_sync_id_shm_{}'.format(id)
143 |
144 | def destroy(self):
145 | self.reg_lock.destroy()
146 | self.b_lock.destroy()
147 | try:
148 | shm = shared_memory.SharedMemory(name=self.id_shmname)
149 | shm.close()
150 | shm.unlink()
151 | except:
152 | return
153 |
154 | @singleton
155 | class nodewise_sync(object):
156 | """
157 | A class that centralize nodewise sync activities.
158 | The backend is multiprocess sharememory, not torch, as torch not support this.
159 | """
160 | def __init__(self):
161 | pass
162 |
163 | def copy_global(self, reference):
164 | self.local_world_size = reference.local_world_size
165 | self.b_lock = reference.b_lock
166 | self.reg_lock = reference.reg_lock
167 | self.id_shmname = reference.id_shmname
168 | return self
169 |
170 | def local_init(self):
171 | self.ddp = is_ddp()
172 | self.global_rank, self.local_rank, self.node_rank = get_rank('all')
173 | self.global_world_size, self.local_world_size, self.nodes = get_world_size('all')
174 | if self.local_rank == 0:
175 | temp = int(random.random()*10000) + int(time.time())*10000
176 | temp = pickle.dumps(temp)
177 | shm = shared_memory.SharedMemory(
178 | name=self.id_shmname, create=True, size=len(temp))
179 | shm.close()
180 | return self
181 |
182 | def random_sync_id(self):
183 | assert self.local_rank is not None, 'Not initialized!'
184 | if self.local_rank == 0:
185 | sync_id = int(random.random()*10000) + int(time.time())*10000
186 | data = pickle.dumps(sync_id)
187 | shm = shared_memory.SharedMemory(name=self.id_shmname)
188 | shm.buf[0:len(data)] = data[0:len(data)]
189 | self.barrier()
190 | shm.close()
191 | else:
192 | self.barrier()
193 | shm = shared_memory.SharedMemory(name=self.id_shmname)
194 | sync_id = pickle.loads(shm.buf)
195 | shm.close()
196 | return sync_id
197 |
198 | def barrier(self):
199 | self.b_lock.wait(self.local_rank)
200 |
201 | def lock(self):
202 | self.reg_lock.lock(self.local_rank)
203 |
204 | def unlock(self):
205 | self.reg_lock.unlock(self.local_rank)
206 |
207 | def broadcast_r0(self, data=None):
208 | assert self.local_rank is not None, 'Not initialized!'
209 | id = self.random_sync_id()
210 | shmname = 'broadcast_r0_{}'.format(id)
211 | if self.local_rank == 0:
212 | assert data!=None, 'Rank 0 needs to input data!'
213 | data = pickle.dumps(data)
214 | datan = len(data)
215 | load_info_shm = shared_memory.SharedMemory(
216 | name=shmname, create=True, size=datan)
217 | load_info_shm.buf[0:datan] = data[0:datan]
218 | self.barrier()
219 | self.barrier()
220 | load_info_shm.close()
221 | load_info_shm.unlink()
222 | return None
223 | else:
224 | assert data==None, 'Rank other than 1 should input None as data!'
225 | self.barrier()
226 | shm = shared_memory.SharedMemory(name=shmname)
227 | data = pickle.loads(shm.buf)
228 | shm.close()
229 | self.barrier()
230 | return data
231 |
232 | def destroy(self):
233 | self.barrier.destroy()
234 | try:
235 | shm = shared_memory.SharedMemory(name=self.id_shmname)
236 | shm.close()
237 | shm.unlink()
238 | except:
239 | return
240 |
241 | # import contextlib
242 |
243 | # @contextlib.contextmanager
244 | # def weight_sync(module, sync):
245 | # assert isinstance(module, torch.nn.Module)
246 | # if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
247 | # yield
248 | # else:
249 | # with module.no_sync():
250 | # yield
251 |
252 | # def weight_sync(net):
253 | # for parameters in net.parameters():
254 | # dist.all_reduce(parameters, dist.ReduceOp.AVG)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.4.2
2 | pyyaml==5.4.1
3 | easydict==1.9
4 | tensorboardx==2.6
5 | tensorboard==2.12.0
6 | protobuf==3.20.3
7 | lpips==0.1.3
8 | fsspec==2022.7.1
9 |
10 | tqdm==4.60.0
11 | transformers==4.24.0
12 | torchmetrics==0.7.3
13 |
14 | einops==0.3.0
15 | omegaconf==2.1.1
16 | open_clip_torch==2.0.2
17 | webdataset==0.2.5
18 | gradio==3.17.1
19 |
20 | safetensors==0.3.1
21 |
22 | opencv-python
23 | scikit-image
24 |
--------------------------------------------------------------------------------
/tools/get_controlnet.py:
--------------------------------------------------------------------------------
1 | # A tool to get and slim downloaded controlnet
2 |
3 | import torch
4 | from safetensors.torch import load_file, save_file
5 | from collections import OrderedDict
6 | import os.path as osp
7 |
8 | in_path = 'pretrained/controlnet/sdwebui_compatible/control_v11p_sd15_canny.pth'
9 | out_path = 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors'
10 |
11 | sd = torch.load(in_path)
12 |
13 | sdnew = [[ni.replace('control_model.', ''), vi] for ni, vi in sd.items()]
14 | save_file(OrderedDict(sdnew), out_path)
15 |
--------------------------------------------------------------------------------