├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── assets ├── anime_ug.pth ├── clip_ug.pth ├── examples-anime │ ├── camping.jpg │ ├── hanfu_girl.jpg │ ├── miku-canny.png │ ├── miku.jpg │ ├── pose.png │ ├── pose_small.png │ ├── pose_withhandface.png │ ├── pose_withhandface_small.png │ └── random1.jpg ├── examples │ ├── astronautridinghouse-canny.png │ ├── astronautridinghouse-input.jpg │ ├── bedroom-input.jpg │ ├── bedroom-mlsd.png │ ├── ghibli-canny.png │ ├── ghibli-input.jpg │ ├── grassland-input.jpg │ ├── grassland-scribble.png │ ├── jeep-depth.png │ ├── jeep-input.jpg │ ├── nightstreet-canny.png │ ├── nightstreet-input.jpg │ ├── woodcar-depth.png │ └── woodcar-input.jpg └── figures │ ├── anime.png │ ├── prompt_free_diffusion.png │ ├── qualitative_show.png │ ├── reusability.png │ └── seecoder.png ├── configs └── model │ ├── autokl.yaml │ ├── clip.yaml │ ├── controlnet.yaml │ ├── openai_unet.yaml │ ├── pfd.yaml │ ├── seecoder.yaml │ └── swin.yaml ├── lib ├── __init__.py ├── cfg_helper.py ├── cfg_holder.py ├── log_service.py ├── model_zoo │ ├── __init__.py │ ├── attention.py │ ├── autokl.py │ ├── autokl_modules.py │ ├── autokl_utils.py │ ├── clip.py │ ├── common │ │ ├── get_model.py │ │ ├── get_optimizer.py │ │ ├── get_scheduler.py │ │ └── utils.py │ ├── controlnet.py │ ├── controlnet_annotator │ │ ├── canny │ │ │ └── __init__.py │ │ ├── hed │ │ │ └── __init__.py │ │ ├── midas │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt_depth.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── midas_net_custom.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ │ └── utils.py │ │ ├── mlsd │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ │ ├── mbv2_mlsd_large.py │ │ │ │ └── mbv2_mlsd_tiny.py │ │ │ └── utils.py │ │ ├── openpose │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── body.py │ │ │ ├── face.py │ │ │ ├── hand.py │ │ │ ├── model.py │ │ │ └── util.py │ │ └── pidinet │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── model.py │ ├── ddim.py │ ├── diffusion_utils.py │ ├── distributions.py │ ├── ema.py │ ├── openaimodel.py │ ├── pfd.py │ ├── sampler.py │ ├── seecoder.py │ └── swin.py ├── sync.py └── utils.py ├── requirements.txt └── tools ├── get_controlnet.py └── model_conversion.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode/ 3 | data/ 4 | data 5 | log/ 6 | log 7 | pretrained/ 8 | pretrained 9 | assets/nosync/ 10 | assets/demo/temp/temp_* 11 | *.out 12 | gradio_cached_examples/ 13 | src/*/build 14 | src/*/dist 15 | src/*/*.egg-info/ 16 | extensions/ 17 | extensions 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 SHI Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt-Free Diffusion 2 | 3 | [![HuggingFace space](https://img.shields.io/badge/🤗-Huggingface%20Space-cyan.svg)](https://huggingface.co/spaces/shi-labs/Prompt-Free-Diffusion) 4 | [![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | This repo hosts the official implementation of: 8 | 9 | [Xingqian Xu](https://ifp-uiuc.github.io/), Jiayi Guo, Zhangyang Wang, Gao Huang, Irfan Essa, and [Humphrey Shi](https://www.humphreyshi.com/home), **Prompt-Free Diffusion: Taking "Text" out of Text-to-Image Diffusion Models**, [Paper arXiv Link](https://arxiv.org/abs/2305.16223). 10 | 11 | ## News 12 | 13 | - **[2023.06.20]: SDWebUI plugin is created, repo at this [link](https://github.com/xingqian2018/sd-webui-prompt-free-diffusion)** 14 | - [2023.05.25]: Our demo is running on [HuggingFace🤗](https://huggingface.co/spaces/shi-labs/Prompt-Free-Diffusion) 15 | - [2023.05.25]: Repo created 16 | 17 | ## Introduction 18 | 19 | **Prompt-Free Diffusion** is a diffusion model that relys on only visual inputs to generate new images, handled by **Semantic Context Encoder (SeeCoder)** by substituting the commonly used CLIP-based text encoder. SeeCoder is **reusable to most public T2I models as well as adaptive layers** like ControlNet, LoRA, T2I-Adapter, etc. Just drop in and play! 20 | 21 |

22 | 23 |

24 | 25 | ## Performance 26 | 27 |

28 | 29 |

30 | 31 | ## Network 32 | 33 |

34 | 35 |

36 | 37 |

38 | 39 |

40 | 41 | ## Setup 42 | 43 | ``` 44 | conda create -n prompt-free-diffusion python=3.10 45 | conda activate prompt-free-diffusion 46 | pip install torch==2.0.0+cu117 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu117 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ## Demo 51 | 52 | We provide a WebUI empowered by [Gradio](https://github.com/gradio-app/gradio). Start the WebUI with the following command: 53 | 54 | ``` 55 | python app.py 56 | ``` 57 | 58 | ## Pretrained models 59 | 60 | To support the full functionality of our demo. You need the following models located in these paths: 61 | 62 | ``` 63 | └── pretrained 64 | ├── pfd 65 | | ├── vae 66 | | │ └── sd-v2-0-base-autokl.pth 67 | | ├── diffuser 68 | | │ ├── AbyssOrangeMix-v2.safetensors 69 | | │ ├── AbyssOrangeMix-v3.safetensors 70 | | │ ├── Anything-v4.safetensors 71 | | │ ├── Deliberate-v2-0.safetensors 72 | | │ ├── OpenJouney-v4.safetensors 73 | | │ ├── RealisticVision-v2-0.safetensors 74 | | │ └── SD-v1-5.safetensors 75 | | └── seecoder 76 | | ├── seecoder-v1-0.safetensors 77 | | ├── seecoder-pa-v1-0.safetensors 78 | | └── seecoder-anime-v1-0.safetensors 79 | └── controlnet 80 | ├── control_sd15_canny_slimmed.safetensors 81 | ├── control_sd15_depth_slimmed.safetensors 82 | ├── control_sd15_hed_slimmed.safetensors 83 | ├── control_sd15_mlsd_slimmed.safetensors 84 | ├── control_sd15_normal_slimmed.safetensors 85 | ├── control_sd15_openpose_slimmed.safetensors 86 | ├── control_sd15_scribble_slimmed.safetensors 87 | ├── control_sd15_seg_slimmed.safetensors 88 | ├── control_v11p_sd15_canny_slimmed.safetensors 89 | ├── control_v11p_sd15_lineart_slimmed.safetensors 90 | ├── control_v11p_sd15_mlsd_slimmed.safetensors 91 | ├── control_v11p_sd15_openpose_slimmed.safetensors 92 | ├── control_v11p_sd15s2_lineart_anime_slimmed.safetensors 93 | ├── control_v11p_sd15_softedge_slimmed.safetensors 94 | └── preprocess 95 | ├── hed 96 | │ └── ControlNetHED.pth 97 | ├── midas 98 | │ └── dpt_hybrid-midas-501f0c75.pt 99 | ├── mlsd 100 | │ └── mlsd_large_512_fp32.pth 101 | ├── openpose 102 | │ ├── body_pose_model.pth 103 | │ ├── facenet.pth 104 | │ └── hand_pose_model.pth 105 | └── pidinet 106 | └── table5_pidinet.pth 107 | ``` 108 | 109 | All models can be downloaded at [HuggingFace link](https://huggingface.co/shi-labs/prompt-free-diffusion). 110 | 111 | ## Tools 112 | 113 | We also provide tools to convert pretrained models from sdwebui and diffuser library to this codebase, please modify the following files: 114 | 115 | ``` 116 | └── tools 117 |    ├── get_controlnet.py 118 |    └── model_conversion.pth 119 | ``` 120 | 121 | You are expected to do some customized coding to make it work (i.e. changing hardcoded input output file paths) 122 | 123 | ## Performance Anime 124 | 125 |

126 | 127 |

128 | 129 | ## Citation 130 | 131 | ``` 132 | @article{xu2023prompt, 133 | title={Prompt-Free Diffusion: Taking" Text" out of Text-to-Image Diffusion Models}, 134 | author={Xu, Xingqian and Guo, Jiayi and Wang, Zhangyang and Huang, Gao and Essa, Irfan and Shi, Humphrey}, 135 | journal={arXiv preprint arXiv:2305.16223}, 136 | year={2023} 137 | } 138 | ``` 139 | 140 | ## Acknowledgement 141 | 142 | Part of the codes reorganizes/reimplements code from the following repositories: [Versatile Diffusion official Github](https://github.com/SHI-Labs/Versatile-Diffusion) and [ControlNet sdwebui Github](https://github.com/Mikubill/sd-webui-controlnet), which are also great influenced by [LDM official Github](https://github.com/CompVis/latent-diffusion) and [DDPM official Github](https://github.com/lucidrains/denoising-diffusion-pytorch) 143 | -------------------------------------------------------------------------------- /assets/anime_ug.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/anime_ug.pth -------------------------------------------------------------------------------- /assets/clip_ug.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/clip_ug.pth -------------------------------------------------------------------------------- /assets/examples-anime/camping.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/camping.jpg -------------------------------------------------------------------------------- /assets/examples-anime/hanfu_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/hanfu_girl.jpg -------------------------------------------------------------------------------- /assets/examples-anime/miku-canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/miku-canny.png -------------------------------------------------------------------------------- /assets/examples-anime/miku.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/miku.jpg -------------------------------------------------------------------------------- /assets/examples-anime/pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose.png -------------------------------------------------------------------------------- /assets/examples-anime/pose_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_small.png -------------------------------------------------------------------------------- /assets/examples-anime/pose_withhandface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_withhandface.png -------------------------------------------------------------------------------- /assets/examples-anime/pose_withhandface_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/pose_withhandface_small.png -------------------------------------------------------------------------------- /assets/examples-anime/random1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples-anime/random1.jpg -------------------------------------------------------------------------------- /assets/examples/astronautridinghouse-canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/astronautridinghouse-canny.png -------------------------------------------------------------------------------- /assets/examples/astronautridinghouse-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/astronautridinghouse-input.jpg -------------------------------------------------------------------------------- /assets/examples/bedroom-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/bedroom-input.jpg -------------------------------------------------------------------------------- /assets/examples/bedroom-mlsd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/bedroom-mlsd.png -------------------------------------------------------------------------------- /assets/examples/ghibli-canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/ghibli-canny.png -------------------------------------------------------------------------------- /assets/examples/ghibli-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/ghibli-input.jpg -------------------------------------------------------------------------------- /assets/examples/grassland-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/grassland-input.jpg -------------------------------------------------------------------------------- /assets/examples/grassland-scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/grassland-scribble.png -------------------------------------------------------------------------------- /assets/examples/jeep-depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/jeep-depth.png -------------------------------------------------------------------------------- /assets/examples/jeep-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/jeep-input.jpg -------------------------------------------------------------------------------- /assets/examples/nightstreet-canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/nightstreet-canny.png -------------------------------------------------------------------------------- /assets/examples/nightstreet-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/nightstreet-input.jpg -------------------------------------------------------------------------------- /assets/examples/woodcar-depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/woodcar-depth.png -------------------------------------------------------------------------------- /assets/examples/woodcar-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/examples/woodcar-input.jpg -------------------------------------------------------------------------------- /assets/figures/anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/anime.png -------------------------------------------------------------------------------- /assets/figures/prompt_free_diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/prompt_free_diffusion.png -------------------------------------------------------------------------------- /assets/figures/qualitative_show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/qualitative_show.png -------------------------------------------------------------------------------- /assets/figures/reusability.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/reusability.png -------------------------------------------------------------------------------- /assets/figures/seecoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/assets/figures/seecoder.png -------------------------------------------------------------------------------- /configs/model/autokl.yaml: -------------------------------------------------------------------------------- 1 | autokl: 2 | symbol: autokl 3 | find_unused_parameters: false 4 | 5 | autokl_v1: 6 | super_cfg: autokl 7 | type: autoencoderkl 8 | args: 9 | embed_dim: 4 10 | ddconfig: 11 | double_z: true 12 | z_channels: 4 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 2, 4, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [] 20 | dropout: 0.0 21 | lossconfig: null 22 | pth: pretrained/kl-f8.pth 23 | 24 | autokl_v2: 25 | super_cfg: autokl_v1 26 | pth: pretrained/pfd/vae/sd-v2-0-base-autokl.pth 27 | -------------------------------------------------------------------------------- /configs/model/clip.yaml: -------------------------------------------------------------------------------- 1 | ################ 2 | # clip for sd1 # 3 | ################ 4 | 5 | clip: 6 | symbol: clip 7 | args: {} 8 | 9 | clip_text_context_encoder_sdv1: 10 | super_cfg: clip 11 | type: clip_text_context_encoder_sdv1 12 | args: {} 13 | -------------------------------------------------------------------------------- /configs/model/controlnet.yaml: -------------------------------------------------------------------------------- 1 | controlnet: 2 | symbol: controlnet 3 | type: controlnet 4 | find_unused_parameters: false 5 | args: 6 | image_size: 32 # unused 7 | in_channels: 4 8 | hint_channels: 3 9 | model_channels: 320 10 | attention_resolutions: [ 4, 2, 1 ] 11 | num_res_blocks: 2 12 | channel_mult: [ 1, 2, 4, 4 ] 13 | num_heads: 8 14 | use_spatial_transformer: True 15 | transformer_depth: 1 16 | context_dim: 768 17 | use_checkpoint: True 18 | legacy: False 19 | -------------------------------------------------------------------------------- /configs/model/openai_unet.yaml: -------------------------------------------------------------------------------- 1 | openai_unet_sd: 2 | type: openai_unet 3 | args: 4 | image_size: null # no use 5 | in_channels: 4 6 | out_channels: 4 7 | model_channels: 320 8 | attention_resolutions: [ 4, 2, 1 ] 9 | num_res_blocks: [ 2, 2, 2, 2 ] 10 | channel_mult: [ 1, 2, 4, 4 ] 11 | # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true 12 | num_heads: 8 13 | use_spatial_transformer: True 14 | transformer_depth: 1 15 | context_dim: 768 16 | use_checkpoint: True 17 | legacy: False 18 | 19 | ######### 20 | # v1 2d # 21 | ######### 22 | 23 | openai_unet_2d_v1: 24 | type: openai_unet_2d_next 25 | args: 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: [ 4, 2, 1 ] 30 | num_res_blocks: [ 2, 2, 2, 2 ] 31 | channel_mult: [ 1, 2, 4, 4 ] 32 | num_heads: 8 33 | context_dim: 768 34 | use_checkpoint: False 35 | parts: [global, data, context] 36 | -------------------------------------------------------------------------------- /configs/model/pfd.yaml: -------------------------------------------------------------------------------- 1 | pfd_base: 2 | symbol: pfd 3 | find_unused_parameters: true 4 | type: pfd 5 | args: 6 | beta_linear_start: 0.00085 7 | beta_linear_end: 0.012 8 | timesteps: 1000 9 | use_ema: false 10 | 11 | pfd_seecoder: 12 | super_cfg: pfd_base 13 | args: 14 | vae_cfg_list: 15 | - [image, MODEL(autokl_v2)] 16 | ctx_cfg_list: 17 | - [image, MODEL(seecoder)] 18 | diffuser_cfg_list: 19 | - [image, MODEL(openai_unet_2d_v1)] 20 | latent_scale_factor: 21 | image: 0.18215 22 | 23 | pdf_seecoder_pa: 24 | super_cfg: pfd_seecoder 25 | args: 26 | ctx_cfg_list: 27 | - [image, MODEL(seecoder_pa)] 28 | 29 | pfd_seecoder_with_controlnet: 30 | super_cfg: pfd_seecoder 31 | type: pfd_with_control 32 | args: 33 | ctl_cfg: MODEL(controlnet) 34 | -------------------------------------------------------------------------------- /configs/model/seecoder.yaml: -------------------------------------------------------------------------------- 1 | seecoder_base: 2 | symbol: seecoder 3 | args: {} 4 | 5 | seecoder: 6 | super_cfg: seecoder_base 7 | type: seecoder 8 | args: 9 | imencoder_cfg : MODEL(swin_large) 10 | imdecoder_cfg : MODEL(seecoder_decoder) 11 | qtransformer_cfg : MODEL(seecoder_query_transformer) 12 | 13 | seecoder_pa: 14 | super_cfg: seet 15 | type: seecoder 16 | args: 17 | imencoder_cfg : MODEL(swin_large) 18 | imdecoder_cfg : MODEL(seecoder_decoder) 19 | qtransformer_cfg : MODEL(seecoder_query_transformer_position_aware) 20 | 21 | ########### 22 | # decoder # 23 | ########### 24 | 25 | seecoder_decoder: 26 | super_cfg: seecoder_base 27 | type: seecoder_decoder 28 | args: 29 | inchannels: 30 | res3: 384 31 | res4: 768 32 | res5: 1536 33 | trans_input_tags: ['res3', 'res4', 'res5'] 34 | trans_dim: 768 35 | trans_dropout: 0.1 36 | trans_nheads: 8 37 | trans_feedforward_dim: 1024 38 | trans_num_layers: 6 39 | 40 | ##################### 41 | # query_transformer # 42 | ##################### 43 | 44 | seecoder_query_transformer: 45 | super_cfg: seecoder_base 46 | type: seecoder_query_transformer 47 | args: 48 | in_channels : 768 49 | hidden_dim: 768 50 | num_queries: [4, 144] 51 | nheads: 8 52 | num_layers: 9 53 | feedforward_dim: 2048 54 | pre_norm: False 55 | num_feature_levels: 3 56 | enforce_input_project: False 57 | with_fea2d_pos: false 58 | 59 | seecoder_query_transformer_position_aware: 60 | super_cfg: seecoder_query_transformer 61 | args: 62 | with_fea2d_pos: true 63 | -------------------------------------------------------------------------------- /configs/model/swin.yaml: -------------------------------------------------------------------------------- 1 | swin: 2 | symbol: swin 3 | args: {} 4 | 5 | swin_base: 6 | super_cfg: swin 7 | type: swin 8 | args: 9 | embed_dim: 128 10 | depths: [ 2, 2, 18, 2 ] 11 | num_heads: [ 4, 8, 16, 32 ] 12 | window_size: 7 13 | ape: False 14 | drop_path_rate: 0.3 15 | patch_norm: True 16 | strict_sd: False 17 | 18 | swin_large: 19 | super_cfg: swin 20 | type: swin 21 | args: 22 | embed_dim: 192 23 | depths: [ 2, 2, 18, 2 ] 24 | num_heads: [ 6, 12, 24, 48 ] 25 | window_size: 12 26 | ape: False 27 | drop_path_rate: 0.3 28 | patch_norm: True 29 | strict_sd: False 30 | 31 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/lib/__init__.py -------------------------------------------------------------------------------- /lib/cfg_holder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def singleton(class_): 4 | instances = {} 5 | def getinstance(*args, **kwargs): 6 | if class_ not in instances: 7 | instances[class_] = class_(*args, **kwargs) 8 | return instances[class_] 9 | return getinstance 10 | 11 | ############## 12 | # cfg_holder # 13 | ############## 14 | 15 | @singleton 16 | class cfg_unique_holder(object): 17 | def __init__(self): 18 | self.cfg = None 19 | # this is use to track the main codes. 20 | self.code = set() 21 | def save_cfg(self, cfg): 22 | self.cfg = copy.deepcopy(cfg) 23 | def add_code(self, code): 24 | """ 25 | A new main code is reached and 26 | its name is added. 27 | """ 28 | self.code.add(code) 29 | -------------------------------------------------------------------------------- /lib/log_service.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import copy 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from .cfg_holder import cfg_unique_holder as cfguh 11 | from . import sync 12 | 13 | def print_log(*console_info): 14 | grank, lrank, _ = sync.get_rank('all') 15 | if lrank!=0: 16 | return 17 | 18 | console_info = [str(i) for i in console_info] 19 | console_info = ' '.join(console_info) 20 | print(console_info) 21 | 22 | if grank!=0: 23 | return 24 | 25 | log_file = None 26 | try: 27 | log_file = cfguh().cfg.train.log_file 28 | except: 29 | try: 30 | log_file = cfguh().cfg.eval.log_file 31 | except: 32 | return 33 | if log_file is not None: 34 | with open(log_file, 'a') as f: 35 | f.write(console_info + '\n') 36 | 37 | class distributed_log_manager(object): 38 | def __init__(self): 39 | self.sum = {} 40 | self.cnt = {} 41 | self.time_check = timeit.default_timer() 42 | 43 | cfgt = cfguh().cfg.train 44 | self.ddp = sync.is_ddp() 45 | self.grank, self.lrank, _ = sync.get_rank('all') 46 | self.gwsize = sync.get_world_size('global') 47 | 48 | use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0) 49 | 50 | self.tb = None 51 | if use_tensorboard: 52 | import tensorboardX 53 | monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') 54 | self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) 55 | 56 | def accumulate(self, n, **data): 57 | if n < 0: 58 | raise ValueError 59 | 60 | for itemn, di in data.items(): 61 | if itemn in self.sum: 62 | self.sum[itemn] += di * n 63 | self.cnt[itemn] += n 64 | else: 65 | self.sum[itemn] = di * n 66 | self.cnt[itemn] = n 67 | 68 | def get_mean_value_dict(self): 69 | value_gather = [ 70 | self.sum[itemn]/self.cnt[itemn] \ 71 | for itemn in sorted(self.sum.keys()) ] 72 | 73 | value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank) 74 | if self.ddp: 75 | dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) 76 | value_gather_tensor /= self.gwsize 77 | 78 | mean = {} 79 | for idx, itemn in enumerate(sorted(self.sum.keys())): 80 | mean[itemn] = value_gather_tensor[idx].item() 81 | return mean 82 | 83 | def tensorboard_log(self, step, data, mode='train', **extra): 84 | if self.tb is None: 85 | return 86 | if mode == 'train': 87 | self.tb.add_scalar('other/epochn', extra['epochn'], step) 88 | if ('lr' in extra) and (extra['lr'] is not None): 89 | self.tb.add_scalar('other/lr', extra['lr'], step) 90 | for itemn, di in data.items(): 91 | if itemn.find('loss') == 0: 92 | self.tb.add_scalar('loss/'+itemn, di, step) 93 | elif itemn == 'Loss': 94 | self.tb.add_scalar('Loss', di, step) 95 | else: 96 | self.tb.add_scalar('other/'+itemn, di, step) 97 | elif mode == 'eval': 98 | if isinstance(data, dict): 99 | for itemn, di in data.items(): 100 | self.tb.add_scalar('eval/'+itemn, di, step) 101 | else: 102 | self.tb.add_scalar('eval', data, step) 103 | return 104 | 105 | def train_summary(self, itern, epochn, samplen, lr, tbstep=None): 106 | console_info = [ 107 | 'Iter:{}'.format(itern), 108 | 'Epoch:{}'.format(epochn), 109 | 'Sample:{}'.format(samplen),] 110 | 111 | if lr is not None: 112 | console_info += ['LR:{:.4E}'.format(lr)] 113 | 114 | mean = self.get_mean_value_dict() 115 | 116 | tbstep = itern if tbstep is None else tbstep 117 | self.tensorboard_log( 118 | tbstep, mean, mode='train', 119 | itern=itern, epochn=epochn, lr=lr) 120 | 121 | loss = mean.pop('Loss') 122 | mean_info = ['Loss:{:.4f}'.format(loss)] + [ 123 | '{}:{:.4f}'.format(itemn, mean[itemn]) \ 124 | for itemn in sorted(mean.keys()) \ 125 | if itemn.find('loss') == 0 126 | ] 127 | console_info += mean_info 128 | console_info.append('Time:{:.2f}s'.format( 129 | timeit.default_timer() - self.time_check)) 130 | return ' , '.join(console_info) 131 | 132 | def clear(self): 133 | self.sum = {} 134 | self.cnt = {} 135 | self.time_check = timeit.default_timer() 136 | 137 | def tensorboard_close(self): 138 | if self.tb is not None: 139 | self.tb.close() 140 | 141 | # ----- also include some small utils ----- 142 | 143 | def torch_to_numpy(*argv): 144 | if len(argv) > 1: 145 | data = list(argv) 146 | else: 147 | data = argv[0] 148 | 149 | if isinstance(data, torch.Tensor): 150 | return data.to('cpu').detach().numpy() 151 | 152 | elif isinstance(data, (list, tuple)): 153 | out = [] 154 | for di in data: 155 | out.append(torch_to_numpy(di)) 156 | return out 157 | 158 | elif isinstance(data, dict): 159 | out = {} 160 | for ni, di in data.items(): 161 | out[ni] = torch_to_numpy(di) 162 | return out 163 | 164 | else: 165 | return data 166 | -------------------------------------------------------------------------------- /lib/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .common.get_model import get_model 2 | from .common.get_optimizer import get_optimizer 3 | from .common.get_scheduler import get_scheduler 4 | from .common.utils import get_unit 5 | -------------------------------------------------------------------------------- /lib/model_zoo/autokl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | from lib.model_zoo.common.get_model import get_model, register 6 | 7 | # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 8 | 9 | from .autokl_modules import Encoder, Decoder 10 | from .distributions import DiagonalGaussianDistribution 11 | 12 | from .autokl_utils import LPIPSWithDiscriminator 13 | 14 | @register('autoencoderkl') 15 | class AutoencoderKL(nn.Module): 16 | def __init__(self, 17 | ddconfig, 18 | lossconfig, 19 | embed_dim,): 20 | super().__init__() 21 | self.encoder = Encoder(**ddconfig) 22 | self.decoder = Decoder(**ddconfig) 23 | if lossconfig is not None: 24 | self.loss = LPIPSWithDiscriminator(**lossconfig) 25 | assert ddconfig["double_z"] 26 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 27 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 28 | self.embed_dim = embed_dim 29 | 30 | @torch.no_grad() 31 | def encode(self, x, out_posterior=False): 32 | return self.encode_trainable(x, out_posterior) 33 | 34 | def encode_trainable(self, x, out_posterior=False): 35 | x = x*2-1 36 | h = self.encoder(x) 37 | moments = self.quant_conv(h) 38 | posterior = DiagonalGaussianDistribution(moments) 39 | if out_posterior: 40 | return posterior 41 | else: 42 | return posterior.sample() 43 | 44 | @torch.no_grad() 45 | def decode(self, z): 46 | dec = self.decode_trainable(z) 47 | dec = torch.clamp(dec, 0, 1) 48 | return dec 49 | 50 | def decode_trainable(self, z): 51 | z = self.post_quant_conv(z) 52 | dec = self.decoder(z) 53 | dec = (dec+1)/2 54 | return dec 55 | 56 | def apply_model(self, input, sample_posterior=True): 57 | posterior = self.encode_trainable(input, out_posterior=True) 58 | if sample_posterior: 59 | z = posterior.sample() 60 | else: 61 | z = posterior.mode() 62 | dec = self.decode_trainable(z) 63 | return dec, posterior 64 | 65 | def get_input(self, batch, k): 66 | x = batch[k] 67 | if len(x.shape) == 3: 68 | x = x[..., None] 69 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 70 | return x 71 | 72 | def forward(self, x, optimizer_idx, global_step): 73 | reconstructions, posterior = self.apply_model(x) 74 | 75 | if optimizer_idx == 0: 76 | # train encoder+decoder+logvar 77 | aeloss, log_dict_ae = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step, 78 | last_layer=self.get_last_layer(), split="train") 79 | return aeloss, log_dict_ae 80 | 81 | if optimizer_idx == 1: 82 | # train the discriminator 83 | discloss, log_dict_disc = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step, 84 | last_layer=self.get_last_layer(), split="train") 85 | 86 | return discloss, log_dict_disc 87 | 88 | def validation_step(self, batch, batch_idx): 89 | inputs = self.get_input(batch, self.image_key) 90 | reconstructions, posterior = self(inputs) 91 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 92 | last_layer=self.get_last_layer(), split="val") 93 | 94 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 95 | last_layer=self.get_last_layer(), split="val") 96 | 97 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 98 | self.log_dict(log_dict_ae) 99 | self.log_dict(log_dict_disc) 100 | return self.log_dict 101 | 102 | def configure_optimizers(self): 103 | lr = self.learning_rate 104 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 105 | list(self.decoder.parameters())+ 106 | list(self.quant_conv.parameters())+ 107 | list(self.post_quant_conv.parameters()), 108 | lr=lr, betas=(0.5, 0.9)) 109 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 110 | lr=lr, betas=(0.5, 0.9)) 111 | return [opt_ae, opt_disc], [] 112 | 113 | def get_last_layer(self): 114 | return self.decoder.conv_out.weight 115 | 116 | @torch.no_grad() 117 | def log_images(self, batch, only_inputs=False, **kwargs): 118 | log = dict() 119 | x = self.get_input(batch, self.image_key) 120 | x = x.to(self.device) 121 | if not only_inputs: 122 | xrec, posterior = self(x) 123 | if x.shape[1] > 3: 124 | # colorize with random projection 125 | assert xrec.shape[1] > 3 126 | x = self.to_rgb(x) 127 | xrec = self.to_rgb(xrec) 128 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 129 | log["reconstructions"] = xrec 130 | log["inputs"] = x 131 | return log 132 | 133 | def to_rgb(self, x): 134 | assert self.image_key == "segmentation" 135 | if not hasattr(self, "colorize"): 136 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 137 | x = F.conv2d(x, weight=self.colorize) 138 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 139 | return x 140 | 141 | @register('autoencoderkl_customnorm') 142 | class AutoencoderKL_CustomNorm(AutoencoderKL): 143 | def __init__(self, *args, **kwargs): 144 | super().__init__(*args, **kwargs) 145 | self.mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]) 146 | self.std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]) 147 | 148 | def encode_trainable(self, x, out_posterior=False): 149 | m = self.mean[None, :, None, None].to(z.device).to(z.dtype) 150 | s = self.std[None, :, None, None].to(z.device).to(z.dtype) 151 | x = (x-m)/s 152 | h = self.encoder(x) 153 | moments = self.quant_conv(h) 154 | posterior = DiagonalGaussianDistribution(moments) 155 | if out_posterior: 156 | return posterior 157 | else: 158 | return posterior.sample() 159 | 160 | def decode_trainable(self, z): 161 | m = self.mean[None, :, None, None].to(z.device).to(z.dtype) 162 | s = self.std[None, :, None, None].to(z.device).to(z.dtype) 163 | z = self.post_quant_conv(z) 164 | dec = self.decoder(z) 165 | dec = (dec+1)/2 166 | return dec 167 | -------------------------------------------------------------------------------- /lib/model_zoo/common/get_model.py: -------------------------------------------------------------------------------- 1 | from email.policy import strict 2 | import torch 3 | import torchvision.models 4 | import os.path as osp 5 | import copy 6 | from ...log_service import print_log 7 | from .utils import \ 8 | get_total_param, get_total_param_sum, \ 9 | get_unit 10 | 11 | # def load_state_dict(net, model_path): 12 | # if isinstance(net, dict): 13 | # for ni, neti in net.items(): 14 | # paras = torch.load(model_path[ni], map_location=torch.device('cpu')) 15 | # new_paras = neti.state_dict() 16 | # new_paras.update(paras) 17 | # neti.load_state_dict(new_paras) 18 | # else: 19 | # paras = torch.load(model_path, map_location=torch.device('cpu')) 20 | # new_paras = net.state_dict() 21 | # new_paras.update(paras) 22 | # net.load_state_dict(new_paras) 23 | # return 24 | 25 | # def save_state_dict(net, path): 26 | # if isinstance(net, (torch.nn.DataParallel, 27 | # torch.nn.parallel.DistributedDataParallel)): 28 | # torch.save(net.module.state_dict(), path) 29 | # else: 30 | # torch.save(net.state_dict(), path) 31 | 32 | def singleton(class_): 33 | instances = {} 34 | def getinstance(*args, **kwargs): 35 | if class_ not in instances: 36 | instances[class_] = class_(*args, **kwargs) 37 | return instances[class_] 38 | return getinstance 39 | 40 | def preprocess_model_args(args): 41 | # If args has layer_units, get the corresponding 42 | # units. 43 | # If args get backbone, get the backbone model. 44 | args = copy.deepcopy(args) 45 | if 'layer_units' in args: 46 | layer_units = [ 47 | get_unit()(i) for i in args.layer_units 48 | ] 49 | args.layer_units = layer_units 50 | if 'backbone' in args: 51 | args.backbone = get_model()(args.backbone) 52 | return args 53 | 54 | @singleton 55 | class get_model(object): 56 | def __init__(self): 57 | self.model = {} 58 | 59 | def register(self, model, name): 60 | self.model[name] = model 61 | 62 | def __call__(self, cfg, verbose=True): 63 | """ 64 | Construct model based on the config. 65 | """ 66 | if cfg is None: 67 | return None 68 | 69 | t = cfg.type 70 | 71 | # the register is in each file 72 | if t.find('pfd')==0: 73 | from .. import pfd 74 | elif t=='autoencoderkl': 75 | from .. import autokl 76 | elif (t.find('clip')==0) or (t.find('openclip')==0): 77 | from .. import clip 78 | elif t.find('openai_unet')==0: 79 | from .. import openaimodel 80 | elif t.find('controlnet')==0: 81 | from .. import controlnet 82 | elif t.find('seecoder')==0: 83 | from .. import seecoder 84 | elif t.find('swin')==0: 85 | from .. import swin 86 | 87 | args = preprocess_model_args(cfg.args) 88 | net = self.model[t](**args) 89 | 90 | pretrained = cfg.get('pretrained', None) 91 | if pretrained is None: # backward compatible 92 | pretrained = cfg.get('pth', None) 93 | map_location = cfg.get('map_location', 'cpu') 94 | strict_sd = cfg.get('strict_sd', True) 95 | 96 | if pretrained is not None: 97 | if osp.splitext(pretrained)[1] == '.pth': 98 | sd = torch.load(pretrained, map_location=map_location) 99 | elif osp.splitext(pretrained)[1] == '.ckpt': 100 | sd = torch.load(pretrained, map_location=map_location)['state_dict'] 101 | elif osp.splitext(pretrained)[1] == '.safetensors': 102 | from safetensors.torch import load_file 103 | from collections import OrderedDict 104 | sd = load_file(pretrained, map_location) 105 | sd = OrderedDict(sd) 106 | net.load_state_dict(sd, strict=strict_sd) 107 | if verbose: 108 | print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd)) 109 | 110 | # display param_num & param_sum 111 | if verbose: 112 | print_log( 113 | 'Load {} with total {} parameters,' 114 | '{:.3f} parameter sum.'.format( 115 | t, 116 | get_total_param(net), 117 | get_total_param_sum(net) )) 118 | return net 119 | 120 | def register(name): 121 | def wrapper(class_): 122 | get_model().register(class_, name) 123 | return class_ 124 | return wrapper 125 | -------------------------------------------------------------------------------- /lib/model_zoo/common/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import itertools 5 | 6 | def singleton(class_): 7 | instances = {} 8 | def getinstance(*args, **kwargs): 9 | if class_ not in instances: 10 | instances[class_] = class_(*args, **kwargs) 11 | return instances[class_] 12 | return getinstance 13 | 14 | class get_optimizer(object): 15 | def __init__(self): 16 | self.optimizer = {} 17 | self.register(optim.SGD, 'sgd') 18 | self.register(optim.Adam, 'adam') 19 | self.register(optim.AdamW, 'adamw') 20 | 21 | def register(self, optim, name): 22 | self.optimizer[name] = optim 23 | 24 | def __call__(self, net, cfg): 25 | if cfg is None: 26 | return None 27 | t = cfg.type 28 | if isinstance(net, (torch.nn.DataParallel, 29 | torch.nn.parallel.DistributedDataParallel)): 30 | netm = net.module 31 | else: 32 | netm = net 33 | pg = getattr(netm, 'parameter_group', None) 34 | 35 | if pg is not None: 36 | params = [] 37 | for group_name, module_or_para in pg.items(): 38 | if not isinstance(module_or_para, list): 39 | module_or_para = [module_or_para] 40 | 41 | grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] 42 | grouped_params = itertools.chain(*grouped_params) 43 | pg_dict = {'params':grouped_params, 'name':group_name} 44 | params.append(pg_dict) 45 | else: 46 | params = net.parameters() 47 | return self.optimizer[t](params, lr=0, **cfg.args) 48 | -------------------------------------------------------------------------------- /lib/model_zoo/common/get_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import copy 5 | from ... import sync 6 | from ...cfg_holder import cfg_unique_holder as cfguh 7 | 8 | def singleton(class_): 9 | instances = {} 10 | def getinstance(*args, **kwargs): 11 | if class_ not in instances: 12 | instances[class_] = class_(*args, **kwargs) 13 | return instances[class_] 14 | return getinstance 15 | 16 | @singleton 17 | class get_scheduler(object): 18 | def __init__(self): 19 | self.lr_scheduler = {} 20 | 21 | def register(self, lrsf, name): 22 | self.lr_scheduler[name] = lrsf 23 | 24 | def __call__(self, cfg): 25 | if cfg is None: 26 | return None 27 | if isinstance(cfg, list): 28 | schedulers = [] 29 | for ci in cfg: 30 | t = ci.type 31 | schedulers.append( 32 | self.lr_scheduler[t](**ci.args)) 33 | if len(schedulers) == 0: 34 | raise ValueError 35 | else: 36 | return compose_scheduler(schedulers) 37 | t = cfg.type 38 | return self.lr_scheduler[t](**cfg.args) 39 | 40 | 41 | def register(name): 42 | def wrapper(class_): 43 | get_scheduler().register(class_, name) 44 | return class_ 45 | return wrapper 46 | 47 | class template_scheduler(object): 48 | def __init__(self, step): 49 | self.step = step 50 | 51 | def __getitem__(self, idx): 52 | raise ValueError 53 | 54 | def set_lr(self, optim, new_lr, pg_lrscale=None): 55 | """ 56 | Set Each parameter_groups in optim with new_lr 57 | New_lr can be find according to the idx. 58 | pg_lrscale tells how to scale each pg. 59 | """ 60 | # new_lr = self.__getitem__(idx) 61 | pg_lrscale = copy.deepcopy(pg_lrscale) 62 | for pg in optim.param_groups: 63 | if pg_lrscale is None: 64 | pg['lr'] = new_lr 65 | else: 66 | pg['lr'] = new_lr * pg_lrscale.pop(pg['name']) 67 | assert (pg_lrscale is None) or (len(pg_lrscale)==0), \ 68 | "pg_lrscale doesn't match pg" 69 | 70 | @register('constant') 71 | class constant_scheduler(template_scheduler): 72 | def __init__(self, lr, step): 73 | super().__init__(step) 74 | self.lr = lr 75 | 76 | def __getitem__(self, idx): 77 | if idx >= self.step: 78 | raise ValueError 79 | return self.lr 80 | 81 | @register('poly') 82 | class poly_scheduler(template_scheduler): 83 | def __init__(self, start_lr, end_lr, power, step): 84 | super().__init__(step) 85 | self.start_lr = start_lr 86 | self.end_lr = end_lr 87 | self.power = power 88 | 89 | def __getitem__(self, idx): 90 | if idx >= self.step: 91 | raise ValueError 92 | a, b = self.start_lr, self.end_lr 93 | p, n = self.power, self.step 94 | return b + (a-b)*((1-idx/n)**p) 95 | 96 | @register('linear') 97 | class linear_scheduler(template_scheduler): 98 | def __init__(self, start_lr, end_lr, step): 99 | super().__init__(step) 100 | self.start_lr = start_lr 101 | self.end_lr = end_lr 102 | 103 | def __getitem__(self, idx): 104 | if idx >= self.step: 105 | raise ValueError 106 | a, b, n = self.start_lr, self.end_lr, self.step 107 | return b + (a-b)*(1-idx/n) 108 | 109 | @register('multistage') 110 | class constant_scheduler(template_scheduler): 111 | def __init__(self, start_lr, milestones, gamma, step): 112 | super().__init__(step) 113 | self.start_lr = start_lr 114 | m = [0] + milestones + [step] 115 | lr_iter = start_lr 116 | self.lr = [] 117 | for ms, me in zip(m[0:-1], m[1:]): 118 | for _ in range(ms, me): 119 | self.lr.append(lr_iter) 120 | lr_iter *= gamma 121 | 122 | def __getitem__(self, idx): 123 | if idx >= self.step: 124 | raise ValueError 125 | return self.lr[idx] 126 | 127 | class compose_scheduler(template_scheduler): 128 | def __init__(self, schedulers): 129 | self.schedulers = schedulers 130 | self.step = [si.step for si in schedulers] 131 | self.step_milestone = [] 132 | acc = 0 133 | for i in self.step: 134 | acc += i 135 | self.step_milestone.append(acc) 136 | self.step = sum(self.step) 137 | 138 | def __getitem__(self, idx): 139 | if idx >= self.step: 140 | raise ValueError 141 | ms = self.step_milestone 142 | for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])): 143 | if mi <= idx < mj: 144 | return self.schedulers[idx-mi] 145 | raise ValueError 146 | 147 | #################### 148 | # lambda schedular # 149 | #################### 150 | 151 | class LambdaWarmUpCosineScheduler(template_scheduler): 152 | """ 153 | note: use with a base_lr of 1.0 154 | """ 155 | def __init__(self, 156 | base_lr, 157 | warm_up_steps, 158 | lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 159 | cfgt = cfguh().cfg.train 160 | bs = cfgt.batch_size 161 | if 'gradacc_every' not in cfgt: 162 | print('Warning, gradacc_every is not found in xml, use 1 as default.') 163 | acc = cfgt.get('gradacc_every', 1) 164 | self.lr_multi = base_lr * bs * acc 165 | self.lr_warm_up_steps = warm_up_steps 166 | self.lr_start = lr_start 167 | self.lr_min = lr_min 168 | self.lr_max = lr_max 169 | self.lr_max_decay_steps = max_decay_steps 170 | self.last_lr = 0. 171 | self.verbosity_interval = verbosity_interval 172 | 173 | def schedule(self, n): 174 | if self.verbosity_interval > 0: 175 | if n % self.verbosity_interval == 0: 176 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 177 | if n < self.lr_warm_up_steps: 178 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 179 | self.last_lr = lr 180 | return lr 181 | else: 182 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 183 | t = min(t, 1.0) 184 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 185 | 1 + np.cos(t * np.pi)) 186 | self.last_lr = lr 187 | return lr 188 | 189 | def __getitem__(self, idx): 190 | return self.schedule(idx) * self.lr_multi 191 | 192 | class LambdaWarmUpCosineScheduler2(template_scheduler): 193 | """ 194 | supports repeated iterations, configurable via lists 195 | note: use with a base_lr of 1.0. 196 | """ 197 | def __init__(self, 198 | base_lr, 199 | warm_up_steps, 200 | f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 201 | cfgt = cfguh().cfg.train 202 | # bs = cfgt.batch_size 203 | # if 'gradacc_every' not in cfgt: 204 | # print('Warning, gradacc_every is not found in xml, use 1 as default.') 205 | # acc = cfgt.get('gradacc_every', 1) 206 | # self.lr_multi = base_lr * bs * acc 207 | self.lr_multi = base_lr 208 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 209 | self.lr_warm_up_steps = warm_up_steps 210 | self.f_start = f_start 211 | self.f_min = f_min 212 | self.f_max = f_max 213 | self.cycle_lengths = cycle_lengths 214 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 215 | self.last_f = 0. 216 | self.verbosity_interval = verbosity_interval 217 | 218 | def find_in_interval(self, n): 219 | interval = 0 220 | for cl in self.cum_cycles[1:]: 221 | if n <= cl: 222 | return interval 223 | interval += 1 224 | 225 | def schedule(self, n): 226 | cycle = self.find_in_interval(n) 227 | n = n - self.cum_cycles[cycle] 228 | if self.verbosity_interval > 0: 229 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 230 | f"current cycle {cycle}") 231 | if n < self.lr_warm_up_steps[cycle]: 232 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 233 | self.last_f = f 234 | return f 235 | else: 236 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 237 | t = min(t, 1.0) 238 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 239 | 1 + np.cos(t * np.pi)) 240 | self.last_f = f 241 | return f 242 | 243 | def __getitem__(self, idx): 244 | return self.schedule(idx) * self.lr_multi 245 | 246 | @register('stable_diffusion_linear') 247 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 248 | def schedule(self, n): 249 | cycle = self.find_in_interval(n) 250 | n = n - self.cum_cycles[cycle] 251 | if self.verbosity_interval > 0: 252 | if n % self.verbosity_interval == 0: 253 | print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 254 | f"current cycle {cycle}") 255 | if n < self.lr_warm_up_steps[cycle]: 256 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 257 | self.last_f = f 258 | return f 259 | else: 260 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 261 | self.last_f = f 262 | return f -------------------------------------------------------------------------------- /lib/model_zoo/common/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import functools 7 | import itertools 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | ######## 12 | # unit # 13 | ######## 14 | 15 | def singleton(class_): 16 | instances = {} 17 | def getinstance(*args, **kwargs): 18 | if class_ not in instances: 19 | instances[class_] = class_(*args, **kwargs) 20 | return instances[class_] 21 | return getinstance 22 | 23 | def str2value(v): 24 | v = v.strip() 25 | try: 26 | return int(v) 27 | except: 28 | pass 29 | try: 30 | return float(v) 31 | except: 32 | pass 33 | if v in ('True', 'true'): 34 | return True 35 | elif v in ('False', 'false'): 36 | return False 37 | else: 38 | return v 39 | 40 | @singleton 41 | class get_unit(object): 42 | def __init__(self): 43 | self.unit = {} 44 | self.register('none', None) 45 | 46 | # general convolution 47 | self.register('conv' , nn.Conv2d) 48 | self.register('bn' , nn.BatchNorm2d) 49 | self.register('relu' , nn.ReLU) 50 | self.register('relu6' , nn.ReLU6) 51 | self.register('lrelu' , nn.LeakyReLU) 52 | self.register('dropout' , nn.Dropout) 53 | self.register('dropout2d', nn.Dropout2d) 54 | self.register('sine', Sine) 55 | self.register('relusine', ReLUSine) 56 | 57 | def register(self, 58 | name, 59 | unitf,): 60 | 61 | self.unit[name] = unitf 62 | 63 | def __call__(self, name): 64 | if name is None: 65 | return None 66 | i = name.find('(') 67 | i = len(name) if i==-1 else i 68 | t = name[:i] 69 | f = self.unit[t] 70 | args = name[i:].strip('()') 71 | if len(args) == 0: 72 | args = {} 73 | return f 74 | else: 75 | args = args.split('=') 76 | args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args] 77 | args = list(itertools.chain.from_iterable(args)) 78 | args = [i.strip() for i in args if len(i)>0] 79 | kwargs = {} 80 | for k, v in zip(args[::2], args[1::2]): 81 | if v[0]=='(' and v[-1]==')': 82 | kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')]) 83 | elif v[0]=='[' and v[-1]==']': 84 | kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')] 85 | else: 86 | kwargs[k] = str2value(v) 87 | return functools.partial(f, **kwargs) 88 | 89 | def register(name): 90 | def wrapper(class_): 91 | get_unit().register(name, class_) 92 | return class_ 93 | return wrapper 94 | 95 | class Sine(object): 96 | def __init__(self, freq, gain=1): 97 | self.freq = freq 98 | self.gain = gain 99 | self.repr = 'sine(freq={}, gain={})'.format(freq, gain) 100 | 101 | def __call__(self, x, gain=1): 102 | act_gain = self.gain * gain 103 | return torch.sin(self.freq * x) * act_gain 104 | 105 | def __repr__(self,): 106 | return self.repr 107 | 108 | class ReLUSine(nn.Module): 109 | def __init(self): 110 | super().__init__() 111 | 112 | def forward(self, input): 113 | a = torch.sin(30 * input) 114 | b = nn.ReLU(inplace=False)(input) 115 | return a+b 116 | 117 | @register('lrelu_agc') 118 | # class lrelu_agc(nn.Module): 119 | class lrelu_agc(object): 120 | """ 121 | The lrelu layer with alpha, gain and clamp 122 | """ 123 | def __init__(self, alpha=0.1, gain=1, clamp=None): 124 | # super().__init__() 125 | self.alpha = alpha 126 | if gain == 'sqrt_2': 127 | self.gain = np.sqrt(2) 128 | else: 129 | self.gain = gain 130 | self.clamp = clamp 131 | self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format( 132 | alpha, gain, clamp) 133 | 134 | # def forward(self, x, gain=1): 135 | def __call__(self, x, gain=1): 136 | x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True) 137 | act_gain = self.gain * gain 138 | act_clamp = self.clamp * gain if self.clamp is not None else None 139 | if act_gain != 1: 140 | x = x * act_gain 141 | if act_clamp is not None: 142 | x = x.clamp(-act_clamp, act_clamp) 143 | return x 144 | 145 | def __repr__(self,): 146 | return self.repr 147 | 148 | #################### 149 | # spatial encoding # 150 | #################### 151 | 152 | @register('se') 153 | class SpatialEncoding(nn.Module): 154 | def __init__(self, 155 | in_dim, 156 | out_dim, 157 | sigma = 6, 158 | cat_input=True, 159 | require_grad=False,): 160 | 161 | super().__init__() 162 | assert out_dim % (2*in_dim) == 0, "dimension must be dividable" 163 | 164 | n = out_dim // 2 // in_dim 165 | m = 2**np.linspace(0, sigma, n) 166 | m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1) 167 | m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0) 168 | self.emb = torch.FloatTensor(m) 169 | if require_grad: 170 | self.emb = nn.Parameter(self.emb, requires_grad=True) 171 | self.in_dim = in_dim 172 | self.out_dim = out_dim 173 | self.sigma = sigma 174 | self.cat_input = cat_input 175 | self.require_grad = require_grad 176 | 177 | def forward(self, x, format='[n x c]'): 178 | """ 179 | Args: 180 | x: [n x m1], 181 | m1 usually is 2 182 | Outputs: 183 | y: [n x m2] 184 | m2 dimention number 185 | """ 186 | if format == '[bs x c x 2D]': 187 | xshape = x.shape 188 | x = x.permute(0, 2, 3, 1).contiguous() 189 | x = x.view(-1, x.size(-1)) 190 | elif format == '[n x c]': 191 | pass 192 | else: 193 | raise ValueError 194 | 195 | if not self.require_grad: 196 | self.emb = self.emb.to(x.device) 197 | y = torch.mm(x, self.emb.T) 198 | if self.cat_input: 199 | z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1) 200 | else: 201 | z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1) 202 | 203 | if format == '[bs x c x 2D]': 204 | z = z.view(xshape[0], xshape[2], xshape[3], -1) 205 | z = z.permute(0, 3, 1, 2).contiguous() 206 | return z 207 | 208 | def extra_repr(self): 209 | outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format( 210 | self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad) 211 | return outstr 212 | 213 | @register('rffe') 214 | class RFFEncoding(SpatialEncoding): 215 | """ 216 | Random Fourier Features 217 | """ 218 | def __init__(self, 219 | in_dim, 220 | out_dim, 221 | sigma = 6, 222 | cat_input=True, 223 | require_grad=False,): 224 | 225 | super().__init__(in_dim, out_dim, sigma, cat_input, require_grad) 226 | n = out_dim // 2 227 | m = np.random.normal(0, sigma, size=(n, in_dim)) 228 | self.emb = torch.FloatTensor(m) 229 | if require_grad: 230 | self.emb = nn.Parameter(self.emb, requires_grad=True) 231 | 232 | def extra_repr(self): 233 | outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format( 234 | self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad) 235 | return outstr 236 | 237 | ########## 238 | # helper # 239 | ########## 240 | 241 | def freeze(net): 242 | for m in net.modules(): 243 | if isinstance(m, ( 244 | nn.BatchNorm2d, 245 | nn.SyncBatchNorm,)): 246 | # inplace_abn not supported 247 | m.eval() 248 | for pi in net.parameters(): 249 | pi.requires_grad = False 250 | return net 251 | 252 | def common_init(m): 253 | if isinstance(m, ( 254 | nn.Conv2d, 255 | nn.ConvTranspose2d,)): 256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 257 | if m.bias is not None: 258 | nn.init.constant_(m.bias, 0) 259 | elif isinstance(m, ( 260 | nn.BatchNorm2d, 261 | nn.SyncBatchNorm,)): 262 | nn.init.constant_(m.weight, 1) 263 | nn.init.constant_(m.bias, 0) 264 | else: 265 | pass 266 | 267 | def init_module(module): 268 | """ 269 | Args: 270 | module: [nn.module] list or nn.module 271 | a list of module to be initialized. 272 | """ 273 | if isinstance(module, (list, tuple)): 274 | module = list(module) 275 | else: 276 | module = [module] 277 | 278 | for mi in module: 279 | for mii in mi.modules(): 280 | common_init(mii) 281 | 282 | def get_total_param(net): 283 | if getattr(net, 'parameters', None) is None: 284 | return 0 285 | return sum(p.numel() for p in net.parameters()) 286 | 287 | def get_total_param_sum(net): 288 | if getattr(net, 'parameters', None) is None: 289 | return 0 290 | with torch.no_grad(): 291 | s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters()) 292 | return s 293 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def apply_canny(img, low_threshold, high_threshold): 5 | return cv2.Canny(img, low_threshold, high_threshold) 6 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/hed/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0. 2 | # Please use this implementation in your products 3 | # This implementation may produce slightly different results from Saining Xie's official implementations, 4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. 5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR) 6 | # and in this way it works better for gradio's RGB protocol 7 | 8 | import os 9 | import cv2 10 | import torch 11 | import numpy as np 12 | 13 | from einops import rearrange 14 | import os 15 | 16 | models_path = 'pretrained/controlnet/preprocess' 17 | 18 | def safe_step(x, step=2): 19 | y = x.astype(np.float32) * float(step + 1) 20 | y = y.astype(np.int32).astype(np.float32) / float(step) 21 | return y 22 | 23 | class DoubleConvBlock(torch.nn.Module): 24 | def __init__(self, input_channel, output_channel, layer_number): 25 | super().__init__() 26 | self.convs = torch.nn.Sequential() 27 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 28 | for i in range(1, layer_number): 29 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 30 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) 31 | 32 | def __call__(self, x, down_sampling=False): 33 | h = x 34 | if down_sampling: 35 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 36 | for conv in self.convs: 37 | h = conv(h) 38 | h = torch.nn.functional.relu(h) 39 | return h, self.projection(h) 40 | 41 | 42 | class ControlNetHED_Apache2(torch.nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 46 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 47 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 48 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 49 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 50 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 51 | 52 | def __call__(self, x): 53 | h = x - self.norm 54 | h, projection1 = self.block1(h) 55 | h, projection2 = self.block2(h, down_sampling=True) 56 | h, projection3 = self.block3(h, down_sampling=True) 57 | h, projection4 = self.block4(h, down_sampling=True) 58 | h, projection5 = self.block5(h, down_sampling=True) 59 | return projection1, projection2, projection3, projection4, projection5 60 | 61 | 62 | netNetwork = None 63 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth" 64 | modeldir = os.path.join(models_path, "hed") 65 | old_modeldir = os.path.dirname(os.path.realpath(__file__)) 66 | 67 | 68 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 69 | """Load file form http url, will download models if necessary. 70 | 71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 72 | 73 | Args: 74 | url (str): URL to be downloaded. 75 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 76 | Default: None. 77 | progress (bool): Whether to show the download progress. Default: True. 78 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 79 | 80 | Returns: 81 | str: The path to the downloaded file. 82 | """ 83 | from torch.hub import download_url_to_file, get_dir 84 | from urllib.parse import urlparse 85 | if model_dir is None: # use the pytorch hub_dir 86 | hub_dir = get_dir() 87 | model_dir = os.path.join(hub_dir, 'checkpoints') 88 | 89 | os.makedirs(model_dir, exist_ok=True) 90 | 91 | parts = urlparse(url) 92 | filename = os.path.basename(parts.path) 93 | if file_name is not None: 94 | filename = file_name 95 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 96 | if not os.path.exists(cached_file): 97 | print(f'Downloading: "{url}" to {cached_file}\n') 98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 99 | return cached_file 100 | 101 | 102 | def apply_hed(input_image, is_safe=False, device='cpu'): 103 | global netNetwork 104 | if netNetwork is None: 105 | modelpath = os.path.join(modeldir, "ControlNetHED.pth") 106 | old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth") 107 | if os.path.exists(old_modelpath): 108 | modelpath = old_modelpath 109 | elif not os.path.exists(modelpath): 110 | load_file_from_url(remote_model_path, model_dir=modeldir) 111 | netNetwork = ControlNetHED_Apache2().to(device) 112 | netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu')) 113 | netNetwork.to(device).float().eval() 114 | 115 | assert input_image.ndim == 3 116 | H, W, C = input_image.shape 117 | with torch.no_grad(): 118 | image_hed = torch.from_numpy(input_image.copy()).float().to(device) 119 | image_hed = rearrange(image_hed, 'h w c -> 1 c h w') 120 | edges = netNetwork(image_hed) 121 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 122 | edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] 123 | edges = np.stack(edges, axis=2) 124 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 125 | if is_safe: 126 | edge = safe_step(edge) 127 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 128 | return edge 129 | 130 | 131 | def unload_hed_model(): 132 | global netNetwork 133 | if netNetwork is not None: 134 | netNetwork.cpu() 135 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | from einops import rearrange 6 | from .api import MiDaSInference 7 | 8 | model = None 9 | 10 | def unload_midas_model(): 11 | global model 12 | if model is not None: 13 | model = model.cpu() 14 | 15 | def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1, device='cpu'): 16 | global model 17 | if model is None: 18 | model = MiDaSInference(model_type="dpt_hybrid") 19 | model = model.to(device) 20 | 21 | assert input_image.ndim == 3 22 | image_depth = input_image 23 | with torch.no_grad(): 24 | image_depth = torch.from_numpy(image_depth).float() 25 | image_depth = image_depth.to(device) 26 | image_depth = image_depth / 127.5 - 1.0 27 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 28 | depth = model(image_depth)[0] 29 | 30 | depth_pt = depth.clone() 31 | depth_pt -= torch.min(depth_pt) 32 | depth_pt /= torch.max(depth_pt) 33 | depth_pt = depth_pt.cpu().numpy() 34 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 35 | 36 | depth_np = depth.cpu().numpy() 37 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) 38 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) 39 | z = np.ones_like(x) * a 40 | x[depth_pt < bg_th] = 0 41 | y[depth_pt < bg_th] = 0 42 | normal = np.stack([x, y, z], axis=2) 43 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 44 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] 45 | 46 | return depth_image, normal_image 47 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | models_path = 'pretrained/controlnet/preprocess' 8 | 9 | from torchvision.transforms import Compose 10 | 11 | from .midas.dpt_depth import DPTDepthModel 12 | from .midas.midas_net import MidasNet 13 | from .midas.midas_net_custom import MidasNet_small 14 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet 15 | 16 | base_model_path = os.path.join(models_path, "midas") 17 | old_modeldir = os.path.dirname(os.path.realpath(__file__)) 18 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" 19 | 20 | ISL_PATHS = { 21 | "dpt_large": os.path.join(base_model_path, "dpt_large-midas-2f21e586.pt"), 22 | "dpt_hybrid": os.path.join(base_model_path, "dpt_hybrid-midas-501f0c75.pt"), 23 | "midas_v21": "", 24 | "midas_v21_small": "", 25 | } 26 | 27 | OLD_ISL_PATHS = { 28 | "dpt_large": os.path.join(old_modeldir, "dpt_large-midas-2f21e586.pt"), 29 | "dpt_hybrid": os.path.join(old_modeldir, "dpt_hybrid-midas-501f0c75.pt"), 30 | "midas_v21": "", 31 | "midas_v21_small": "", 32 | } 33 | 34 | 35 | def disabled_train(self, mode=True): 36 | """Overwrite model.train with this function to make sure train/eval mode 37 | does not change anymore.""" 38 | return self 39 | 40 | 41 | def load_midas_transform(model_type): 42 | # https://github.com/isl-org/MiDaS/blob/master/run.py 43 | # load transform only 44 | if model_type == "dpt_large": # DPT-Large 45 | net_w, net_h = 384, 384 46 | resize_mode = "minimal" 47 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 48 | 49 | elif model_type == "dpt_hybrid": # DPT-Hybrid 50 | net_w, net_h = 384, 384 51 | resize_mode = "minimal" 52 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 53 | 54 | elif model_type == "midas_v21": 55 | net_w, net_h = 384, 384 56 | resize_mode = "upper_bound" 57 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 58 | 59 | elif model_type == "midas_v21_small": 60 | net_w, net_h = 256, 256 61 | resize_mode = "upper_bound" 62 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 63 | 64 | else: 65 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 66 | 67 | transform = Compose( 68 | [ 69 | Resize( 70 | net_w, 71 | net_h, 72 | resize_target=None, 73 | keep_aspect_ratio=True, 74 | ensure_multiple_of=32, 75 | resize_method=resize_mode, 76 | image_interpolation_method=cv2.INTER_CUBIC, 77 | ), 78 | normalization, 79 | PrepareForNet(), 80 | ] 81 | ) 82 | 83 | return transform 84 | 85 | 86 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 87 | """Load file form http url, will download models if necessary. 88 | 89 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 90 | 91 | Args: 92 | url (str): URL to be downloaded. 93 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 94 | Default: None. 95 | progress (bool): Whether to show the download progress. Default: True. 96 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 97 | 98 | Returns: 99 | str: The path to the downloaded file. 100 | """ 101 | from torch.hub import download_url_to_file, get_dir 102 | from urllib.parse import urlparse 103 | if model_dir is None: # use the pytorch hub_dir 104 | hub_dir = get_dir() 105 | model_dir = os.path.join(hub_dir, 'checkpoints') 106 | 107 | os.makedirs(model_dir, exist_ok=True) 108 | 109 | parts = urlparse(url) 110 | filename = os.path.basename(parts.path) 111 | if file_name is not None: 112 | filename = file_name 113 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 114 | if not os.path.exists(cached_file): 115 | print(f'Downloading: "{url}" to {cached_file}\n') 116 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 117 | return cached_file 118 | 119 | 120 | def load_model(model_type): 121 | # https://github.com/isl-org/MiDaS/blob/master/run.py 122 | # load network 123 | model_path = ISL_PATHS[model_type] 124 | old_model_path = OLD_ISL_PATHS[model_type] 125 | if model_type == "dpt_large": # DPT-Large 126 | model = DPTDepthModel( 127 | path=model_path, 128 | backbone="vitl16_384", 129 | non_negative=True, 130 | ) 131 | net_w, net_h = 384, 384 132 | resize_mode = "minimal" 133 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 134 | 135 | elif model_type == "dpt_hybrid": # DPT-Hybrid 136 | if os.path.exists(old_model_path): 137 | model_path = old_model_path 138 | elif not os.path.exists(model_path): 139 | load_file_from_url(remote_model_path, model_dir=base_model_path) 140 | 141 | model = DPTDepthModel( 142 | path=model_path, 143 | backbone="vitb_rn50_384", 144 | non_negative=True, 145 | ) 146 | net_w, net_h = 384, 384 147 | resize_mode = "minimal" 148 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 149 | 150 | elif model_type == "midas_v21": 151 | model = MidasNet(model_path, non_negative=True) 152 | net_w, net_h = 384, 384 153 | resize_mode = "upper_bound" 154 | normalization = NormalizeImage( 155 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 156 | ) 157 | 158 | elif model_type == "midas_v21_small": 159 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 160 | non_negative=True, blocks={'expand': True}) 161 | net_w, net_h = 256, 256 162 | resize_mode = "upper_bound" 163 | normalization = NormalizeImage( 164 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 165 | ) 166 | 167 | else: 168 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 169 | assert False 170 | 171 | transform = Compose( 172 | [ 173 | Resize( 174 | net_w, 175 | net_h, 176 | resize_target=None, 177 | keep_aspect_ratio=True, 178 | ensure_multiple_of=32, 179 | resize_method=resize_mode, 180 | image_interpolation_method=cv2.INTER_CUBIC, 181 | ), 182 | normalization, 183 | PrepareForNet(), 184 | ] 185 | ) 186 | 187 | return model.eval(), transform 188 | 189 | 190 | class MiDaSInference(nn.Module): 191 | MODEL_TYPES_TORCH_HUB = [ 192 | "DPT_Large", 193 | "DPT_Hybrid", 194 | "MiDaS_small" 195 | ] 196 | MODEL_TYPES_ISL = [ 197 | "dpt_large", 198 | "dpt_hybrid", 199 | "midas_v21", 200 | "midas_v21_small", 201 | ] 202 | 203 | def __init__(self, model_type): 204 | super().__init__() 205 | assert (model_type in self.MODEL_TYPES_ISL) 206 | model, _ = load_model(model_type) 207 | self.model = model 208 | self.model.train = disabled_train 209 | 210 | def forward(self, x): 211 | with torch.no_grad(): 212 | prediction = self.model(x) 213 | return prediction 214 | 215 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Prompt-Free-Diffusion/f4295fc7ad80763e2ab562eb9aab06abb979b278/lib/model_zoo/controlnet_annotator/midas/midas/__init__.py -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/mlsd/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021-present NAVER Corp. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/mlsd/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | from einops import rearrange 7 | from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny 8 | from .models.mbv2_mlsd_large import MobileV2_MLSD_Large 9 | from .utils import pred_lines 10 | 11 | models_path = 'pretrained/controlnet/preprocess' 12 | 13 | mlsdmodel = None 14 | remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth" 15 | old_modeldir = os.path.dirname(os.path.realpath(__file__)) 16 | modeldir = os.path.join(models_path, "mlsd") 17 | 18 | def unload_mlsd_model(): 19 | global mlsdmodel 20 | if mlsdmodel is not None: 21 | mlsdmodel = mlsdmodel.cpu() 22 | 23 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 24 | """Load file form http url, will download models if necessary. 25 | 26 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 27 | 28 | Args: 29 | url (str): URL to be downloaded. 30 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 31 | Default: None. 32 | progress (bool): Whether to show the download progress. Default: True. 33 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 34 | 35 | Returns: 36 | str: The path to the downloaded file. 37 | """ 38 | from torch.hub import download_url_to_file, get_dir 39 | from urllib.parse import urlparse 40 | if model_dir is None: # use the pytorch hub_dir 41 | hub_dir = get_dir() 42 | model_dir = os.path.join(hub_dir, 'checkpoints') 43 | 44 | os.makedirs(model_dir, exist_ok=True) 45 | 46 | parts = urlparse(url) 47 | filename = os.path.basename(parts.path) 48 | if file_name is not None: 49 | filename = file_name 50 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 51 | if not os.path.exists(cached_file): 52 | print(f'Downloading: "{url}" to {cached_file}\n') 53 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 54 | return cached_file 55 | 56 | def apply_mlsd(input_image, thr_v, thr_d, device='cpu'): 57 | global modelpath, mlsdmodel 58 | if mlsdmodel is None: 59 | modelpath = os.path.join(modeldir, "mlsd_large_512_fp32.pth") 60 | old_modelpath = os.path.join(old_modeldir, "mlsd_large_512_fp32.pth") 61 | if os.path.exists(old_modelpath): 62 | modelpath = old_modelpath 63 | elif not os.path.exists(modelpath): 64 | load_file_from_url(remote_model_path, model_dir=modeldir) 65 | mlsdmodel = MobileV2_MLSD_Large() 66 | mlsdmodel.load_state_dict(torch.load(modelpath), strict=True) 67 | mlsdmodel = mlsdmodel.to(device).eval() 68 | 69 | model = mlsdmodel 70 | assert input_image.ndim == 3 71 | img = input_image 72 | img_output = np.zeros_like(img) 73 | try: 74 | with torch.no_grad(): 75 | lines = pred_lines(img, model, [img.shape[0], img.shape[1]], thr_v, thr_d) 76 | for line in lines: 77 | x_start, y_start, x_end, y_end = [int(val) for val in line] 78 | cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) 79 | except Exception as e: 80 | pass 81 | return img_output[:, :, 0] 82 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_large.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | 9 | class BlockTypeA(nn.Module): 10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 11 | super(BlockTypeA, self).__init__() 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 14 | nn.BatchNorm2d(out_c2), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 19 | nn.BatchNorm2d(out_c1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.upscale = upscale 23 | 24 | def forward(self, a, b): 25 | b = self.conv1(b) 26 | a = self.conv2(a) 27 | if self.upscale: 28 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 29 | return torch.cat((a, b), dim=1) 30 | 31 | 32 | class BlockTypeB(nn.Module): 33 | def __init__(self, in_c, out_c): 34 | super(BlockTypeB, self).__init__() 35 | self.conv1 = nn.Sequential( 36 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(in_c), 38 | nn.ReLU() 39 | ) 40 | self.conv2 = nn.Sequential( 41 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(out_c), 43 | nn.ReLU() 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) + x 48 | x = self.conv2(x) 49 | return x 50 | 51 | class BlockTypeC(nn.Module): 52 | def __init__(self, in_c, out_c): 53 | super(BlockTypeC, self).__init__() 54 | self.conv1 = nn.Sequential( 55 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 56 | nn.BatchNorm2d(in_c), 57 | nn.ReLU() 58 | ) 59 | self.conv2 = nn.Sequential( 60 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(in_c), 62 | nn.ReLU() 63 | ) 64 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 65 | 66 | def forward(self, x): 67 | x = self.conv1(x) 68 | x = self.conv2(x) 69 | x = self.conv3(x) 70 | return x 71 | 72 | def _make_divisible(v, divisor, min_value=None): 73 | """ 74 | This function is taken from the original tf repo. 75 | It ensures that all layers have a channel number that is divisible by 8 76 | It can be seen here: 77 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 78 | :param v: 79 | :param divisor: 80 | :param min_value: 81 | :return: 82 | """ 83 | if min_value is None: 84 | min_value = divisor 85 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 86 | # Make sure that round down does not go down by more than 10%. 87 | if new_v < 0.9 * v: 88 | new_v += divisor 89 | return new_v 90 | 91 | 92 | class ConvBNReLU(nn.Sequential): 93 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 94 | self.channel_pad = out_planes - in_planes 95 | self.stride = stride 96 | #padding = (kernel_size - 1) // 2 97 | 98 | # TFLite uses slightly different padding than PyTorch 99 | if stride == 2: 100 | padding = 0 101 | else: 102 | padding = (kernel_size - 1) // 2 103 | 104 | super(ConvBNReLU, self).__init__( 105 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 106 | nn.BatchNorm2d(out_planes), 107 | nn.ReLU6(inplace=True) 108 | ) 109 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 110 | 111 | 112 | def forward(self, x): 113 | # TFLite uses different padding 114 | if self.stride == 2: 115 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 116 | #print(x.shape) 117 | 118 | for module in self: 119 | if not isinstance(module, nn.MaxPool2d): 120 | x = module(x) 121 | return x 122 | 123 | 124 | class InvertedResidual(nn.Module): 125 | def __init__(self, inp, oup, stride, expand_ratio): 126 | super(InvertedResidual, self).__init__() 127 | self.stride = stride 128 | assert stride in [1, 2] 129 | 130 | hidden_dim = int(round(inp * expand_ratio)) 131 | self.use_res_connect = self.stride == 1 and inp == oup 132 | 133 | layers = [] 134 | if expand_ratio != 1: 135 | # pw 136 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 137 | layers.extend([ 138 | # dw 139 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 140 | # pw-linear 141 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 142 | nn.BatchNorm2d(oup), 143 | ]) 144 | self.conv = nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | if self.use_res_connect: 148 | return x + self.conv(x) 149 | else: 150 | return self.conv(x) 151 | 152 | 153 | class MobileNetV2(nn.Module): 154 | def __init__(self, pretrained=True): 155 | """ 156 | MobileNet V2 main class 157 | Args: 158 | num_classes (int): Number of classes 159 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 160 | inverted_residual_setting: Network structure 161 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 162 | Set to 1 to turn off rounding 163 | block: Module specifying inverted residual building block for mobilenet 164 | """ 165 | super(MobileNetV2, self).__init__() 166 | 167 | block = InvertedResidual 168 | input_channel = 32 169 | last_channel = 1280 170 | width_mult = 1.0 171 | round_nearest = 8 172 | 173 | inverted_residual_setting = [ 174 | # t, c, n, s 175 | [1, 16, 1, 1], 176 | [6, 24, 2, 2], 177 | [6, 32, 3, 2], 178 | [6, 64, 4, 2], 179 | [6, 96, 3, 1], 180 | #[6, 160, 3, 2], 181 | #[6, 320, 1, 1], 182 | ] 183 | 184 | # only check the first element, assuming user knows t,c,n,s are required 185 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 186 | raise ValueError("inverted_residual_setting should be non-empty " 187 | "or a 4-element list, got {}".format(inverted_residual_setting)) 188 | 189 | # building first layer 190 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 191 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 192 | features = [ConvBNReLU(4, input_channel, stride=2)] 193 | # building inverted residual blocks 194 | for t, c, n, s in inverted_residual_setting: 195 | output_channel = _make_divisible(c * width_mult, round_nearest) 196 | for i in range(n): 197 | stride = s if i == 0 else 1 198 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 199 | input_channel = output_channel 200 | 201 | self.features = nn.Sequential(*features) 202 | self.fpn_selected = [1, 3, 6, 10, 13] 203 | # weight initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 207 | if m.bias is not None: 208 | nn.init.zeros_(m.bias) 209 | elif isinstance(m, nn.BatchNorm2d): 210 | nn.init.ones_(m.weight) 211 | nn.init.zeros_(m.bias) 212 | elif isinstance(m, nn.Linear): 213 | nn.init.normal_(m.weight, 0, 0.01) 214 | nn.init.zeros_(m.bias) 215 | if pretrained: 216 | self._load_pretrained_model() 217 | 218 | def _forward_impl(self, x): 219 | # This exists since TorchScript doesn't support inheritance, so the superclass method 220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 221 | fpn_features = [] 222 | for i, f in enumerate(self.features): 223 | if i > self.fpn_selected[-1]: 224 | break 225 | x = f(x) 226 | if i in self.fpn_selected: 227 | fpn_features.append(x) 228 | 229 | c1, c2, c3, c4, c5 = fpn_features 230 | return c1, c2, c3, c4, c5 231 | 232 | 233 | def forward(self, x): 234 | return self._forward_impl(x) 235 | 236 | def _load_pretrained_model(self): 237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 238 | model_dict = {} 239 | state_dict = self.state_dict() 240 | for k, v in pretrain_dict.items(): 241 | if k in state_dict: 242 | model_dict[k] = v 243 | state_dict.update(model_dict) 244 | self.load_state_dict(state_dict) 245 | 246 | 247 | class MobileV2_MLSD_Large(nn.Module): 248 | def __init__(self): 249 | super(MobileV2_MLSD_Large, self).__init__() 250 | 251 | self.backbone = MobileNetV2(pretrained=False) 252 | ## A, B 253 | self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, 254 | out_c1= 64, out_c2=64, 255 | upscale=False) 256 | self.block16 = BlockTypeB(128, 64) 257 | 258 | ## A, B 259 | self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, 260 | out_c1= 64, out_c2= 64) 261 | self.block18 = BlockTypeB(128, 64) 262 | 263 | ## A, B 264 | self.block19 = BlockTypeA(in_c1=24, in_c2=64, 265 | out_c1=64, out_c2=64) 266 | self.block20 = BlockTypeB(128, 64) 267 | 268 | ## A, B, C 269 | self.block21 = BlockTypeA(in_c1=16, in_c2=64, 270 | out_c1=64, out_c2=64) 271 | self.block22 = BlockTypeB(128, 64) 272 | 273 | self.block23 = BlockTypeC(64, 16) 274 | 275 | def forward(self, x): 276 | c1, c2, c3, c4, c5 = self.backbone(x) 277 | 278 | x = self.block15(c4, c5) 279 | x = self.block16(x) 280 | 281 | x = self.block17(c3, x) 282 | x = self.block18(x) 283 | 284 | x = self.block19(c2, x) 285 | x = self.block20(x) 286 | 287 | x = self.block21(c1, x) 288 | x = self.block22(x) 289 | x = self.block23(x) 290 | x = x[:, 7:, :, :] 291 | 292 | return x -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_tiny.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.nn import functional as F 7 | 8 | 9 | class BlockTypeA(nn.Module): 10 | def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): 11 | super(BlockTypeA, self).__init__() 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_c2, out_c2, kernel_size=1), 14 | nn.BatchNorm2d(out_c2), 15 | nn.ReLU(inplace=True) 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_c1, out_c1, kernel_size=1), 19 | nn.BatchNorm2d(out_c1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | self.upscale = upscale 23 | 24 | def forward(self, a, b): 25 | b = self.conv1(b) 26 | a = self.conv2(a) 27 | b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) 28 | return torch.cat((a, b), dim=1) 29 | 30 | 31 | class BlockTypeB(nn.Module): 32 | def __init__(self, in_c, out_c): 33 | super(BlockTypeB, self).__init__() 34 | self.conv1 = nn.Sequential( 35 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(in_c), 37 | nn.ReLU() 38 | ) 39 | self.conv2 = nn.Sequential( 40 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(out_c), 42 | nn.ReLU() 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.conv1(x) + x 47 | x = self.conv2(x) 48 | return x 49 | 50 | class BlockTypeC(nn.Module): 51 | def __init__(self, in_c, out_c): 52 | super(BlockTypeC, self).__init__() 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), 55 | nn.BatchNorm2d(in_c), 56 | nn.ReLU() 57 | ) 58 | self.conv2 = nn.Sequential( 59 | nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(in_c), 61 | nn.ReLU() 62 | ) 63 | self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.conv2(x) 68 | x = self.conv3(x) 69 | return x 70 | 71 | def _make_divisible(v, divisor, min_value=None): 72 | """ 73 | This function is taken from the original tf repo. 74 | It ensures that all layers have a channel number that is divisible by 8 75 | It can be seen here: 76 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 77 | :param v: 78 | :param divisor: 79 | :param min_value: 80 | :return: 81 | """ 82 | if min_value is None: 83 | min_value = divisor 84 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 85 | # Make sure that round down does not go down by more than 10%. 86 | if new_v < 0.9 * v: 87 | new_v += divisor 88 | return new_v 89 | 90 | 91 | class ConvBNReLU(nn.Sequential): 92 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 93 | self.channel_pad = out_planes - in_planes 94 | self.stride = stride 95 | #padding = (kernel_size - 1) // 2 96 | 97 | # TFLite uses slightly different padding than PyTorch 98 | if stride == 2: 99 | padding = 0 100 | else: 101 | padding = (kernel_size - 1) // 2 102 | 103 | super(ConvBNReLU, self).__init__( 104 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 105 | nn.BatchNorm2d(out_planes), 106 | nn.ReLU6(inplace=True) 107 | ) 108 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 109 | 110 | 111 | def forward(self, x): 112 | # TFLite uses different padding 113 | if self.stride == 2: 114 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 115 | #print(x.shape) 116 | 117 | for module in self: 118 | if not isinstance(module, nn.MaxPool2d): 119 | x = module(x) 120 | return x 121 | 122 | 123 | class InvertedResidual(nn.Module): 124 | def __init__(self, inp, oup, stride, expand_ratio): 125 | super(InvertedResidual, self).__init__() 126 | self.stride = stride 127 | assert stride in [1, 2] 128 | 129 | hidden_dim = int(round(inp * expand_ratio)) 130 | self.use_res_connect = self.stride == 1 and inp == oup 131 | 132 | layers = [] 133 | if expand_ratio != 1: 134 | # pw 135 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 136 | layers.extend([ 137 | # dw 138 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 139 | # pw-linear 140 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 141 | nn.BatchNorm2d(oup), 142 | ]) 143 | self.conv = nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | if self.use_res_connect: 147 | return x + self.conv(x) 148 | else: 149 | return self.conv(x) 150 | 151 | 152 | class MobileNetV2(nn.Module): 153 | def __init__(self, pretrained=True): 154 | """ 155 | MobileNet V2 main class 156 | Args: 157 | num_classes (int): Number of classes 158 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 159 | inverted_residual_setting: Network structure 160 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 161 | Set to 1 to turn off rounding 162 | block: Module specifying inverted residual building block for mobilenet 163 | """ 164 | super(MobileNetV2, self).__init__() 165 | 166 | block = InvertedResidual 167 | input_channel = 32 168 | last_channel = 1280 169 | width_mult = 1.0 170 | round_nearest = 8 171 | 172 | inverted_residual_setting = [ 173 | # t, c, n, s 174 | [1, 16, 1, 1], 175 | [6, 24, 2, 2], 176 | [6, 32, 3, 2], 177 | [6, 64, 4, 2], 178 | #[6, 96, 3, 1], 179 | #[6, 160, 3, 2], 180 | #[6, 320, 1, 1], 181 | ] 182 | 183 | # only check the first element, assuming user knows t,c,n,s are required 184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 185 | raise ValueError("inverted_residual_setting should be non-empty " 186 | "or a 4-element list, got {}".format(inverted_residual_setting)) 187 | 188 | # building first layer 189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 191 | features = [ConvBNReLU(4, input_channel, stride=2)] 192 | # building inverted residual blocks 193 | for t, c, n, s in inverted_residual_setting: 194 | output_channel = _make_divisible(c * width_mult, round_nearest) 195 | for i in range(n): 196 | stride = s if i == 0 else 1 197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 198 | input_channel = output_channel 199 | self.features = nn.Sequential(*features) 200 | 201 | self.fpn_selected = [3, 6, 10] 202 | # weight initialization 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 206 | if m.bias is not None: 207 | nn.init.zeros_(m.bias) 208 | elif isinstance(m, nn.BatchNorm2d): 209 | nn.init.ones_(m.weight) 210 | nn.init.zeros_(m.bias) 211 | elif isinstance(m, nn.Linear): 212 | nn.init.normal_(m.weight, 0, 0.01) 213 | nn.init.zeros_(m.bias) 214 | 215 | #if pretrained: 216 | # self._load_pretrained_model() 217 | 218 | def _forward_impl(self, x): 219 | # This exists since TorchScript doesn't support inheritance, so the superclass method 220 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 221 | fpn_features = [] 222 | for i, f in enumerate(self.features): 223 | if i > self.fpn_selected[-1]: 224 | break 225 | x = f(x) 226 | if i in self.fpn_selected: 227 | fpn_features.append(x) 228 | 229 | c2, c3, c4 = fpn_features 230 | return c2, c3, c4 231 | 232 | 233 | def forward(self, x): 234 | return self._forward_impl(x) 235 | 236 | def _load_pretrained_model(self): 237 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') 238 | model_dict = {} 239 | state_dict = self.state_dict() 240 | for k, v in pretrain_dict.items(): 241 | if k in state_dict: 242 | model_dict[k] = v 243 | state_dict.update(model_dict) 244 | self.load_state_dict(state_dict) 245 | 246 | 247 | class MobileV2_MLSD_Tiny(nn.Module): 248 | def __init__(self): 249 | super(MobileV2_MLSD_Tiny, self).__init__() 250 | 251 | self.backbone = MobileNetV2(pretrained=True) 252 | 253 | self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, 254 | out_c1= 64, out_c2=64) 255 | self.block13 = BlockTypeB(128, 64) 256 | 257 | self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, 258 | out_c1= 32, out_c2= 32) 259 | self.block15 = BlockTypeB(64, 64) 260 | 261 | self.block16 = BlockTypeC(64, 16) 262 | 263 | def forward(self, x): 264 | c2, c3, c4 = self.backbone(x) 265 | 266 | x = self.block12(c3, c4) 267 | x = self.block13(x) 268 | x = self.block14(c2, x) 269 | x = self.block15(x) 270 | x = self.block16(x) 271 | x = x[:, 7:, :, :] 272 | #print(x.shape) 273 | x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) 274 | 275 | return x -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/openpose/LICENSE: -------------------------------------------------------------------------------- 1 | OPENPOSE: MULTIPERSON KEYPOINT DETECTION 2 | SOFTWARE LICENSE AGREEMENT 3 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY 4 | 5 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. 6 | 7 | This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. 8 | 9 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: 10 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, 11 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). 12 | 13 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. 14 | 15 | COPYRIGHT: The Software is owned by Licensor and is protected by United 16 | States copyright laws and applicable international treaties and/or conventions. 17 | 18 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. 19 | 20 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. 21 | 22 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. 23 | 24 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. 25 | 26 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. 27 | 28 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. 29 | 30 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. 31 | 32 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. 33 | 34 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. 35 | 36 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. 37 | 38 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. 39 | 40 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. 41 | 42 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable 43 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. 44 | 45 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. 46 | 47 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. 48 | 49 | GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. 50 | 51 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. 52 | 53 | 54 | 55 | ************************************************************************ 56 | 57 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 58 | 59 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 60 | 61 | 1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) 62 | 63 | COPYRIGHT 64 | 65 | All contributions by the University of California: 66 | Copyright (c) 2014-2017 The Regents of the University of California (Regents) 67 | All rights reserved. 68 | 69 | All other contributions: 70 | Copyright (c) 2014-2017, the respective contributors 71 | All rights reserved. 72 | 73 | Caffe uses a shared copyright model: each contributor holds copyright over 74 | their contributions to Caffe. The project versioning records all such 75 | contribution and copyright details. If a contributor wants to further mark 76 | their specific copyright on a particular contribution, they should indicate 77 | their copyright solely in the commit message of the change when it is 78 | committed. 79 | 80 | LICENSE 81 | 82 | Redistribution and use in source and binary forms, with or without 83 | modification, are permitted provided that the following conditions are met: 84 | 85 | 1. Redistributions of source code must retain the above copyright notice, this 86 | list of conditions and the following disclaimer. 87 | 2. Redistributions in binary form must reproduce the above copyright notice, 88 | this list of conditions and the following disclaimer in the documentation 89 | and/or other materials provided with the distribution. 90 | 91 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 92 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 93 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 94 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 95 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 96 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 97 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 98 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 99 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 100 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 101 | 102 | CONTRIBUTION AGREEMENT 103 | 104 | By contributing to the BVLC/caffe repository through pull-request, comment, 105 | or otherwise, the contributor releases their content to the 106 | license and copyright terms herein. 107 | 108 | ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/openpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Openpose 2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 4 | # 3rd Edited by ControlNet 5 | # 4th Edited by ControlNet (added face and correct hands) 6 | # 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) 7 | # This preprocessor is licensed by CMU for non-commercial use only. 8 | 9 | 10 | import os 11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 12 | 13 | import json 14 | import torch 15 | import numpy as np 16 | from . import util 17 | from .body import Body, BodyResult, Keypoint 18 | from .hand import Hand 19 | from .face import Face 20 | 21 | models_path = "pretrained/controlnet/preprocess" 22 | 23 | from typing import NamedTuple, Tuple, List, Callable, Union 24 | 25 | body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth" 26 | hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth" 27 | face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth" 28 | 29 | HandResult = List[Keypoint] 30 | FaceResult = List[Keypoint] 31 | 32 | class PoseResult(NamedTuple): 33 | body: BodyResult 34 | left_hand: Union[HandResult, None] 35 | right_hand: Union[HandResult, None] 36 | face: Union[FaceResult, None] 37 | 38 | def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): 39 | """ 40 | Draw the detected poses on an empty canvas. 41 | 42 | Args: 43 | poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. 44 | H (int): The height of the canvas. 45 | W (int): The width of the canvas. 46 | draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. 47 | draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. 48 | draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. 49 | 50 | Returns: 51 | numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. 52 | """ 53 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 54 | 55 | for pose in poses: 56 | if draw_body: 57 | canvas = util.draw_bodypose(canvas, pose.body.keypoints) 58 | 59 | if draw_hand: 60 | canvas = util.draw_handpose(canvas, pose.left_hand) 61 | canvas = util.draw_handpose(canvas, pose.right_hand) 62 | 63 | if draw_face: 64 | canvas = util.draw_facepose(canvas, pose.face) 65 | 66 | return canvas 67 | 68 | def encode_poses_as_json(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str: 69 | """ Encode the pose as a JSON string following openpose JSON output format: 70 | https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md 71 | """ 72 | def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]: 73 | if not keypoints: 74 | return None 75 | 76 | return [ 77 | value 78 | for keypoint in keypoints 79 | for value in ( 80 | [float(keypoint.x), float(keypoint.y), 1.0] 81 | if keypoint is not None 82 | else [0.0, 0.0, 0.0] 83 | ) 84 | ] 85 | 86 | return json.dumps({ 87 | 'people': [ 88 | { 89 | 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints), 90 | "face_keypoints_2d": compress_keypoints(pose.face), 91 | "hand_left_keypoints_2d": compress_keypoints(pose.left_hand), 92 | "hand_right_keypoints_2d":compress_keypoints(pose.right_hand), 93 | } 94 | for pose in poses 95 | ], 96 | 'canvas_height': canvas_height, 97 | 'canvas_width': canvas_width, 98 | }, indent=4) 99 | 100 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 101 | """Load file form http url, will download models if necessary. 102 | 103 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 104 | 105 | Args: 106 | url (str): URL to be downloaded. 107 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 108 | Default: None. 109 | progress (bool): Whether to show the download progress. Default: True. 110 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 111 | 112 | Returns: 113 | str: The path to the downloaded file. 114 | """ 115 | from torch.hub import download_url_to_file, get_dir 116 | from urllib.parse import urlparse 117 | if model_dir is None: # use the pytorch hub_dir 118 | hub_dir = get_dir() 119 | model_dir = os.path.join(hub_dir, 'checkpoints') 120 | 121 | os.makedirs(model_dir, exist_ok=True) 122 | 123 | parts = urlparse(url) 124 | filename = os.path.basename(parts.path) 125 | if file_name is not None: 126 | filename = file_name 127 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 128 | if not os.path.exists(cached_file): 129 | print(f'Downloading: "{url}" to {cached_file}\n') 130 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 131 | return cached_file 132 | 133 | class OpenposeDetector: 134 | """ 135 | A class for detecting human poses in images using the Openpose model. 136 | 137 | Attributes: 138 | model_dir (str): Path to the directory where the pose models are stored. 139 | """ 140 | model_dir = os.path.join(models_path, "openpose") 141 | 142 | def __init__(self, device): 143 | self.device = device 144 | self.body_estimation = None 145 | self.hand_estimation = None 146 | self.face_estimation = None 147 | 148 | def load_model(self): 149 | """ 150 | Load the Openpose body, hand, and face models. 151 | """ 152 | body_modelpath = os.path.join(self.model_dir, "body_pose_model.pth") 153 | hand_modelpath = os.path.join(self.model_dir, "hand_pose_model.pth") 154 | face_modelpath = os.path.join(self.model_dir, "facenet.pth") 155 | 156 | if not os.path.exists(body_modelpath): 157 | load_file_from_url(body_model_path, model_dir=self.model_dir) 158 | 159 | if not os.path.exists(hand_modelpath): 160 | load_file_from_url(hand_model_path, model_dir=self.model_dir) 161 | 162 | if not os.path.exists(face_modelpath): 163 | load_file_from_url(face_model_path, model_dir=self.model_dir) 164 | 165 | self.body_estimation = Body(body_modelpath) 166 | self.hand_estimation = Hand(hand_modelpath) 167 | self.face_estimation = Face(face_modelpath) 168 | 169 | def unload_model(self): 170 | """ 171 | Unload the Openpose models by moving them to the CPU. 172 | """ 173 | if self.body_estimation is not None: 174 | self.body_estimation.model.to("cpu") 175 | self.hand_estimation.model.to("cpu") 176 | self.face_estimation.model.to("cpu") 177 | 178 | def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: 179 | left_hand = None 180 | right_hand = None 181 | H, W, _ = oriImg.shape 182 | for x, y, w, is_left in util.handDetect(body, oriImg): 183 | peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) 184 | if peaks.ndim == 2 and peaks.shape[1] == 2: 185 | peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) 186 | peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) 187 | 188 | hand_result = [ 189 | Keypoint(x=peak[0], y=peak[1]) 190 | for peak in peaks 191 | ] 192 | 193 | if is_left: 194 | left_hand = hand_result 195 | else: 196 | right_hand = hand_result 197 | 198 | return left_hand, right_hand 199 | 200 | def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: 201 | face = util.faceDetect(body, oriImg) 202 | if face is None: 203 | return None 204 | 205 | x, y, w = face 206 | H, W, _ = oriImg.shape 207 | heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) 208 | peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) 209 | if peaks.ndim == 2 and peaks.shape[1] == 2: 210 | peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) 211 | peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) 212 | return [ 213 | Keypoint(x=peak[0], y=peak[1]) 214 | for peak in peaks 215 | ] 216 | 217 | return None 218 | 219 | def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: 220 | """ 221 | Detect poses in the given image. 222 | Args: 223 | oriImg (numpy.ndarray): The input image for pose detection. 224 | include_hand (bool, optional): Whether to include hand detection. Defaults to False. 225 | include_face (bool, optional): Whether to include face detection. Defaults to False. 226 | 227 | Returns: 228 | List[PoseResult]: A list of PoseResult objects containing the detected poses. 229 | """ 230 | if self.body_estimation is None: 231 | self.load_model() 232 | 233 | self.body_estimation.model.to(self.device) 234 | self.hand_estimation.model.to(self.device) 235 | self.face_estimation.model.to(self.device) 236 | 237 | self.body_estimation.cn_device = self.device 238 | self.hand_estimation.cn_device = self.device 239 | self.face_estimation.cn_device = self.device 240 | 241 | oriImg = oriImg[:, :, ::-1].copy() 242 | H, W, C = oriImg.shape 243 | with torch.no_grad(): 244 | candidate, subset = self.body_estimation(oriImg) 245 | bodies = self.body_estimation.format_body_result(candidate, subset) 246 | 247 | results = [] 248 | for body in bodies: 249 | left_hand, right_hand, face = (None,) * 3 250 | if include_hand: 251 | left_hand, right_hand = self.detect_hands(body, oriImg) 252 | if include_face: 253 | face = self.detect_face(body, oriImg) 254 | 255 | results.append(PoseResult(BodyResult( 256 | keypoints=[ 257 | Keypoint( 258 | x=keypoint.x / float(W), 259 | y=keypoint.y / float(H) 260 | ) if keypoint is not None else None 261 | for keypoint in body.keypoints 262 | ], 263 | total_score=body.total_score, 264 | total_parts=body.total_parts 265 | ), left_hand, right_hand, face)) 266 | 267 | return results 268 | 269 | def __call__( 270 | self, oriImg, include_body=True, include_hand=False, include_face=False, 271 | json_pose_callback: Callable[[str], None] = None, 272 | ): 273 | """ 274 | Detect and draw poses in the given image. 275 | 276 | Args: 277 | oriImg (numpy.ndarray): The input image for pose detection and drawing. 278 | include_body (bool, optional): Whether to include body keypoints. Defaults to True. 279 | include_hand (bool, optional): Whether to include hand keypoints. Defaults to False. 280 | include_face (bool, optional): Whether to include face keypoints. Defaults to False. 281 | json_pose_callback (Callable, optional): A callback that accepts the pose JSON string. 282 | 283 | Returns: 284 | numpy.ndarray: The image with detected and drawn poses. 285 | """ 286 | H, W, _ = oriImg.shape 287 | poses = self.detect_poses(oriImg, include_hand, include_face) 288 | if json_pose_callback: 289 | json_pose_callback(encode_poses_as_json(poses, H, W)) 290 | return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) 291 | 292 | class OpenposeModel(object): 293 | def __init__(self) -> None: 294 | self.model_openpose = None 295 | 296 | def run_model( 297 | self, 298 | img: np.ndarray, 299 | include_body: bool, 300 | include_hand: bool, 301 | include_face: bool, 302 | json_pose_callback: Callable[[str], None] = None, 303 | device = 'cpu', ): 304 | 305 | if json_pose_callback is None: 306 | json_pose_callback = lambda x: None 307 | 308 | if self.model_openpose is None: 309 | self.model_openpose = OpenposeDetector(device=device) 310 | 311 | return self.model_openpose( 312 | img, 313 | include_body=include_body, 314 | include_hand=include_hand, 315 | include_face=include_face, 316 | json_pose_callback=json_pose_callback) 317 | 318 | def unload(self): 319 | if self.model_openpose is not None: 320 | self.model_openpose.unload_model() 321 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/openpose/hand.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import numpy as np 4 | import math 5 | import time 6 | from scipy.ndimage.filters import gaussian_filter 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | import torch 10 | from skimage.measure import label 11 | 12 | from .model import handpose_model 13 | from . import util 14 | 15 | class Hand(object): 16 | def __init__(self, model_path): 17 | self.model = handpose_model() 18 | # if torch.cuda.is_available(): 19 | # self.model = self.model.cuda() 20 | # print('cuda') 21 | model_dict = util.transfer(self.model, torch.load(model_path)) 22 | self.model.load_state_dict(model_dict) 23 | self.model.eval() 24 | 25 | def __call__(self, oriImgRaw): 26 | scale_search = [0.5, 1.0, 1.5, 2.0] 27 | # scale_search = [0.5] 28 | boxsize = 368 29 | stride = 8 30 | padValue = 128 31 | thre = 0.05 32 | multiplier = [x * boxsize for x in scale_search] 33 | 34 | wsize = 128 35 | heatmap_avg = np.zeros((wsize, wsize, 22)) 36 | 37 | Hr, Wr, Cr = oriImgRaw.shape 38 | 39 | oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) 40 | 41 | for m in range(len(multiplier)): 42 | scale = multiplier[m] 43 | imageToTest = util.smart_resize(oriImg, (scale, scale)) 44 | 45 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 46 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 47 | im = np.ascontiguousarray(im) 48 | 49 | data = torch.from_numpy(im).float() 50 | if torch.cuda.is_available(): 51 | data = data.cuda() 52 | 53 | with torch.no_grad(): 54 | data = data.to(self.cn_device) 55 | output = self.model(data).cpu().numpy() 56 | 57 | # extract outputs, resize, and remove padding 58 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps 59 | heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) 60 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 61 | heatmap = util.smart_resize(heatmap, (wsize, wsize)) 62 | 63 | heatmap_avg += heatmap / len(multiplier) 64 | 65 | all_peaks = [] 66 | for part in range(21): 67 | map_ori = heatmap_avg[:, :, part] 68 | one_heatmap = gaussian_filter(map_ori, sigma=3) 69 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) 70 | 71 | if np.sum(binary) == 0: 72 | all_peaks.append([0, 0]) 73 | continue 74 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) 75 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 76 | label_img[label_img != max_index] = 0 77 | map_ori[label_img == 0] = 0 78 | 79 | y, x = util.npmax(map_ori) 80 | y = int(float(y) * float(Hr) / float(wsize)) 81 | x = int(float(x) * float(Wr) / float(wsize)) 82 | all_peaks.append([x, y]) 83 | return np.array(all_peaks) 84 | 85 | if __name__ == "__main__": 86 | hand_estimation = Hand('../model/hand_pose_model.pth') 87 | 88 | # test_image = '../images/hand.jpg' 89 | test_image = '../images/hand.jpg' 90 | oriImg = cv2.imread(test_image) # B,G,R order 91 | peaks = hand_estimation(oriImg) 92 | canvas = util.draw_handpose(oriImg, peaks, True) 93 | cv2.imshow('', canvas) 94 | cv2.waitKey(0) -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/openpose/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def make_layers(block, no_relu_layers): 8 | layers = [] 9 | for layer_name, v in block.items(): 10 | if 'pool' in layer_name: 11 | layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], 12 | padding=v[2]) 13 | layers.append((layer_name, layer)) 14 | else: 15 | conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], 16 | kernel_size=v[2], stride=v[3], 17 | padding=v[4]) 18 | layers.append((layer_name, conv2d)) 19 | if layer_name not in no_relu_layers: 20 | layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) 21 | 22 | return nn.Sequential(OrderedDict(layers)) 23 | 24 | class bodypose_model(nn.Module): 25 | def __init__(self): 26 | super(bodypose_model, self).__init__() 27 | 28 | # these layers have no relu layer 29 | no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 30 | 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 31 | 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 32 | 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] 33 | blocks = {} 34 | block0 = OrderedDict([ 35 | ('conv1_1', [3, 64, 3, 1, 1]), 36 | ('conv1_2', [64, 64, 3, 1, 1]), 37 | ('pool1_stage1', [2, 2, 0]), 38 | ('conv2_1', [64, 128, 3, 1, 1]), 39 | ('conv2_2', [128, 128, 3, 1, 1]), 40 | ('pool2_stage1', [2, 2, 0]), 41 | ('conv3_1', [128, 256, 3, 1, 1]), 42 | ('conv3_2', [256, 256, 3, 1, 1]), 43 | ('conv3_3', [256, 256, 3, 1, 1]), 44 | ('conv3_4', [256, 256, 3, 1, 1]), 45 | ('pool3_stage1', [2, 2, 0]), 46 | ('conv4_1', [256, 512, 3, 1, 1]), 47 | ('conv4_2', [512, 512, 3, 1, 1]), 48 | ('conv4_3_CPM', [512, 256, 3, 1, 1]), 49 | ('conv4_4_CPM', [256, 128, 3, 1, 1]) 50 | ]) 51 | 52 | 53 | # Stage 1 54 | block1_1 = OrderedDict([ 55 | ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), 56 | ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), 57 | ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), 58 | ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), 59 | ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) 60 | ]) 61 | 62 | block1_2 = OrderedDict([ 63 | ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), 64 | ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), 65 | ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), 66 | ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), 67 | ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) 68 | ]) 69 | blocks['block1_1'] = block1_1 70 | blocks['block1_2'] = block1_2 71 | 72 | self.model0 = make_layers(block0, no_relu_layers) 73 | 74 | # Stages 2 - 6 75 | for i in range(2, 7): 76 | blocks['block%d_1' % i] = OrderedDict([ 77 | ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), 78 | ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), 79 | ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), 80 | ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), 81 | ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), 82 | ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), 83 | ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) 84 | ]) 85 | 86 | blocks['block%d_2' % i] = OrderedDict([ 87 | ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), 88 | ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), 89 | ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), 90 | ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), 91 | ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), 92 | ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), 93 | ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) 94 | ]) 95 | 96 | for k in blocks.keys(): 97 | blocks[k] = make_layers(blocks[k], no_relu_layers) 98 | 99 | self.model1_1 = blocks['block1_1'] 100 | self.model2_1 = blocks['block2_1'] 101 | self.model3_1 = blocks['block3_1'] 102 | self.model4_1 = blocks['block4_1'] 103 | self.model5_1 = blocks['block5_1'] 104 | self.model6_1 = blocks['block6_1'] 105 | 106 | self.model1_2 = blocks['block1_2'] 107 | self.model2_2 = blocks['block2_2'] 108 | self.model3_2 = blocks['block3_2'] 109 | self.model4_2 = blocks['block4_2'] 110 | self.model5_2 = blocks['block5_2'] 111 | self.model6_2 = blocks['block6_2'] 112 | 113 | 114 | def forward(self, x): 115 | 116 | out1 = self.model0(x) 117 | 118 | out1_1 = self.model1_1(out1) 119 | out1_2 = self.model1_2(out1) 120 | out2 = torch.cat([out1_1, out1_2, out1], 1) 121 | 122 | out2_1 = self.model2_1(out2) 123 | out2_2 = self.model2_2(out2) 124 | out3 = torch.cat([out2_1, out2_2, out1], 1) 125 | 126 | out3_1 = self.model3_1(out3) 127 | out3_2 = self.model3_2(out3) 128 | out4 = torch.cat([out3_1, out3_2, out1], 1) 129 | 130 | out4_1 = self.model4_1(out4) 131 | out4_2 = self.model4_2(out4) 132 | out5 = torch.cat([out4_1, out4_2, out1], 1) 133 | 134 | out5_1 = self.model5_1(out5) 135 | out5_2 = self.model5_2(out5) 136 | out6 = torch.cat([out5_1, out5_2, out1], 1) 137 | 138 | out6_1 = self.model6_1(out6) 139 | out6_2 = self.model6_2(out6) 140 | 141 | return out6_1, out6_2 142 | 143 | class handpose_model(nn.Module): 144 | def __init__(self): 145 | super(handpose_model, self).__init__() 146 | 147 | # these layers have no relu layer 148 | no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 149 | 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] 150 | # stage 1 151 | block1_0 = OrderedDict([ 152 | ('conv1_1', [3, 64, 3, 1, 1]), 153 | ('conv1_2', [64, 64, 3, 1, 1]), 154 | ('pool1_stage1', [2, 2, 0]), 155 | ('conv2_1', [64, 128, 3, 1, 1]), 156 | ('conv2_2', [128, 128, 3, 1, 1]), 157 | ('pool2_stage1', [2, 2, 0]), 158 | ('conv3_1', [128, 256, 3, 1, 1]), 159 | ('conv3_2', [256, 256, 3, 1, 1]), 160 | ('conv3_3', [256, 256, 3, 1, 1]), 161 | ('conv3_4', [256, 256, 3, 1, 1]), 162 | ('pool3_stage1', [2, 2, 0]), 163 | ('conv4_1', [256, 512, 3, 1, 1]), 164 | ('conv4_2', [512, 512, 3, 1, 1]), 165 | ('conv4_3', [512, 512, 3, 1, 1]), 166 | ('conv4_4', [512, 512, 3, 1, 1]), 167 | ('conv5_1', [512, 512, 3, 1, 1]), 168 | ('conv5_2', [512, 512, 3, 1, 1]), 169 | ('conv5_3_CPM', [512, 128, 3, 1, 1]) 170 | ]) 171 | 172 | block1_1 = OrderedDict([ 173 | ('conv6_1_CPM', [128, 512, 1, 1, 0]), 174 | ('conv6_2_CPM', [512, 22, 1, 1, 0]) 175 | ]) 176 | 177 | blocks = {} 178 | blocks['block1_0'] = block1_0 179 | blocks['block1_1'] = block1_1 180 | 181 | # stage 2-6 182 | for i in range(2, 7): 183 | blocks['block%d' % i] = OrderedDict([ 184 | ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), 185 | ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), 186 | ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), 187 | ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), 188 | ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), 189 | ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), 190 | ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) 191 | ]) 192 | 193 | for k in blocks.keys(): 194 | blocks[k] = make_layers(blocks[k], no_relu_layers) 195 | 196 | self.model1_0 = blocks['block1_0'] 197 | self.model1_1 = blocks['block1_1'] 198 | self.model2 = blocks['block2'] 199 | self.model3 = blocks['block3'] 200 | self.model4 = blocks['block4'] 201 | self.model5 = blocks['block5'] 202 | self.model6 = blocks['block6'] 203 | 204 | def forward(self, x): 205 | out1_0 = self.model1_0(x) 206 | out1_1 = self.model1_1(out1_0) 207 | concat_stage2 = torch.cat([out1_1, out1_0], 1) 208 | out_stage2 = self.model2(concat_stage2) 209 | concat_stage3 = torch.cat([out_stage2, out1_0], 1) 210 | out_stage3 = self.model3(concat_stage3) 211 | concat_stage4 = torch.cat([out_stage3, out1_0], 1) 212 | out_stage4 = self.model4(concat_stage4) 213 | concat_stage5 = torch.cat([out_stage4, out1_0], 1) 214 | out_stage5 = self.model5(concat_stage5) 215 | concat_stage6 = torch.cat([out_stage5, out1_0], 1) 216 | out_stage6 = self.model6(concat_stage6) 217 | return out_stage6 218 | 219 | -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/pidinet/LICENSE: -------------------------------------------------------------------------------- 1 | It is just for research purpose, and commercial use should be contacted with authors first. 2 | 3 | Copyright (c) 2021 Zhuo Su 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lib/model_zoo/controlnet_annotator/pidinet/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from einops import rearrange 5 | from .model import pidinet 6 | 7 | models_path = 'pretrained/controlnet/preprocess' 8 | 9 | netNetwork = None 10 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" 11 | modeldir = os.path.join(models_path, "pidinet") 12 | old_modeldir = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | def safe_step(x, step=2): 15 | y = x.astype(np.float32) * float(step + 1) 16 | y = y.astype(np.int32).astype(np.float32) / float(step) 17 | return y 18 | 19 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 20 | """Load file form http url, will download models if necessary. 21 | 22 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 23 | 24 | Args: 25 | url (str): URL to be downloaded. 26 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 27 | Default: None. 28 | progress (bool): Whether to show the download progress. Default: True. 29 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 30 | 31 | Returns: 32 | str: The path to the downloaded file. 33 | """ 34 | from torch.hub import download_url_to_file, get_dir 35 | from urllib.parse import urlparse 36 | if model_dir is None: # use the pytorch hub_dir 37 | hub_dir = get_dir() 38 | model_dir = os.path.join(hub_dir, 'checkpoints') 39 | 40 | os.makedirs(model_dir, exist_ok=True) 41 | 42 | parts = urlparse(url) 43 | filename = os.path.basename(parts.path) 44 | if file_name is not None: 45 | filename = file_name 46 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 47 | if not os.path.exists(cached_file): 48 | print(f'Downloading: "{url}" to {cached_file}\n') 49 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 50 | return cached_file 51 | 52 | def load_state_dict(ckpt_path, location='cpu'): 53 | def get_state_dict(d): 54 | return d.get('state_dict', d) 55 | 56 | _, extension = os.path.splitext(ckpt_path) 57 | if extension.lower() == ".safetensors": 58 | import safetensors.torch 59 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 60 | else: 61 | state_dict = get_state_dict(torch.load( 62 | ckpt_path, map_location=torch.device(location))) 63 | state_dict = get_state_dict(state_dict) 64 | print(f'Loaded state_dict from [{ckpt_path}]') 65 | return state_dict 66 | 67 | def apply_pidinet(input_image, is_safe=False, apply_fliter=False, device='cpu'): 68 | global netNetwork 69 | if netNetwork is None: 70 | modelpath = os.path.join(modeldir, "table5_pidinet.pth") 71 | old_modelpath = os.path.join(old_modeldir, "table5_pidinet.pth") 72 | if os.path.exists(old_modelpath): 73 | modelpath = old_modelpath 74 | elif not os.path.exists(modelpath): 75 | load_file_from_url(remote_model_path, model_dir=modeldir) 76 | netNetwork = pidinet() 77 | ckp = load_state_dict(modelpath) 78 | netNetwork.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) 79 | 80 | netNetwork = netNetwork.to(device) 81 | netNetwork.eval() 82 | assert input_image.ndim == 3 83 | input_image = input_image[:, :, ::-1].copy() 84 | with torch.no_grad(): 85 | image_pidi = torch.from_numpy(input_image).float().to(device) 86 | image_pidi = image_pidi / 255.0 87 | image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') 88 | edge = netNetwork(image_pidi)[-1] 89 | edge = edge.cpu().numpy() 90 | if apply_fliter: 91 | edge = edge > 0.5 92 | if is_safe: 93 | edge = safe_step(edge) 94 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 95 | 96 | return edge[0][0] 97 | 98 | def unload_pid_model(): 99 | global netNetwork 100 | if netNetwork is not None: 101 | netNetwork.cpu() -------------------------------------------------------------------------------- /lib/model_zoo/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 9 | if schedule == "linear": 10 | betas = ( 11 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 12 | ) 13 | 14 | elif schedule == "cosine": 15 | timesteps = ( 16 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 17 | ) 18 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 19 | alphas = torch.cos(alphas).pow(2) 20 | alphas = alphas / alphas[0] 21 | betas = 1 - alphas[1:] / alphas[:-1] 22 | betas = np.clip(betas, a_min=0, a_max=0.999) 23 | 24 | elif schedule == "sqrt_linear": 25 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 26 | elif schedule == "sqrt": 27 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 28 | else: 29 | raise ValueError(f"schedule '{schedule}' unknown.") 30 | return betas.numpy() 31 | 32 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 33 | if ddim_discr_method == 'uniform': 34 | c = num_ddpm_timesteps // num_ddim_timesteps 35 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 36 | elif ddim_discr_method == 'quad': 37 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 38 | else: 39 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 40 | 41 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 42 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 43 | steps_out = ddim_timesteps + 1 44 | if verbose: 45 | print(f'Selected timesteps for ddim sampler: {steps_out}') 46 | return steps_out 47 | 48 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 49 | # select alphas for computing the variance schedule 50 | alphas = alphacums[ddim_timesteps] 51 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 52 | 53 | # according the the formula provided in https://arxiv.org/abs/2010.02502 54 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 55 | if verbose: 56 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 57 | print(f'For the chosen value of eta, which is {eta}, ' 58 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 59 | return sigmas, alphas, alphas_prev 60 | 61 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 62 | """ 63 | Create a beta schedule that discretizes the given alpha_t_bar function, 64 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 65 | :param num_diffusion_timesteps: the number of betas to produce. 66 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 67 | produces the cumulative product of (1-beta) up to that 68 | part of the diffusion process. 69 | :param max_beta: the maximum beta to use; use values lower than 1 to 70 | prevent singularities. 71 | """ 72 | betas = [] 73 | for i in range(num_diffusion_timesteps): 74 | t1 = i / num_diffusion_timesteps 75 | t2 = (i + 1) / num_diffusion_timesteps 76 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 77 | return np.array(betas) 78 | 79 | def extract_into_tensor(a, t, x_shape): 80 | b, *_ = t.shape 81 | out = a.gather(-1, t) 82 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 83 | 84 | def checkpoint(func, inputs, params, flag): 85 | """ 86 | Evaluate a function without caching intermediate activations, allowing for 87 | reduced memory at the expense of extra compute in the backward pass. 88 | :param func: the function to evaluate. 89 | :param inputs: the argument sequence to pass to `func`. 90 | :param params: a sequence of parameters `func` depends on but does not 91 | explicitly take as arguments. 92 | :param flag: if False, disable gradient checkpointing. 93 | """ 94 | if flag: 95 | args = tuple(inputs) + tuple(params) 96 | return CheckpointFunction.apply(func, len(inputs), *args) 97 | else: 98 | return func(*inputs) 99 | 100 | class CheckpointFunction(torch.autograd.Function): 101 | @staticmethod 102 | def forward(ctx, run_function, length, *args): 103 | ctx.run_function = run_function 104 | ctx.input_tensors = list(args[:length]) 105 | ctx.input_params = list(args[length:]) 106 | 107 | with torch.no_grad(): 108 | output_tensors = ctx.run_function(*ctx.input_tensors) 109 | return output_tensors 110 | 111 | @staticmethod 112 | def backward(ctx, *output_grads): 113 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 114 | with torch.enable_grad(): 115 | # Fixes a bug where the first op in run_function modifies the 116 | # Tensor storage in place, which is not allowed for detach()'d 117 | # Tensors. 118 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 119 | output_tensors = ctx.run_function(*shallow_copies) 120 | input_grads = torch.autograd.grad( 121 | output_tensors, 122 | ctx.input_tensors + ctx.input_params, 123 | output_grads, 124 | allow_unused=True, 125 | ) 126 | del ctx.input_tensors 127 | del ctx.input_params 128 | del output_tensors 129 | return (None, None) + input_grads 130 | 131 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 132 | """ 133 | Create sinusoidal timestep embeddings. 134 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 135 | These may be fractional. 136 | :param dim: the dimension of the output. 137 | :param max_period: controls the minimum frequency of the embeddings. 138 | :return: an [N x dim] Tensor of positional embeddings. 139 | """ 140 | if not repeat_only: 141 | half = dim // 2 142 | freqs = torch.exp( 143 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 144 | ).to(device=timesteps.device) 145 | args = timesteps[:, None].float() * freqs[None] 146 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 147 | if dim % 2: 148 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 149 | else: 150 | embedding = repeat(timesteps, 'b -> b d', d=dim) 151 | return embedding 152 | 153 | def zero_module(module): 154 | """ 155 | Zero out the parameters of a module and return it. 156 | """ 157 | for p in module.parameters(): 158 | p.detach().zero_() 159 | return module 160 | 161 | def scale_module(module, scale): 162 | """ 163 | Scale the parameters of a module and return it. 164 | """ 165 | for p in module.parameters(): 166 | p.detach().mul_(scale) 167 | return module 168 | 169 | def mean_flat(tensor): 170 | """ 171 | Take the mean over all non-batch dimensions. 172 | """ 173 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 174 | 175 | def normalization(channels): 176 | """ 177 | Make a standard normalization layer. 178 | :param channels: number of input channels. 179 | :return: an nn.Module for normalization. 180 | """ 181 | return GroupNorm32(32, channels) 182 | 183 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 184 | class SiLU(nn.Module): 185 | def forward(self, x): 186 | return x * torch.sigmoid(x) 187 | 188 | class GroupNorm32(nn.GroupNorm): 189 | def forward(self, x): 190 | # return super().forward(x.float()).type(x.dtype) 191 | return super().forward(x) 192 | 193 | def conv_nd(dims, *args, **kwargs): 194 | """ 195 | Create a 1D, 2D, or 3D convolution module. 196 | """ 197 | if dims == 1: 198 | return nn.Conv1d(*args, **kwargs) 199 | elif dims == 2: 200 | return nn.Conv2d(*args, **kwargs) 201 | elif dims == 3: 202 | return nn.Conv3d(*args, **kwargs) 203 | raise ValueError(f"unsupported dimensions: {dims}") 204 | 205 | def linear(*args, **kwargs): 206 | """ 207 | Create a linear module. 208 | """ 209 | return nn.Linear(*args, **kwargs) 210 | 211 | def avg_pool_nd(dims, *args, **kwargs): 212 | """ 213 | Create a 1D, 2D, or 3D average pooling module. 214 | """ 215 | if dims == 1: 216 | return nn.AvgPool1d(*args, **kwargs) 217 | elif dims == 2: 218 | return nn.AvgPool2d(*args, **kwargs) 219 | elif dims == 3: 220 | return nn.AvgPool3d(*args, **kwargs) 221 | raise ValueError(f"unsupported dimensions: {dims}") 222 | 223 | class HybridConditioner(nn.Module): 224 | 225 | def __init__(self, c_concat_config, c_crossattn_config): 226 | super().__init__() 227 | self.concat_conditioner = instantiate_from_config(c_concat_config) 228 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 229 | 230 | def forward(self, c_concat, c_crossattn): 231 | c_concat = self.concat_conditioner(c_concat) 232 | c_crossattn = self.crossattn_conditioner(c_crossattn) 233 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 234 | 235 | def noise_like(x, repeat=False): 236 | noise = torch.randn_like(x) 237 | if repeat: 238 | bs = x.shape[0] 239 | noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1))) 240 | return noise 241 | 242 | ########################## 243 | # inherit from ldm.utils # 244 | ########################## 245 | 246 | def count_params(model, verbose=False): 247 | total_params = sum(p.numel() for p in model.parameters()) 248 | if verbose: 249 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 250 | return total_params 251 | -------------------------------------------------------------------------------- /lib/model_zoo/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /lib/model_zoo/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LitEma(nn.Module): 5 | def __init__(self, model, decay=0.9999, use_num_updates=True): 6 | super().__init__() 7 | if decay < 0.0 or decay > 1.0: 8 | raise ValueError('Decay must be between 0 and 1') 9 | 10 | self.m_name2s_name = {} 11 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 12 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates 13 | else torch.tensor(-1,dtype=torch.int)) 14 | 15 | for name, p in model.named_parameters(): 16 | if p.requires_grad: 17 | #remove as '.'-character is not allowed in buffers 18 | s_name = name.replace('.','') 19 | self.m_name2s_name.update({name:s_name}) 20 | self.register_buffer(s_name,p.clone().detach().data) 21 | 22 | self.collected_params = [] 23 | 24 | def forward(self, model): 25 | decay = self.decay 26 | 27 | if self.num_updates >= 0: 28 | self.num_updates += 1 29 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 30 | 31 | one_minus_decay = 1.0 - decay 32 | 33 | with torch.no_grad(): 34 | m_param = dict(model.named_parameters()) 35 | shadow_params = dict(self.named_buffers()) 36 | 37 | for key in m_param: 38 | if m_param[key].requires_grad: 39 | sname = self.m_name2s_name[key] 40 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 41 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 42 | else: 43 | assert not key in self.m_name2s_name 44 | 45 | def copy_to(self, model): 46 | m_param = dict(model.named_parameters()) 47 | shadow_params = dict(self.named_buffers()) 48 | for key in m_param: 49 | if m_param[key].requires_grad: 50 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 51 | else: 52 | assert not key in self.m_name2s_name 53 | 54 | def store(self, parameters): 55 | """ 56 | Save the current parameters for restoring later. 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | temporarily stored. 60 | """ 61 | self.collected_params = [param.clone() for param in parameters] 62 | 63 | def restore(self, parameters): 64 | """ 65 | Restore the parameters stored with the `store` method. 66 | Useful to validate the model with EMA parameters without affecting the 67 | original optimization process. Store the parameters before the 68 | `copy_to` method. After validation (or model saving), use this to 69 | restore the former parameters. 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | updated with the stored parameters. 73 | """ 74 | for c_param, param in zip(self.collected_params, parameters): 75 | param.data.copy_(c_param.data) 76 | -------------------------------------------------------------------------------- /lib/model_zoo/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | def append_dims(x, target_dims): 11 | dims_to_append = target_dims - x.ndim 12 | if dims_to_append < 0: 13 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 14 | return x[(...,) + (None,) * dims_to_append] 15 | 16 | def default_noise_sampler(x): 17 | return lambda sigma, sigma_next: torch.randn_like(x) 18 | 19 | def get_ancestral_step(sigma_from, sigma_to, eta=1.): 20 | if not eta: 21 | return sigma_to, 0. 22 | sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) 23 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 24 | return sigma_down, sigma_up 25 | 26 | def to_d(x, sigma, denoised): 27 | return (x - denoised) / append_dims(sigma, x.ndim) 28 | 29 | class Sampler(object): 30 | def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5): 31 | super().__init__() 32 | self.net = net 33 | self.type = type 34 | self.steps = steps 35 | self.output_dim = output_dim 36 | self.n_samples = n_samples 37 | self.scale = scale 38 | self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5 39 | self.log_sigmas = self.sigmas.log() 40 | 41 | def t_to_sigma(self, t): 42 | t = t.float() 43 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() 44 | log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] 45 | return log_sigma.exp() 46 | 47 | def get_sigmas(self, n=None): 48 | def append_zero(x): 49 | return torch.cat([x, x.new_zeros([1])]) 50 | if n is None: 51 | return append_zero(self.sigmas.flip(0)) 52 | t_max = len(self.sigmas) - 1 53 | t = torch.linspace(t_max, 0, n, device=self.sigmas.device) 54 | return append_zero(self.t_to_sigma(t)) 55 | 56 | @torch.no_grad() 57 | def sample(self, x_info, c_info): 58 | h, w = self.output_dim 59 | shape = [self.n_samples, 4, h//8, w//8] 60 | device, dtype = self.net.get_device(), self.net.get_dtype() 61 | 62 | if ('xt' in x_info) and (x_info['xt'] is not None): 63 | xt = x_info['xt'].astype(dtype).to(device) 64 | x_info['x'] = xt 65 | elif ('x0' in x_info) and (x_info['x0'] is not None): 66 | x0 = x_info['x0'].type(dtype).to(device) 67 | ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples) 68 | ts = torch.Tensor(ts).long().to(device) 69 | timesteps = timesteps[:x_info['x0_forward_timesteps']] 70 | x0_nz = self.model.q_sample(x0, ts) 71 | x_info['x'] = x0_nz 72 | else: 73 | x_info['x'] = torch.randn(shape, device=device, dtype=dtype) 74 | 75 | sigmas = self.get_sigmas(n=self.steps) 76 | 77 | if self.type == 'eular_a': 78 | rv = self.sample_euler_ancestral( 79 | x_info=x_info, 80 | c_info=c_info, 81 | sigmas = sigmas) 82 | return rv 83 | 84 | @torch.no_grad() 85 | def sample_euler_ancestral( 86 | self, x_info, c_info, sigmas, eta=1., s_noise=1.,): 87 | 88 | x = x_info['x'] 89 | x = x * sigmas[0] 90 | 91 | noise_sampler = default_noise_sampler(x) 92 | 93 | s_in = x.new_ones([x.shape[0]]) 94 | for i in range(len(sigmas)-1): 95 | denoised = self.net.apply_model(x, sigmas[i] * s_in, ) 96 | 97 | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) 98 | d = to_d(x, sigmas[i], denoised) 99 | # Euler method 100 | dt = sigma_down - sigmas[i] 101 | x = x + d * dt 102 | if sigmas[i + 1] > 0: 103 | x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up 104 | return x 105 | -------------------------------------------------------------------------------- /lib/sync.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import shared_memory 2 | # import multiprocessing 3 | # if hasattr(multiprocessing, "shared_memory"): 4 | # from multiprocessing import shared_memory 5 | # else: 6 | # # workaround for single gpu inference on colab 7 | # shared_memory = None 8 | 9 | import random 10 | import pickle 11 | import time 12 | import copy 13 | import torch 14 | import torch.distributed as dist 15 | from lib.cfg_holder import cfg_unique_holder as cfguh 16 | 17 | def singleton(class_): 18 | instances = {} 19 | def getinstance(*args, **kwargs): 20 | if class_ not in instances: 21 | instances[class_] = class_(*args, **kwargs) 22 | return instances[class_] 23 | return getinstance 24 | 25 | def is_ddp(): 26 | return dist.is_available() and dist.is_initialized() 27 | 28 | def get_rank(type='local'): 29 | ddp = is_ddp() 30 | global_rank = dist.get_rank() if ddp else 0 31 | local_world_size = torch.cuda.device_count() 32 | if type == 'global': 33 | return global_rank 34 | elif type == 'local': 35 | return global_rank % local_world_size 36 | elif type == 'node': 37 | return global_rank // local_world_size 38 | elif type == 'all': 39 | return global_rank, \ 40 | global_rank % local_world_size, \ 41 | global_rank // local_world_size 42 | else: 43 | assert False, 'Unknown type' 44 | 45 | def get_world_size(type='local'): 46 | ddp = is_ddp() 47 | global_rank = dist.get_rank() if ddp else 0 48 | global_world_size = dist.get_world_size() if ddp else 1 49 | local_world_size = torch.cuda.device_count() 50 | if type == 'global': 51 | return global_world_size 52 | elif type == 'local': 53 | return local_world_size 54 | elif type == 'node': 55 | return global_world_size // local_world_size 56 | elif type == 'all': 57 | return global_world_size, local_world_size, \ 58 | global_world_size // local_world_size 59 | else: 60 | assert False, 'Unknown type' 61 | 62 | class barrier_lock(object): 63 | def __init__(self, n): 64 | self.n = n 65 | id = int(random.random()*10000) + int(time.time())*10000 66 | self.lock_shmname = 'barrier_lock_{}'.format(id) 67 | lock_shm = shared_memory.SharedMemory( 68 | name=self.lock_shmname, create=True, size=n) 69 | for i in range(n): 70 | lock_shm.buf[i] = 0 71 | lock_shm.close() 72 | 73 | def destroy(self): 74 | try: 75 | lock_shm = shared_memory.SharedMemory( 76 | name=self.lock_shmname) 77 | lock_shm.close() 78 | lock_shm.unlink() 79 | except: 80 | return 81 | 82 | def wait(self, k): 83 | lock_shm = shared_memory.SharedMemory( 84 | name=self.lock_shmname) 85 | assert lock_shm.buf[k] == 0, 'Two waits on the same id is not allowed.' 86 | lock_shm.buf[k] = 1 87 | if k == 0: 88 | while sum([lock_shm.buf[i]==0 for i in range(self.n)]) != 0: 89 | pass 90 | for i in range(self.n): 91 | lock_shm.buf[i] = 0 92 | return 93 | else: 94 | while lock_shm.buf[k] != 0: 95 | pass 96 | 97 | class default_lock(object): 98 | def __init__(self): 99 | id = int(random.random()*10000) + int(time.time())*10000 100 | self.lock_shmname = 'lock_{}'.format(id) 101 | lock_shm = shared_memory.SharedMemory( 102 | name=self.lock_shmname, create=True, size=2) 103 | for i in range(2): 104 | lock_shm.buf[i] = 0 105 | lock_shm.close() 106 | 107 | def destroy(self): 108 | try: 109 | lock_shm = shared_memory.SharedMemory( 110 | name=self.lock_shmname) 111 | lock_shm.close() 112 | lock_shm.unlink() 113 | except: 114 | return 115 | 116 | def lock(self, k): 117 | lock_shm = shared_memory.SharedMemory( 118 | name=self.lock_shmname) 119 | while lock_shm.buf[0] == 1: 120 | pass 121 | lock_shm.buf[0] = 1 122 | lock_shm.buf[1] = k 123 | 124 | def unlock(self, k): 125 | lock_shm = shared_memory.SharedMemory( 126 | name=self.lock_shmname) 127 | if lock_shm.buf[1] != k: 128 | return 129 | lock_shm.buf[0] = 0 130 | return 131 | 132 | class nodewise_sync_global(object): 133 | """ 134 | This is the global part of nodewise_sync that need to call at master process 135 | before spawn. 136 | """ 137 | def __init__(self): 138 | self.local_world_size = get_world_size('local') 139 | self.reg_lock = default_lock() 140 | self.b_lock = barrier_lock(self.local_world_size) 141 | id = int(random.random()*10000) + int(time.time())*10000 142 | self.id_shmname = 'nodewise_sync_id_shm_{}'.format(id) 143 | 144 | def destroy(self): 145 | self.reg_lock.destroy() 146 | self.b_lock.destroy() 147 | try: 148 | shm = shared_memory.SharedMemory(name=self.id_shmname) 149 | shm.close() 150 | shm.unlink() 151 | except: 152 | return 153 | 154 | @singleton 155 | class nodewise_sync(object): 156 | """ 157 | A class that centralize nodewise sync activities. 158 | The backend is multiprocess sharememory, not torch, as torch not support this. 159 | """ 160 | def __init__(self): 161 | pass 162 | 163 | def copy_global(self, reference): 164 | self.local_world_size = reference.local_world_size 165 | self.b_lock = reference.b_lock 166 | self.reg_lock = reference.reg_lock 167 | self.id_shmname = reference.id_shmname 168 | return self 169 | 170 | def local_init(self): 171 | self.ddp = is_ddp() 172 | self.global_rank, self.local_rank, self.node_rank = get_rank('all') 173 | self.global_world_size, self.local_world_size, self.nodes = get_world_size('all') 174 | if self.local_rank == 0: 175 | temp = int(random.random()*10000) + int(time.time())*10000 176 | temp = pickle.dumps(temp) 177 | shm = shared_memory.SharedMemory( 178 | name=self.id_shmname, create=True, size=len(temp)) 179 | shm.close() 180 | return self 181 | 182 | def random_sync_id(self): 183 | assert self.local_rank is not None, 'Not initialized!' 184 | if self.local_rank == 0: 185 | sync_id = int(random.random()*10000) + int(time.time())*10000 186 | data = pickle.dumps(sync_id) 187 | shm = shared_memory.SharedMemory(name=self.id_shmname) 188 | shm.buf[0:len(data)] = data[0:len(data)] 189 | self.barrier() 190 | shm.close() 191 | else: 192 | self.barrier() 193 | shm = shared_memory.SharedMemory(name=self.id_shmname) 194 | sync_id = pickle.loads(shm.buf) 195 | shm.close() 196 | return sync_id 197 | 198 | def barrier(self): 199 | self.b_lock.wait(self.local_rank) 200 | 201 | def lock(self): 202 | self.reg_lock.lock(self.local_rank) 203 | 204 | def unlock(self): 205 | self.reg_lock.unlock(self.local_rank) 206 | 207 | def broadcast_r0(self, data=None): 208 | assert self.local_rank is not None, 'Not initialized!' 209 | id = self.random_sync_id() 210 | shmname = 'broadcast_r0_{}'.format(id) 211 | if self.local_rank == 0: 212 | assert data!=None, 'Rank 0 needs to input data!' 213 | data = pickle.dumps(data) 214 | datan = len(data) 215 | load_info_shm = shared_memory.SharedMemory( 216 | name=shmname, create=True, size=datan) 217 | load_info_shm.buf[0:datan] = data[0:datan] 218 | self.barrier() 219 | self.barrier() 220 | load_info_shm.close() 221 | load_info_shm.unlink() 222 | return None 223 | else: 224 | assert data==None, 'Rank other than 1 should input None as data!' 225 | self.barrier() 226 | shm = shared_memory.SharedMemory(name=shmname) 227 | data = pickle.loads(shm.buf) 228 | shm.close() 229 | self.barrier() 230 | return data 231 | 232 | def destroy(self): 233 | self.barrier.destroy() 234 | try: 235 | shm = shared_memory.SharedMemory(name=self.id_shmname) 236 | shm.close() 237 | shm.unlink() 238 | except: 239 | return 240 | 241 | # import contextlib 242 | 243 | # @contextlib.contextmanager 244 | # def weight_sync(module, sync): 245 | # assert isinstance(module, torch.nn.Module) 246 | # if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 247 | # yield 248 | # else: 249 | # with module.no_sync(): 250 | # yield 251 | 252 | # def weight_sync(net): 253 | # for parameters in net.parameters(): 254 | # dist.all_reduce(parameters, dist.ReduceOp.AVG) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.2 2 | pyyaml==5.4.1 3 | easydict==1.9 4 | tensorboardx==2.6 5 | tensorboard==2.12.0 6 | protobuf==3.20.3 7 | lpips==0.1.3 8 | fsspec==2022.7.1 9 | 10 | tqdm==4.60.0 11 | transformers==4.24.0 12 | torchmetrics==0.7.3 13 | 14 | einops==0.3.0 15 | omegaconf==2.1.1 16 | open_clip_torch==2.0.2 17 | webdataset==0.2.5 18 | gradio==3.17.1 19 | 20 | safetensors==0.3.1 21 | 22 | opencv-python 23 | scikit-image 24 | -------------------------------------------------------------------------------- /tools/get_controlnet.py: -------------------------------------------------------------------------------- 1 | # A tool to get and slim downloaded controlnet 2 | 3 | import torch 4 | from safetensors.torch import load_file, save_file 5 | from collections import OrderedDict 6 | import os.path as osp 7 | 8 | in_path = 'pretrained/controlnet/sdwebui_compatible/control_v11p_sd15_canny.pth' 9 | out_path = 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors' 10 | 11 | sd = torch.load(in_path) 12 | 13 | sdnew = [[ni.replace('control_model.', ''), vi] for ni, vi in sd.items()] 14 | save_file(OrderedDict(sdnew), out_path) 15 | --------------------------------------------------------------------------------