├── .gitignore ├── LICENSE.txt ├── README.md ├── setup.py ├── src └── controlnet_aux │ ├── __init__.py │ ├── anyline │ └── __init__.py │ ├── canny │ └── __init__.py │ ├── dwpose │ ├── __init__.py │ ├── dwpose_config │ │ ├── __init__.py │ │ └── dwpose-l_384x288.py │ ├── util.py │ ├── wholebody.py │ └── yolox_config │ │ ├── __init__.py │ │ └── yolox_l_8xb8-300e_coco.py │ ├── hed │ └── __init__.py │ ├── leres │ ├── __init__.py │ ├── leres │ │ ├── LICENSE │ │ ├── Resnet.py │ │ ├── Resnext_torch.py │ │ ├── __init__.py │ │ ├── depthmap.py │ │ ├── multi_depth_model_woauxi.py │ │ ├── net_tools.py │ │ └── network_auxi.py │ └── pix2pix │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_model_hg.py │ │ ├── networks.py │ │ └── pix2pix4depth_model.py │ │ ├── options │ │ ├── __init__.py │ │ ├── base_options.py │ │ └── test_options.py │ │ └── util │ │ ├── __init__.py │ │ └── util.py │ ├── lineart │ ├── LICENSE │ └── __init__.py │ ├── lineart_anime │ ├── LICENSE │ └── __init__.py │ ├── lineart_standard │ └── __init__.py │ ├── mediapipe_face │ ├── __init__.py │ └── mediapipe_face_common.py │ ├── midas │ ├── LICENSE │ ├── __init__.py │ ├── api.py │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ └── utils.py │ ├── mlsd │ ├── LICENSE │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── mbv2_mlsd_large.py │ │ └── mbv2_mlsd_tiny.py │ └── utils.py │ ├── normalbae │ ├── LICENSE │ ├── __init__.py │ └── nets │ │ ├── NNET.py │ │ ├── __init__.py │ │ ├── baseline.py │ │ └── submodules │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── efficientnet_repo │ │ ├── .gitignore │ │ ├── BENCHMARK.md │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── caffe2_benchmark.py │ │ ├── caffe2_validate.py │ │ ├── geffnet │ │ │ ├── __init__.py │ │ │ ├── activations │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── activations_jit.py │ │ │ │ └── activations_me.py │ │ │ ├── config.py │ │ │ ├── conv2d_layers.py │ │ │ ├── efficientnet_builder.py │ │ │ ├── gen_efficientnet.py │ │ │ ├── helpers.py │ │ │ ├── mobilenetv3.py │ │ │ ├── model_factory.py │ │ │ └── version.py │ │ ├── hubconf.py │ │ ├── onnx_export.py │ │ ├── onnx_optimize.py │ │ ├── onnx_to_caffe.py │ │ ├── onnx_validate.py │ │ ├── requirements.txt │ │ ├── setup.py │ │ ├── utils.py │ │ └── validate.py │ │ ├── encoder.py │ │ └── submodules.py │ ├── open_pose │ ├── LICENSE │ ├── __init__.py │ ├── body.py │ ├── face.py │ ├── hand.py │ ├── model.py │ └── util.py │ ├── pidi │ ├── LICENSE │ ├── __init__.py │ └── model.py │ ├── processor.py │ ├── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ ├── tiny_vit_sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py │ ├── shuffle │ └── __init__.py │ ├── teed │ ├── Fsmish.py │ ├── LICENSE.txt │ ├── Xsmish.py │ ├── __init__.py │ └── ted.py │ ├── tests │ ├── requirements.txt │ ├── test_image.png │ ├── test_processor.py │ └── test_processor_pytest.py │ ├── util.py │ └── zoe │ ├── LICENSE │ ├── __init__.py │ └── zoedepth │ ├── __init__.py │ ├── models │ ├── __init__.py │ ├── base_models │ │ ├── __init__.py │ │ ├── midas.py │ │ └── midas_repo │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── hubconf.py │ │ │ └── midas │ │ │ ├── __init__.py │ │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── beit.py │ │ │ ├── levit.py │ │ │ ├── next_vit.py │ │ │ ├── swin.py │ │ │ ├── swin2.py │ │ │ ├── swin_common.py │ │ │ ├── timm_adapter.py │ │ │ ├── utils.py │ │ │ └── vit.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net.py │ │ │ ├── midas_net_custom.py │ │ │ ├── model_loader.py │ │ │ └── transforms.py │ ├── builder.py │ ├── depth_model.py │ ├── layers │ │ ├── __init__.py │ │ ├── attractor.py │ │ ├── dist_layers.py │ │ ├── localbins_layers.py │ │ └── patch_transformer.py │ ├── model_io.py │ ├── zoedepth │ │ ├── __init__.py │ │ ├── config_zoedepth.json │ │ ├── config_zoedepth_kitti.json │ │ └── zoedepth_v1.py │ └── zoedepth_nk │ │ ├── __init__.py │ │ ├── config_zoedepth_nk.json │ │ └── zoedepth_nk_v1.py │ └── utils │ ├── __init__.py │ ├── arg_utils.py │ ├── config.py │ └── easydict │ └── __init__.py └── tests └── test_controlnet_aux.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | tests/outputs 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # vscode 126 | .vs 127 | .vscode 128 | 129 | # Pycharm 130 | .idea 131 | 132 | # TF code 133 | tensorflow_code 134 | 135 | # Models 136 | proc_data 137 | 138 | # examples 139 | runs 140 | /runs_old 141 | /wandb 142 | /examples/runs 143 | /examples/**/*.args 144 | /examples/rag/sweep 145 | 146 | # data 147 | /data 148 | serialization_dir 149 | 150 | # emacs 151 | *.*~ 152 | debug.env 153 | 154 | # vim 155 | .*.swp 156 | 157 | #ctags 158 | tags 159 | 160 | # pre-commit 161 | .pre-commit* 162 | 163 | # .lock 164 | *.lock 165 | 166 | # DS_Store (MacOS) 167 | .DS_Store 168 | # RL pipelines may produce mp4 outputs 169 | *.mp4 170 | 171 | # dependencies 172 | /transformers 173 | 174 | # ruff 175 | .ruff_cache 176 | 177 | wandb 178 | 179 | -------------------------------------------------------------------------------- /src/controlnet_aux/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.10" 2 | 3 | from .anyline import AnylineDetector 4 | from .canny import CannyDetector 5 | from .dwpose import DWposeDetector 6 | from .hed import HEDdetector 7 | from .leres import LeresDetector 8 | from .lineart import LineartDetector 9 | from .lineart_anime import LineartAnimeDetector 10 | from .lineart_standard import LineartStandardDetector 11 | from .mediapipe_face import MediapipeFaceDetector 12 | from .midas import MidasDetector 13 | from .mlsd import MLSDdetector 14 | from .normalbae import NormalBaeDetector 15 | from .open_pose import OpenposeDetector 16 | from .pidi import PidiNetDetector 17 | from .segment_anything import SamDetector 18 | from .shuffle import ContentShuffleDetector 19 | from .teed import TEEDdetector 20 | from .zoe import ZoeDetector 21 | -------------------------------------------------------------------------------- /src/controlnet_aux/anyline/__init__.py: -------------------------------------------------------------------------------- 1 | # code based in https://github.com/TheMistoAI/ComfyUI-Anyline/blob/main/anyline.py 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from huggingface_hub import hf_hub_download 9 | from PIL import Image 10 | from skimage import morphology 11 | 12 | from ..teed.ted import TED 13 | from ..util import HWC3, resize_image, safe_step 14 | 15 | 16 | class AnylineDetector: 17 | def __init__(self, model): 18 | self.model = model 19 | 20 | @classmethod 21 | def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): 22 | if os.path.isdir(pretrained_model_or_path): 23 | model_path = os.path.join(pretrained_model_or_path, filename) 24 | else: 25 | model_path = hf_hub_download( 26 | pretrained_model_or_path, filename, subfolder=subfolder 27 | ) 28 | 29 | model = TED() 30 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 31 | 32 | return cls(model) 33 | 34 | def to(self, device): 35 | self.model.to(device) 36 | return self 37 | 38 | def __call__( 39 | self, 40 | input_image, 41 | detect_resolution=1280, 42 | guassian_sigma=2.0, 43 | intensity_threshold=3, 44 | output_type="pil", 45 | ): 46 | device = next(iter(self.model.parameters())).device 47 | 48 | if not isinstance(input_image, np.ndarray): 49 | input_image = np.array(input_image, dtype=np.uint8) 50 | output_type = output_type or "pil" 51 | else: 52 | output_type = output_type or "np" 53 | 54 | original_height, original_width, _ = input_image.shape 55 | 56 | input_image = HWC3(input_image) 57 | input_image = resize_image(input_image, detect_resolution) 58 | 59 | assert input_image.ndim == 3 60 | height, width, _ = input_image.shape 61 | with torch.no_grad(): 62 | image_teed = torch.from_numpy(input_image.copy()).float().to(device) 63 | image_teed = rearrange(image_teed, "h w c -> 1 c h w") 64 | edges = self.model(image_teed) 65 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 66 | edges = [ 67 | cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) 68 | for e in edges 69 | ] 70 | edges = np.stack(edges, axis=2) 71 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 72 | edge = safe_step(edge, 2) 73 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 74 | 75 | mteed_result = edge 76 | mteed_result = HWC3(mteed_result) 77 | 78 | x = input_image.astype(np.float32) 79 | g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) 80 | intensity = np.min(g - x, axis=2).clip(0, 255) 81 | intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) 82 | intensity *= 127 83 | lineart_result = intensity.clip(0, 255).astype(np.uint8) 84 | 85 | lineart_result = HWC3(lineart_result) 86 | 87 | lineart_result = self.get_intensity_mask( 88 | lineart_result, lower_bound=0, upper_bound=255 89 | ) 90 | 91 | cleaned = morphology.remove_small_objects( 92 | lineart_result.astype(bool), min_size=36, connectivity=1 93 | ) 94 | lineart_result = lineart_result * cleaned 95 | final_result = self.combine_layers(mteed_result, lineart_result) 96 | 97 | final_result = cv2.resize( 98 | final_result, 99 | (original_width, original_height), 100 | interpolation=cv2.INTER_LINEAR, 101 | ) 102 | 103 | if output_type == "pil": 104 | final_result = Image.fromarray(final_result) 105 | 106 | return final_result 107 | 108 | def get_intensity_mask(self, image_array, lower_bound, upper_bound): 109 | mask = image_array[:, :, 0] 110 | mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0) 111 | mask = np.expand_dims(mask, 2).repeat(3, axis=2) 112 | return mask 113 | 114 | def combine_layers(self, base_layer, top_layer): 115 | mask = top_layer.astype(bool) 116 | temp = 1 - (1 - top_layer) * (1 - base_layer) 117 | result = base_layer * (~mask) + temp * mask 118 | return result 119 | -------------------------------------------------------------------------------- /src/controlnet_aux/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from ..util import HWC3, resize_image 6 | 7 | class CannyDetector: 8 | def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): 9 | if "img" in kwargs: 10 | warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) 11 | input_image = kwargs.pop("img") 12 | 13 | if input_image is None: 14 | raise ValueError("input_image must be defined.") 15 | 16 | if not isinstance(input_image, np.ndarray): 17 | input_image = np.array(input_image, dtype=np.uint8) 18 | output_type = output_type or "pil" 19 | else: 20 | output_type = output_type or "np" 21 | 22 | input_image = HWC3(input_image) 23 | input_image = resize_image(input_image, detect_resolution) 24 | 25 | detected_map = cv2.Canny(input_image, low_threshold, high_threshold) 26 | detected_map = HWC3(detected_map) 27 | 28 | img = resize_image(input_image, image_resolution) 29 | H, W, C = img.shape 30 | 31 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 32 | 33 | if output_type == "pil": 34 | detected_map = Image.fromarray(detected_map) 35 | 36 | return detected_map 37 | -------------------------------------------------------------------------------- /src/controlnet_aux/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Openpose 2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 4 | # 3rd Edited by ControlNet 5 | # 4th Edited by ControlNet (added face and correct hands) 6 | 7 | import os 8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 9 | 10 | import cv2 11 | import torch 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from ..util import HWC3, resize_image 16 | from . import util 17 | 18 | 19 | def draw_pose(pose, H, W): 20 | bodies = pose['bodies'] 21 | faces = pose['faces'] 22 | hands = pose['hands'] 23 | candidate = bodies['candidate'] 24 | subset = bodies['subset'] 25 | 26 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 27 | canvas = util.draw_bodypose(canvas, candidate, subset) 28 | canvas = util.draw_handpose(canvas, hands) 29 | canvas = util.draw_facepose(canvas, faces) 30 | 31 | return canvas 32 | 33 | class DWposeDetector: 34 | def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): 35 | from .wholebody import Wholebody 36 | 37 | self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) 38 | 39 | def to(self, device): 40 | self.pose_estimation.to(device) 41 | return self 42 | 43 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): 44 | 45 | input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) 46 | 47 | input_image = HWC3(input_image) 48 | input_image = resize_image(input_image, detect_resolution) 49 | H, W, C = input_image.shape 50 | 51 | with torch.no_grad(): 52 | candidate, subset = self.pose_estimation(input_image) 53 | nums, keys, locs = candidate.shape 54 | candidate[..., 0] /= float(W) 55 | candidate[..., 1] /= float(H) 56 | body = candidate[:,:18].copy() 57 | body = body.reshape(nums*18, locs) 58 | score = subset[:,:18] 59 | 60 | for i in range(len(score)): 61 | for j in range(len(score[i])): 62 | if score[i][j] > 0.3: 63 | score[i][j] = int(18*i+j) 64 | else: 65 | score[i][j] = -1 66 | 67 | un_visible = subset<0.3 68 | candidate[un_visible] = -1 69 | 70 | foot = candidate[:,18:24] 71 | 72 | faces = candidate[:,24:92] 73 | 74 | hands = candidate[:,92:113] 75 | hands = np.vstack([hands, candidate[:,113:]]) 76 | 77 | bodies = dict(candidate=body, subset=score) 78 | pose = dict(bodies=bodies, hands=hands, faces=faces) 79 | 80 | detected_map = draw_pose(pose, H, W) 81 | detected_map = HWC3(detected_map) 82 | 83 | img = resize_image(input_image, image_resolution) 84 | H, W, C = img.shape 85 | 86 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 87 | 88 | if output_type == "pil": 89 | detected_map = Image.fromarray(detected_map) 90 | 91 | return detected_map 92 | -------------------------------------------------------------------------------- /src/controlnet_aux/dwpose/dwpose_config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/dwpose/dwpose_config/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import numpy as np 4 | import warnings 5 | 6 | try: 7 | import mmcv 8 | except ImportError: 9 | warnings.warn( 10 | "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'" 11 | ) 12 | 13 | try: 14 | from mmpose.apis import inference_topdown 15 | from mmpose.apis import init_model as init_pose_estimator 16 | from mmpose.evaluation.functional import nms 17 | from mmpose.utils import adapt_mmdet_pipeline 18 | from mmpose.structures import merge_data_samples 19 | except ImportError: 20 | warnings.warn( 21 | "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'" 22 | ) 23 | 24 | try: 25 | from mmdet.apis import inference_detector, init_detector 26 | except ImportError: 27 | warnings.warn( 28 | "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'" 29 | ) 30 | 31 | 32 | class Wholebody: 33 | def __init__(self, 34 | det_config=None, det_ckpt=None, 35 | pose_config=None, pose_ckpt=None, 36 | device="cpu"): 37 | 38 | if det_config is None: 39 | det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py") 40 | 41 | if pose_config is None: 42 | pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py") 43 | 44 | if det_ckpt is None: 45 | det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' 46 | 47 | if pose_ckpt is None: 48 | pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" 49 | 50 | # build detector 51 | self.detector = init_detector(det_config, det_ckpt, device=device) 52 | self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) 53 | 54 | # build pose estimator 55 | self.pose_estimator = init_pose_estimator( 56 | pose_config, 57 | pose_ckpt, 58 | device=device) 59 | 60 | def to(self, device): 61 | self.detector.to(device) 62 | self.pose_estimator.to(device) 63 | return self 64 | 65 | def __call__(self, oriImg): 66 | # predict bbox 67 | det_result = inference_detector(self.detector, oriImg) 68 | pred_instance = det_result.pred_instances.cpu().numpy() 69 | bboxes = np.concatenate( 70 | (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) 71 | bboxes = bboxes[np.logical_and(pred_instance.labels == 0, 72 | pred_instance.scores > 0.5)] 73 | 74 | # set NMS threshold 75 | bboxes = bboxes[nms(bboxes, 0.7), :4] 76 | 77 | # predict keypoints 78 | if len(bboxes) == 0: 79 | pose_results = inference_topdown(self.pose_estimator, oriImg) 80 | else: 81 | pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) 82 | preds = merge_data_samples(pose_results) 83 | preds = preds.pred_instances 84 | 85 | # preds = pose_results[0].pred_instances 86 | keypoints = preds.get('transformed_keypoints', 87 | preds.keypoints) 88 | if 'keypoint_scores' in preds: 89 | scores = preds.keypoint_scores 90 | else: 91 | scores = np.ones(keypoints.shape[:-1]) 92 | 93 | if 'keypoints_visible' in preds: 94 | visible = preds.keypoints_visible 95 | else: 96 | visible = np.ones(keypoints.shape[:-1]) 97 | keypoints_info = np.concatenate( 98 | (keypoints, scores[..., None], visible[..., None]), 99 | axis=-1) 100 | # compute neck joint 101 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 102 | # neck score when visualizing pred 103 | neck[:, 2:4] = np.logical_and( 104 | keypoints_info[:, 5, 2:4] > 0.3, 105 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 106 | new_keypoints_info = np.insert( 107 | keypoints_info, 17, neck, axis=1) 108 | mmpose_idx = [ 109 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 110 | ] 111 | openpose_idx = [ 112 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 113 | ] 114 | new_keypoints_info[:, openpose_idx] = \ 115 | new_keypoints_info[:, mmpose_idx] 116 | keypoints_info = new_keypoints_info 117 | 118 | keypoints, scores, visible = keypoints_info[ 119 | ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] 120 | 121 | return keypoints, scores 122 | -------------------------------------------------------------------------------- /src/controlnet_aux/dwpose/yolox_config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/dwpose/yolox_config/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/leres/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from huggingface_hub import hf_hub_download 7 | from PIL import Image 8 | 9 | from ..util import HWC3, resize_image 10 | from .leres.depthmap import estimateboost, estimateleres 11 | from .leres.multi_depth_model_woauxi import RelDepthModel 12 | from .leres.net_tools import strip_prefix_if_present 13 | from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel 14 | from .pix2pix.options.test_options import TestOptions 15 | 16 | 17 | class LeresDetector: 18 | def __init__(self, model, pix2pixmodel): 19 | self.model = model 20 | self.pix2pixmodel = pix2pixmodel 21 | 22 | @classmethod 23 | def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None, local_files_only=False): 24 | filename = filename or "res101.pth" 25 | pix2pix_filename = pix2pix_filename or "latest_net_G.pth" 26 | 27 | if os.path.isdir(pretrained_model_or_path): 28 | model_path = os.path.join(pretrained_model_or_path, filename) 29 | else: 30 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 31 | 32 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 33 | 34 | model = RelDepthModel(backbone='resnext101') 35 | model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) 36 | del checkpoint 37 | 38 | if os.path.isdir(pretrained_model_or_path): 39 | model_path = os.path.join(pretrained_model_or_path, pix2pix_filename) 40 | else: 41 | model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir, local_files_only=local_files_only) 42 | 43 | opt = TestOptions().parse() 44 | if not torch.cuda.is_available(): 45 | opt.gpu_ids = [] # cpu mode 46 | pix2pixmodel = Pix2Pix4DepthModel(opt) 47 | pix2pixmodel.save_dir = os.path.dirname(model_path) 48 | pix2pixmodel.load_networks('latest') 49 | pix2pixmodel.eval() 50 | 51 | return cls(model, pix2pixmodel) 52 | 53 | def to(self, device): 54 | self.model.to(device) 55 | # TODO - refactor pix2pix implementation to support device migration 56 | # self.pix2pixmodel.to(device) 57 | return self 58 | 59 | def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"): 60 | device = next(iter(self.model.parameters())).device 61 | if not isinstance(input_image, np.ndarray): 62 | input_image = np.array(input_image, dtype=np.uint8) 63 | 64 | input_image = HWC3(input_image) 65 | input_image = resize_image(input_image, detect_resolution) 66 | 67 | assert input_image.ndim == 3 68 | height, width, dim = input_image.shape 69 | 70 | with torch.no_grad(): 71 | 72 | if boost: 73 | depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height)) 74 | else: 75 | depth = estimateleres(input_image, self.model, width, height) 76 | 77 | numbytes=2 78 | depth_min = depth.min() 79 | depth_max = depth.max() 80 | max_val = (2**(8*numbytes))-1 81 | 82 | # check output before normalizing and mapping to 16 bit 83 | if depth_max - depth_min > np.finfo("float").eps: 84 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 85 | else: 86 | out = np.zeros(depth.shape) 87 | 88 | # single channel, 16 bit image 89 | depth_image = out.astype("uint16") 90 | 91 | # convert to uint8 92 | depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0)) 93 | 94 | # remove near 95 | if thr_a != 0: 96 | thr_a = ((thr_a/100)*255) 97 | depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1] 98 | 99 | # invert image 100 | depth_image = cv2.bitwise_not(depth_image) 101 | 102 | # remove bg 103 | if thr_b != 0: 104 | thr_b = ((thr_b/100)*255) 105 | depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1] 106 | 107 | detected_map = depth_image 108 | detected_map = HWC3(detected_map) 109 | 110 | img = resize_image(input_image, image_resolution) 111 | H, W, C = img.shape 112 | 113 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 114 | 115 | if output_type == "pil": 116 | detected_map = Image.fromarray(detected_map) 117 | 118 | return detected_map 119 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/leres/LICENSE: -------------------------------------------------------------------------------- 1 | https://github.com/thygate/stable-diffusion-webui-depthmap-script 2 | 3 | MIT License 4 | 5 | Copyright (c) 2023 Bob Thiry 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/leres/leres/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/leres/leres/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/leres/leres/multi_depth_model_woauxi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import network_auxi as network 5 | from .net_tools import get_func 6 | 7 | 8 | class RelDepthModel(nn.Module): 9 | def __init__(self, backbone='resnet50'): 10 | super(RelDepthModel, self).__init__() 11 | if backbone == 'resnet50': 12 | encoder = 'resnet50_stride32' 13 | elif backbone == 'resnext101': 14 | encoder = 'resnext101_stride32x8d' 15 | self.depth_model = DepthModel(encoder) 16 | 17 | def inference(self, rgb): 18 | with torch.no_grad(): 19 | input = rgb.to(self.depth_model.device) 20 | depth = self.depth_model(input) 21 | #pred_depth_out = depth - depth.min() + 0.01 22 | return depth #pred_depth_out 23 | 24 | 25 | class DepthModel(nn.Module): 26 | def __init__(self, encoder): 27 | super(DepthModel, self).__init__() 28 | backbone = network.__name__.split('.')[-1] + '.' + encoder 29 | self.encoder_modules = get_func(backbone)() 30 | self.decoder_modules = network.Decoder() 31 | 32 | def forward(self, x): 33 | lateral_out = self.encoder_modules(x) 34 | out_logit = self.decoder_modules(lateral_out) 35 | return out_logit -------------------------------------------------------------------------------- /src/controlnet_aux/leres/leres/net_tools.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | 6 | 7 | def get_func(func_name): 8 | """Helper to return a function object by name. func_name must identify a 9 | function in this module or the path to a function relative to the base 10 | 'modeling' module. 11 | """ 12 | if func_name == '': 13 | return None 14 | try: 15 | parts = func_name.split('.') 16 | # Refers to a function in this module 17 | if len(parts) == 1: 18 | return globals()[parts[0]] 19 | # Otherwise, assume we're referencing a module under modeling 20 | module_name = '.' + '.'.join(parts[:-1]) 21 | 22 | # Import module_name, for example ".network_auxi", 23 | # under __package__=="controlnet_aux.leres.leres" 24 | # __package__ resolves to the package namespace above this file 25 | module = importlib.import_module(module_name, package=__package__) 26 | return getattr(module, parts[-1]) 27 | except Exception: 28 | print('Failed to f1ind function: %s', func_name) 29 | raise 30 | 31 | def load_ckpt(args, depth_model, shift_model, focal_model): 32 | """ 33 | Load checkpoint. 34 | """ 35 | if os.path.isfile(args.load_ckpt): 36 | print("loading checkpoint %s" % args.load_ckpt) 37 | checkpoint = torch.load(args.load_ckpt) 38 | if shift_model is not None: 39 | shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), 40 | strict=True) 41 | if focal_model is not None: 42 | focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), 43 | strict=True) 44 | depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), 45 | strict=True) 46 | del checkpoint 47 | if torch.cuda.is_available(): 48 | torch.cuda.empty_cache() 49 | 50 | 51 | def strip_prefix_if_present(state_dict, prefix): 52 | keys = sorted(state_dict.keys()) 53 | if not all(key.startswith(prefix) for key in keys): 54 | return state_dict 55 | stripped_state_dict = OrderedDict() 56 | for key, value in state_dict.items(): 57 | stripped_state_dict[key.replace(prefix, "")] = value 58 | return stripped_state_dict -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/LICENSE: -------------------------------------------------------------------------------- 1 | https://github.com/compphoto/BoostingMonocularDepth 2 | 3 | Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved. 4 | 5 | This software is for academic use only. A redistribution of this 6 | software, with or without modifications, has to be for academic 7 | use only, while giving the appropriate credit to the original 8 | authors of the software. The methods implemented as a part of 9 | this software may be covered under patents or patent applications. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED 12 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 13 | FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR 14 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 15 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 16 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 17 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 18 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 19 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/leres/pix2pix/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from .base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "." + model_name + "_model" 33 | 34 | # Import model_filename, for example ".pix2pix4depth_model", 35 | # under __package__=="controlnet_aux.leres.pix2pix.models" 36 | # __package__ resolves to the package namespace above this file 37 | modellib = importlib.import_module(model_filename, package=__package__) 38 | model = None 39 | target_model_name = model_name.replace('_', '') + 'model' 40 | for name, cls in modellib.__dict__.items(): 41 | if name.lower() == target_model_name.lower() \ 42 | and issubclass(cls, BaseModel): 43 | model = cls 44 | 45 | if model is None: 46 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 47 | exit(0) 48 | 49 | return model 50 | 51 | 52 | def get_option_setter(model_name): 53 | """Return the static method of the model class.""" 54 | model_class = find_model_using_name(model_name) 55 | return model_class.modify_commandline_options 56 | 57 | 58 | def create_model(opt): 59 | """Create a model given the option. 60 | 61 | This function warps the class CustomDatasetDataLoader. 62 | This is the main interface between this package and 'train.py'/'test.py' 63 | 64 | Example: 65 | >>> from models import create_model 66 | >>> model = create_model(opt) 67 | """ 68 | model = find_model_using_name(opt.model) 69 | instance = model(opt) 70 | print("model [%s] was created" % type(instance).__name__) 71 | return instance 72 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/models/base_model_hg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class BaseModelHG(): 5 | def name(self): 6 | return 'BaseModel' 7 | 8 | def initialize(self, opt): 9 | self.opt = opt 10 | self.gpu_ids = opt.gpu_ids 11 | self.isTrain = opt.isTrain 12 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 13 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 14 | 15 | def set_input(self, input): 16 | self.input = input 17 | 18 | def forward(self): 19 | pass 20 | 21 | # used in test time, no backprop 22 | def test(self): 23 | pass 24 | 25 | def get_image_paths(self): 26 | pass 27 | 28 | def optimize_parameters(self): 29 | pass 30 | 31 | def get_current_visuals(self): 32 | return self.input 33 | 34 | def get_current_errors(self): 35 | return {} 36 | 37 | def save(self, label): 38 | pass 39 | 40 | # helper saving function that can be used by subclasses 41 | def save_network(self, network, network_label, epoch_label, gpu_ids): 42 | save_filename = '_%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | torch.save(network.cpu().state_dict(), save_path) 45 | if len(gpu_ids) and torch.cuda.is_available(): 46 | network.cuda(device_id=gpu_ids[0]) 47 | 48 | # helper loading function that can be used by subclasses 49 | def load_network(self, network, network_label, epoch_label): 50 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 51 | save_path = os.path.join(self.save_dir, save_filename) 52 | print(save_path) 53 | model = torch.load(save_path) 54 | return model 55 | # network.load_state_dict(torch.load(save_path)) 56 | 57 | def update_learning_rate(): 58 | pass 59 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 13 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 14 | # Dropout and Batchnorm has different behavioir during training and test. 15 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 16 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 17 | # rewrite devalue values 18 | parser.set_defaults(model='pix2pix4depth') 19 | # To avoid cropping, the load_size should be the same as crop_size 20 | parser.set_defaults(load_size=parser.get_default('crop_size')) 21 | self.isTrain = False 22 | return parser 23 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /src/controlnet_aux/leres/pix2pix/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint16): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array 22 | image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) # 23 | else: # if it is a numpy array, do nothing 24 | image_numpy = input_image 25 | return image_numpy.astype(imtype) 26 | 27 | 28 | def diagnose_network(net, name='network'): 29 | """Calculate and print the mean of average absolute(gradients) 30 | 31 | Parameters: 32 | net (torch network) -- Torch network 33 | name (str) -- the name of the network 34 | """ 35 | mean = 0.0 36 | count = 0 37 | for param in net.parameters(): 38 | if param.grad is not None: 39 | mean += torch.mean(torch.abs(param.grad.data)) 40 | count += 1 41 | if count > 0: 42 | mean = mean / count 43 | print(name) 44 | print(mean) 45 | 46 | 47 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 48 | """Save a numpy image to the disk 49 | 50 | Parameters: 51 | image_numpy (numpy array) -- input numpy array 52 | image_path (str) -- the path of the image 53 | """ 54 | image_pil = Image.fromarray(image_numpy) 55 | 56 | image_pil = image_pil.convert('I;16') 57 | 58 | # image_pil = Image.fromarray(image_numpy) 59 | # h, w, _ = image_numpy.shape 60 | # 61 | # if aspect_ratio > 1.0: 62 | # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 63 | # if aspect_ratio < 1.0: 64 | # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 65 | 66 | image_pil.save(image_path) 67 | 68 | 69 | def print_numpy(x, val=True, shp=False): 70 | """Print the mean, min, max, median, std, and size of a numpy array 71 | 72 | Parameters: 73 | val (bool) -- if print the values of the numpy array 74 | shp (bool) -- if print the shape of the numpy array 75 | """ 76 | x = x.astype(np.float64) 77 | if shp: 78 | print('shape,', x.shape) 79 | if val: 80 | x = x.flatten() 81 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 82 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 83 | 84 | 85 | def mkdirs(paths): 86 | """create empty directories if they don't exist 87 | 88 | Parameters: 89 | paths (str list) -- a list of directory paths 90 | """ 91 | if isinstance(paths, list) and not isinstance(paths, str): 92 | for path in paths: 93 | mkdir(path) 94 | else: 95 | mkdir(paths) 96 | 97 | 98 | def mkdir(path): 99 | """create a single empty directory if it didn't exist 100 | 101 | Parameters: 102 | path (str) -- a single directory path 103 | """ 104 | if not os.path.exists(path): 105 | os.makedirs(path) 106 | -------------------------------------------------------------------------------- /src/controlnet_aux/lineart/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Caroline Chan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/lineart_anime/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Caroline Chan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/lineart_standard/__init__.py: -------------------------------------------------------------------------------- 1 | # Code based based from the repository comfyui_controlnet_aux: 2 | # https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/controlnet_aux/lineart_standard/__init__.py 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from ..util import HWC3, resize_image 8 | 9 | 10 | class LineartStandardDetector: 11 | def __call__( 12 | self, 13 | input_image=None, 14 | guassian_sigma=6.0, 15 | intensity_threshold=8, 16 | detect_resolution=512, 17 | output_type="pil", 18 | ): 19 | if not isinstance(input_image, np.ndarray): 20 | input_image = np.array(input_image, dtype=np.uint8) 21 | else: 22 | output_type = output_type or "np" 23 | 24 | original_height, original_width, _ = input_image.shape 25 | 26 | input_image = HWC3(input_image) 27 | input_image = resize_image(input_image, detect_resolution) 28 | 29 | x = input_image.astype(np.float32) 30 | g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) 31 | intensity = np.min(g - x, axis=2).clip(0, 255) 32 | intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) 33 | intensity *= 127 34 | detected_map = intensity.clip(0, 255).astype(np.uint8) 35 | 36 | detected_map = HWC3(detected_map) 37 | 38 | detected_map = cv2.resize( 39 | detected_map, 40 | (original_width, original_height), 41 | interpolation=cv2.INTER_CUBIC, 42 | ) 43 | 44 | if output_type == "pil": 45 | detected_map = Image.fromarray(detected_map) 46 | 47 | return detected_map 48 | -------------------------------------------------------------------------------- /src/controlnet_aux/mediapipe_face/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from ..util import HWC3, resize_image 9 | from .mediapipe_face_common import generate_annotation 10 | 11 | 12 | class MediapipeFaceDetector: 13 | def __call__(self, 14 | input_image: Union[np.ndarray, Image.Image] = None, 15 | max_faces: int = 1, 16 | min_confidence: float = 0.5, 17 | output_type: str = "pil", 18 | detect_resolution: int = 512, 19 | image_resolution: int = 512, 20 | **kwargs): 21 | 22 | if "image" in kwargs: 23 | warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) 24 | input_image = kwargs.pop("image") 25 | if input_image is None: 26 | raise ValueError("input_image must be defined.") 27 | 28 | if "return_pil" in kwargs: 29 | warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) 30 | output_type = "pil" if kwargs["return_pil"] else "np" 31 | if type(output_type) is bool: 32 | warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") 33 | if output_type: 34 | output_type = "pil" 35 | 36 | if not isinstance(input_image, np.ndarray): 37 | input_image = np.array(input_image, dtype=np.uint8) 38 | 39 | input_image = HWC3(input_image) 40 | input_image = resize_image(input_image, detect_resolution) 41 | 42 | detected_map = generate_annotation(input_image, max_faces, min_confidence) 43 | detected_map = HWC3(detected_map) 44 | 45 | img = resize_image(input_image, image_resolution) 46 | H, W, C = img.shape 47 | 48 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 49 | 50 | if output_type == "pil": 51 | detected_map = Image.fromarray(detected_map) 52 | 53 | return detected_map 54 | -------------------------------------------------------------------------------- /src/controlnet_aux/midas/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/controlnet_aux/midas/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | 10 | from ..util import HWC3, resize_image 11 | from .api import MiDaSInference 12 | 13 | 14 | class MidasDetector: 15 | def __init__(self, model): 16 | self.model = model 17 | 18 | @classmethod 19 | def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", filename=None, cache_dir=None, local_files_only=False): 20 | if pretrained_model_or_path == "lllyasviel/ControlNet": 21 | filename = filename or "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" 22 | else: 23 | filename = filename or "dpt_hybrid-midas-501f0c75.pt" 24 | 25 | if os.path.isdir(pretrained_model_or_path): 26 | model_path = os.path.join(pretrained_model_or_path, filename) 27 | else: 28 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 29 | 30 | model = MiDaSInference(model_type=model_type, model_path=model_path) 31 | 32 | return cls(model) 33 | 34 | 35 | def to(self, device): 36 | self.model.to(device) 37 | return self 38 | 39 | def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, image_resolution=512, output_type=None): 40 | device = next(iter(self.model.parameters())).device 41 | if not isinstance(input_image, np.ndarray): 42 | input_image = np.array(input_image, dtype=np.uint8) 43 | output_type = output_type or "pil" 44 | else: 45 | output_type = output_type or "np" 46 | 47 | input_image = HWC3(input_image) 48 | input_image = resize_image(input_image, detect_resolution) 49 | 50 | assert input_image.ndim == 3 51 | image_depth = input_image 52 | with torch.no_grad(): 53 | image_depth = torch.from_numpy(image_depth).float() 54 | image_depth = image_depth.to(device) 55 | image_depth = image_depth / 127.5 - 1.0 56 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 57 | depth = self.model(image_depth)[0] 58 | 59 | depth_pt = depth.clone() 60 | depth_pt -= torch.min(depth_pt) 61 | depth_pt /= torch.max(depth_pt) 62 | depth_pt = depth_pt.cpu().numpy() 63 | depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) 64 | 65 | if depth_and_normal: 66 | depth_np = depth.cpu().numpy() 67 | x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) 68 | y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) 69 | z = np.ones_like(x) * a 70 | x[depth_pt < bg_th] = 0 71 | y[depth_pt < bg_th] = 0 72 | normal = np.stack([x, y, z], axis=2) 73 | normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 74 | normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] 75 | 76 | depth_image = HWC3(depth_image) 77 | if depth_and_normal: 78 | normal_image = HWC3(normal_image) 79 | 80 | img = resize_image(input_image, image_resolution) 81 | H, W, C = img.shape 82 | 83 | depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR) 84 | if depth_and_normal: 85 | normal_image = cv2.resize(normal_image, (W, H), interpolation=cv2.INTER_LINEAR) 86 | 87 | if output_type == "pil": 88 | depth_image = Image.fromarray(depth_image) 89 | if depth_and_normal: 90 | normal_image = Image.fromarray(normal_image) 91 | 92 | if depth_and_normal: 93 | return depth_image, normal_image 94 | else: 95 | return depth_image 96 | -------------------------------------------------------------------------------- /src/controlnet_aux/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/midas/midas/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/controlnet_aux/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /src/controlnet_aux/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/controlnet_aux/mlsd/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | 10 | from ..util import HWC3, resize_image 11 | from .models.mbv2_mlsd_large import MobileV2_MLSD_Large 12 | from .utils import pred_lines 13 | 14 | 15 | class MLSDdetector: 16 | def __init__(self, model): 17 | self.model = model 18 | 19 | @classmethod 20 | def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): 21 | if pretrained_model_or_path == "lllyasviel/ControlNet": 22 | filename = filename or "annotator/ckpts/mlsd_large_512_fp32.pth" 23 | else: 24 | filename = filename or "mlsd_large_512_fp32.pth" 25 | 26 | if os.path.isdir(pretrained_model_or_path): 27 | model_path = os.path.join(pretrained_model_or_path, filename) 28 | else: 29 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 30 | 31 | model = MobileV2_MLSD_Large() 32 | model.load_state_dict(torch.load(model_path), strict=True) 33 | model.eval() 34 | 35 | return cls(model) 36 | 37 | def to(self, device): 38 | self.model.to(device) 39 | return self 40 | 41 | def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): 42 | if "return_pil" in kwargs: 43 | warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) 44 | output_type = "pil" if kwargs["return_pil"] else "np" 45 | if type(output_type) is bool: 46 | warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") 47 | if output_type: 48 | output_type = "pil" 49 | 50 | if not isinstance(input_image, np.ndarray): 51 | input_image = np.array(input_image, dtype=np.uint8) 52 | 53 | input_image = HWC3(input_image) 54 | input_image = resize_image(input_image, detect_resolution) 55 | 56 | assert input_image.ndim == 3 57 | img = input_image 58 | img_output = np.zeros_like(img) 59 | try: 60 | with torch.no_grad(): 61 | lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) 62 | for line in lines: 63 | x_start, y_start, x_end, y_end = [int(val) for val in line] 64 | cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) 65 | except Exception as e: 66 | pass 67 | 68 | detected_map = img_output[:, :, 0] 69 | detected_map = HWC3(detected_map) 70 | 71 | img = resize_image(input_image, image_resolution) 72 | H, W, C = img.shape 73 | 74 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 75 | 76 | if output_type == "pil": 77 | detected_map = Image.fromarray(detected_map) 78 | 79 | return detected_map 80 | -------------------------------------------------------------------------------- /src/controlnet_aux/mlsd/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/mlsd/models/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Caroline Chan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import warnings 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | from einops import rearrange 10 | from huggingface_hub import hf_hub_download 11 | from PIL import Image 12 | 13 | from ..util import HWC3, resize_image 14 | from .nets.NNET import NNET 15 | 16 | 17 | # load model 18 | def load_checkpoint(fpath, model): 19 | ckpt = torch.load(fpath, map_location='cpu')['model'] 20 | 21 | load_dict = {} 22 | for k, v in ckpt.items(): 23 | if k.startswith('module.'): 24 | k_ = k.replace('module.', '') 25 | load_dict[k_] = v 26 | else: 27 | load_dict[k] = v 28 | 29 | model.load_state_dict(load_dict) 30 | return model 31 | 32 | class NormalBaeDetector: 33 | def __init__(self, model): 34 | self.model = model 35 | self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 36 | 37 | @classmethod 38 | def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): 39 | filename = filename or "scannet.pt" 40 | 41 | if os.path.isdir(pretrained_model_or_path): 42 | model_path = os.path.join(pretrained_model_or_path, filename) 43 | else: 44 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 45 | 46 | args = types.SimpleNamespace() 47 | args.mode = 'client' 48 | args.architecture = 'BN' 49 | args.pretrained = 'scannet' 50 | args.sampling_ratio = 0.4 51 | args.importance_ratio = 0.7 52 | model = NNET(args) 53 | model = load_checkpoint(model_path, model) 54 | model.eval() 55 | 56 | return cls(model) 57 | 58 | def to(self, device): 59 | self.model.to(device) 60 | return self 61 | 62 | 63 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): 64 | if "return_pil" in kwargs: 65 | warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) 66 | output_type = "pil" if kwargs["return_pil"] else "np" 67 | if type(output_type) is bool: 68 | warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") 69 | if output_type: 70 | output_type = "pil" 71 | 72 | device = next(iter(self.model.parameters())).device 73 | if not isinstance(input_image, np.ndarray): 74 | input_image = np.array(input_image, dtype=np.uint8) 75 | 76 | input_image = HWC3(input_image) 77 | input_image = resize_image(input_image, detect_resolution) 78 | 79 | assert input_image.ndim == 3 80 | image_normal = input_image 81 | with torch.no_grad(): 82 | image_normal = torch.from_numpy(image_normal).float().to(device) 83 | image_normal = image_normal / 255.0 84 | image_normal = rearrange(image_normal, 'h w c -> 1 c h w') 85 | image_normal = self.norm(image_normal) 86 | 87 | normal = self.model(image_normal) 88 | normal = normal[0][-1][:, :3] 89 | # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 90 | # d = torch.maximum(d, torch.ones_like(d) * 1e-5) 91 | # normal /= d 92 | normal = ((normal + 1) * 0.5).clip(0, 1) 93 | 94 | normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() 95 | normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) 96 | 97 | detected_map = normal_image 98 | detected_map = HWC3(detected_map) 99 | 100 | img = resize_image(input_image, image_resolution) 101 | H, W, C = img.shape 102 | 103 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 104 | 105 | if output_type == "pil": 106 | detected_map = Image.fromarray(detected_map) 107 | 108 | return detected_map 109 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/NNET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .submodules.encoder import Encoder 6 | from .submodules.decoder import Decoder 7 | 8 | 9 | class NNET(nn.Module): 10 | def __init__(self, args): 11 | super(NNET, self).__init__() 12 | self.encoder = Encoder() 13 | self.decoder = Decoder(args) 14 | 15 | def get_1x_lr_params(self): # lr/10 learning rate 16 | return self.encoder.parameters() 17 | 18 | def get_10x_lr_params(self): # lr learning rate 19 | return self.decoder.parameters() 20 | 21 | def forward(self, img, **kwargs): 22 | return self.decoder(self.encoder(img), **kwargs) -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/normalbae/nets/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .submodules.submodules import UpSampleBN, norm_normalize 6 | 7 | 8 | # This is the baseline encoder-decoder we used in the ablation study 9 | class NNET(nn.Module): 10 | def __init__(self, args=None): 11 | super(NNET, self).__init__() 12 | self.encoder = Encoder() 13 | self.decoder = Decoder(num_classes=4) 14 | 15 | def forward(self, x, **kwargs): 16 | out = self.decoder(self.encoder(x), **kwargs) 17 | 18 | # Bilinearly upsample the output to match the input resolution 19 | up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False) 20 | 21 | # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) 22 | up_out = norm_normalize(up_out) 23 | return up_out 24 | 25 | def get_1x_lr_params(self): # lr/10 learning rate 26 | return self.encoder.parameters() 27 | 28 | def get_10x_lr_params(self): # lr learning rate 29 | modules = [self.decoder] 30 | for m in modules: 31 | yield from m.parameters() 32 | 33 | 34 | # Encoder 35 | class Encoder(nn.Module): 36 | def __init__(self): 37 | super(Encoder, self).__init__() 38 | 39 | basemodel_name = 'tf_efficientnet_b5_ap' 40 | basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) 41 | 42 | # Remove last layer 43 | basemodel.global_pool = nn.Identity() 44 | basemodel.classifier = nn.Identity() 45 | 46 | self.original_model = basemodel 47 | 48 | def forward(self, x): 49 | features = [x] 50 | for k, v in self.original_model._modules.items(): 51 | if (k == 'blocks'): 52 | for ki, vi in v._modules.items(): 53 | features.append(vi(features[-1])) 54 | else: 55 | features.append(v(features[-1])) 56 | return features 57 | 58 | 59 | # Decoder (no pixel-wise MLP, no uncertainty-guided sampling) 60 | class Decoder(nn.Module): 61 | def __init__(self, num_classes=4): 62 | super(Decoder, self).__init__() 63 | self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) 64 | self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) 65 | self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) 66 | self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) 67 | self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) 68 | self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, features): 71 | x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] 72 | x_d0 = self.conv2(x_block4) 73 | x_d1 = self.up1(x_d0, x_block3) 74 | x_d2 = self.up2(x_d1, x_block2) 75 | x_d3 = self.up3(x_d2, x_block1) 76 | x_d4 = self.up4(x_d3, x_block0) 77 | out = self.conv3(x_d4) 78 | return out 79 | 80 | 81 | if __name__ == '__main__': 82 | model = Baseline() 83 | x = torch.rand(2, 3, 480, 640) 84 | out = model(x) 85 | print(out.shape) 86 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/normalbae/nets/submodules/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # pytorch stuff 104 | *.pth 105 | *.onnx 106 | *.pb 107 | 108 | trained_models/ 109 | .fuse_hidden* 110 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py: -------------------------------------------------------------------------------- 1 | """ Caffe2 validation script 2 | 3 | This script runs Caffe2 benchmark on exported ONNX model. 4 | It is a useful tool for reporting model FLOPS. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | import argparse 9 | from caffe2.python import core, workspace, model_helper 10 | from caffe2.proto import caffe2_pb2 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') 14 | parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', 15 | help='caffe2 model pb name prefix') 16 | parser.add_argument('--c2-init', default='', type=str, metavar='PATH', 17 | help='caffe2 model init .pb') 18 | parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', 19 | help='caffe2 model predict .pb') 20 | parser.add_argument('-b', '--batch-size', default=1, type=int, 21 | metavar='N', help='mini-batch size (default: 1)') 22 | parser.add_argument('--img-size', default=224, type=int, 23 | metavar='N', help='Input image dimension, uses model default if empty') 24 | 25 | 26 | def main(): 27 | args = parser.parse_args() 28 | args.gpu_id = 0 29 | if args.c2_prefix: 30 | args.c2_init = args.c2_prefix + '.init.pb' 31 | args.c2_predict = args.c2_prefix + '.predict.pb' 32 | 33 | model = model_helper.ModelHelper(name="le_net", init_params=False) 34 | 35 | # Bring in the init net from init_net.pb 36 | init_net_proto = caffe2_pb2.NetDef() 37 | with open(args.c2_init, "rb") as f: 38 | init_net_proto.ParseFromString(f.read()) 39 | model.param_init_net = core.Net(init_net_proto) 40 | 41 | # bring in the predict net from predict_net.pb 42 | predict_net_proto = caffe2_pb2.NetDef() 43 | with open(args.c2_predict, "rb") as f: 44 | predict_net_proto.ParseFromString(f.read()) 45 | model.net = core.Net(predict_net_proto) 46 | 47 | # CUDA performance not impressive 48 | #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) 49 | #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 50 | #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 51 | 52 | input_blob = model.net.external_inputs[0] 53 | model.param_init_net.GaussianFill( 54 | [], 55 | input_blob.GetUnscopedName(), 56 | shape=(args.batch_size, 3, args.img_size, args.img_size), 57 | mean=0.0, 58 | std=1.0) 59 | workspace.RunNetOnce(model.param_init_net) 60 | workspace.CreateNet(model.net, overwrite=True) 61 | workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_efficientnet import * 2 | from .mobilenetv3 import * 3 | from .model_factory import create_model 4 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable 5 | from .activations import * -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from geffnet import config 2 | from geffnet.activations.activations_me import * 3 | from geffnet.activations.activations_jit import * 4 | from geffnet.activations.activations import * 5 | import torch 6 | 7 | _has_silu = 'silu' in dir(torch.nn.functional) 8 | 9 | _ACT_FN_DEFAULT = dict( 10 | silu=F.silu if _has_silu else swish, 11 | swish=F.silu if _has_silu else swish, 12 | mish=mish, 13 | relu=F.relu, 14 | relu6=F.relu6, 15 | sigmoid=sigmoid, 16 | tanh=tanh, 17 | hard_sigmoid=hard_sigmoid, 18 | hard_swish=hard_swish, 19 | ) 20 | 21 | _ACT_FN_JIT = dict( 22 | silu=F.silu if _has_silu else swish_jit, 23 | swish=F.silu if _has_silu else swish_jit, 24 | mish=mish_jit, 25 | ) 26 | 27 | _ACT_FN_ME = dict( 28 | silu=F.silu if _has_silu else swish_me, 29 | swish=F.silu if _has_silu else swish_me, 30 | mish=mish_me, 31 | hard_swish=hard_swish_me, 32 | hard_sigmoid_jit=hard_sigmoid_me, 33 | ) 34 | 35 | _ACT_LAYER_DEFAULT = dict( 36 | silu=nn.SiLU if _has_silu else Swish, 37 | swish=nn.SiLU if _has_silu else Swish, 38 | mish=Mish, 39 | relu=nn.ReLU, 40 | relu6=nn.ReLU6, 41 | sigmoid=Sigmoid, 42 | tanh=Tanh, 43 | hard_sigmoid=HardSigmoid, 44 | hard_swish=HardSwish, 45 | ) 46 | 47 | _ACT_LAYER_JIT = dict( 48 | silu=nn.SiLU if _has_silu else SwishJit, 49 | swish=nn.SiLU if _has_silu else SwishJit, 50 | mish=MishJit, 51 | ) 52 | 53 | _ACT_LAYER_ME = dict( 54 | silu=nn.SiLU if _has_silu else SwishMe, 55 | swish=nn.SiLU if _has_silu else SwishMe, 56 | mish=MishMe, 57 | hard_swish=HardSwishMe, 58 | hard_sigmoid=HardSigmoidMe 59 | ) 60 | 61 | _OVERRIDE_FN = dict() 62 | _OVERRIDE_LAYER = dict() 63 | 64 | 65 | def add_override_act_fn(name, fn): 66 | global _OVERRIDE_FN 67 | _OVERRIDE_FN[name] = fn 68 | 69 | 70 | def update_override_act_fn(overrides): 71 | assert isinstance(overrides, dict) 72 | global _OVERRIDE_FN 73 | _OVERRIDE_FN.update(overrides) 74 | 75 | 76 | def clear_override_act_fn(): 77 | global _OVERRIDE_FN 78 | _OVERRIDE_FN = dict() 79 | 80 | 81 | def add_override_act_layer(name, fn): 82 | _OVERRIDE_LAYER[name] = fn 83 | 84 | 85 | def update_override_act_layer(overrides): 86 | assert isinstance(overrides, dict) 87 | global _OVERRIDE_LAYER 88 | _OVERRIDE_LAYER.update(overrides) 89 | 90 | 91 | def clear_override_act_layer(): 92 | global _OVERRIDE_LAYER 93 | _OVERRIDE_LAYER = dict() 94 | 95 | 96 | def get_act_fn(name='relu'): 97 | """ Activation Function Factory 98 | Fetching activation fns by name with this function allows export or torch script friendly 99 | functions to be returned dynamically based on current config. 100 | """ 101 | if name in _OVERRIDE_FN: 102 | return _OVERRIDE_FN[name] 103 | use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) 104 | if use_me and name in _ACT_FN_ME: 105 | # If not exporting or scripting the model, first look for a memory optimized version 106 | # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin 107 | return _ACT_FN_ME[name] 108 | if config.is_exportable() and name in ('silu', 'swish'): 109 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 110 | return swish 111 | use_jit = not (config.is_exportable() or config.is_no_jit()) 112 | # NOTE: export tracing should work with jit scripted components, but I keep running into issues 113 | if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting 114 | return _ACT_FN_JIT[name] 115 | return _ACT_FN_DEFAULT[name] 116 | 117 | 118 | def get_act_layer(name='relu'): 119 | """ Activation Layer Factory 120 | Fetching activation layers by name with this function allows export or torch script friendly 121 | functions to be returned dynamically based on current config. 122 | """ 123 | if name in _OVERRIDE_LAYER: 124 | return _OVERRIDE_LAYER[name] 125 | use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) 126 | if use_me and name in _ACT_LAYER_ME: 127 | return _ACT_LAYER_ME[name] 128 | if config.is_exportable() and name in ('silu', 'swish'): 129 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 130 | return Swish 131 | use_jit = not (config.is_exportable() or config.is_no_jit()) 132 | # NOTE: export tracing should work with jit scripted components, but I keep running into issues 133 | if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting 134 | return _ACT_LAYER_JIT[name] 135 | return _ACT_LAYER_DEFAULT[name] 136 | 137 | 138 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | def swish(x, inplace: bool = False): 13 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 14 | and also as Swish (https://arxiv.org/abs/1710.05941). 15 | 16 | TODO Rename to SiLU with addition to PyTorch 17 | """ 18 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 19 | 20 | 21 | class Swish(nn.Module): 22 | def __init__(self, inplace: bool = False): 23 | super(Swish, self).__init__() 24 | self.inplace = inplace 25 | 26 | def forward(self, x): 27 | return swish(x, self.inplace) 28 | 29 | 30 | def mish(x, inplace: bool = False): 31 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | def __init__(self, inplace: bool = False): 38 | super(Mish, self).__init__() 39 | self.inplace = inplace 40 | 41 | def forward(self, x): 42 | return mish(x, self.inplace) 43 | 44 | 45 | def sigmoid(x, inplace: bool = False): 46 | return x.sigmoid_() if inplace else x.sigmoid() 47 | 48 | 49 | # PyTorch has this, but not with a consistent inplace argmument interface 50 | class Sigmoid(nn.Module): 51 | def __init__(self, inplace: bool = False): 52 | super(Sigmoid, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, x): 56 | return x.sigmoid_() if self.inplace else x.sigmoid() 57 | 58 | 59 | def tanh(x, inplace: bool = False): 60 | return x.tanh_() if inplace else x.tanh() 61 | 62 | 63 | # PyTorch has this, but not with a consistent inplace argmument interface 64 | class Tanh(nn.Module): 65 | def __init__(self, inplace: bool = False): 66 | super(Tanh, self).__init__() 67 | self.inplace = inplace 68 | 69 | def forward(self, x): 70 | return x.tanh_() if self.inplace else x.tanh() 71 | 72 | 73 | def hard_swish(x, inplace: bool = False): 74 | inner = F.relu6(x + 3.).div_(6.) 75 | return x.mul_(inner) if inplace else x.mul(inner) 76 | 77 | 78 | class HardSwish(nn.Module): 79 | def __init__(self, inplace: bool = False): 80 | super(HardSwish, self).__init__() 81 | self.inplace = inplace 82 | 83 | def forward(self, x): 84 | return hard_swish(x, self.inplace) 85 | 86 | 87 | def hard_sigmoid(x, inplace: bool = False): 88 | if inplace: 89 | return x.add_(3.).clamp_(0., 6.).div_(6.) 90 | else: 91 | return F.relu6(x + 3.) / 6. 92 | 93 | 94 | class HardSigmoid(nn.Module): 95 | def __init__(self, inplace: bool = False): 96 | super(HardSigmoid, self).__init__() 97 | self.inplace = inplace 98 | 99 | def forward(self, x): 100 | return hard_sigmoid(x, self.inplace) 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations (jit) 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', 18 | 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit(x, inplace: bool = False): 23 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 24 | and also as Swish (https://arxiv.org/abs/1710.05941). 25 | 26 | TODO Rename to SiLU with addition to PyTorch 27 | """ 28 | return x.mul(x.sigmoid()) 29 | 30 | 31 | @torch.jit.script 32 | def mish_jit(x, _inplace: bool = False): 33 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 34 | """ 35 | return x.mul(F.softplus(x).tanh()) 36 | 37 | 38 | class SwishJit(nn.Module): 39 | def __init__(self, inplace: bool = False): 40 | super(SwishJit, self).__init__() 41 | 42 | def forward(self, x): 43 | return swish_jit(x) 44 | 45 | 46 | class MishJit(nn.Module): 47 | def __init__(self, inplace: bool = False): 48 | super(MishJit, self).__init__() 49 | 50 | def forward(self, x): 51 | return mish_jit(x) 52 | 53 | 54 | @torch.jit.script 55 | def hard_sigmoid_jit(x, inplace: bool = False): 56 | # return F.relu6(x + 3.) / 6. 57 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 58 | 59 | 60 | class HardSigmoidJit(nn.Module): 61 | def __init__(self, inplace: bool = False): 62 | super(HardSigmoidJit, self).__init__() 63 | 64 | def forward(self, x): 65 | return hard_sigmoid_jit(x) 66 | 67 | 68 | @torch.jit.script 69 | def hard_swish_jit(x, inplace: bool = False): 70 | # return x * (F.relu6(x + 3.) / 6) 71 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 72 | 73 | 74 | class HardSwishJit(nn.Module): 75 | def __init__(self, inplace: bool = False): 76 | super(HardSwishJit, self).__init__() 77 | 78 | def forward(self, x): 79 | return hard_swish_jit(x) 80 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py: -------------------------------------------------------------------------------- 1 | """ Activations (memory-efficient w/ custom autograd) 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either 7 | the JIT or basic versions of the activations. 8 | 9 | Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | 17 | __all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', 18 | 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit_fwd(x): 23 | return x.mul(torch.sigmoid(x)) 24 | 25 | 26 | @torch.jit.script 27 | def swish_jit_bwd(x, grad_output): 28 | x_sigmoid = torch.sigmoid(x) 29 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 30 | 31 | 32 | class SwishJitAutoFn(torch.autograd.Function): 33 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint 34 | Inspired by conversation btw Jeremy Howard & Adam Pazske 35 | https://twitter.com/jeremyphoward/status/1188251041835315200 36 | 37 | Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 38 | and also as Swish (https://arxiv.org/abs/1710.05941). 39 | 40 | TODO Rename to SiLU with addition to PyTorch 41 | """ 42 | 43 | @staticmethod 44 | def forward(ctx, x): 45 | ctx.save_for_backward(x) 46 | return swish_jit_fwd(x) 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | x = ctx.saved_tensors[0] 51 | return swish_jit_bwd(x, grad_output) 52 | 53 | 54 | def swish_me(x, inplace=False): 55 | return SwishJitAutoFn.apply(x) 56 | 57 | 58 | class SwishMe(nn.Module): 59 | def __init__(self, inplace: bool = False): 60 | super(SwishMe, self).__init__() 61 | 62 | def forward(self, x): 63 | return SwishJitAutoFn.apply(x) 64 | 65 | 66 | @torch.jit.script 67 | def mish_jit_fwd(x): 68 | return x.mul(torch.tanh(F.softplus(x))) 69 | 70 | 71 | @torch.jit.script 72 | def mish_jit_bwd(x, grad_output): 73 | x_sigmoid = torch.sigmoid(x) 74 | x_tanh_sp = F.softplus(x).tanh() 75 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 76 | 77 | 78 | class MishJitAutoFn(torch.autograd.Function): 79 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 80 | A memory efficient, jit scripted variant of Mish 81 | """ 82 | @staticmethod 83 | def forward(ctx, x): 84 | ctx.save_for_backward(x) 85 | return mish_jit_fwd(x) 86 | 87 | @staticmethod 88 | def backward(ctx, grad_output): 89 | x = ctx.saved_tensors[0] 90 | return mish_jit_bwd(x, grad_output) 91 | 92 | 93 | def mish_me(x, inplace=False): 94 | return MishJitAutoFn.apply(x) 95 | 96 | 97 | class MishMe(nn.Module): 98 | def __init__(self, inplace: bool = False): 99 | super(MishMe, self).__init__() 100 | 101 | def forward(self, x): 102 | return MishJitAutoFn.apply(x) 103 | 104 | 105 | @torch.jit.script 106 | def hard_sigmoid_jit_fwd(x, inplace: bool = False): 107 | return (x + 3).clamp(min=0, max=6).div(6.) 108 | 109 | 110 | @torch.jit.script 111 | def hard_sigmoid_jit_bwd(x, grad_output): 112 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. 113 | return grad_output * m 114 | 115 | 116 | class HardSigmoidJitAutoFn(torch.autograd.Function): 117 | @staticmethod 118 | def forward(ctx, x): 119 | ctx.save_for_backward(x) 120 | return hard_sigmoid_jit_fwd(x) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | x = ctx.saved_tensors[0] 125 | return hard_sigmoid_jit_bwd(x, grad_output) 126 | 127 | 128 | def hard_sigmoid_me(x, inplace: bool = False): 129 | return HardSigmoidJitAutoFn.apply(x) 130 | 131 | 132 | class HardSigmoidMe(nn.Module): 133 | def __init__(self, inplace: bool = False): 134 | super(HardSigmoidMe, self).__init__() 135 | 136 | def forward(self, x): 137 | return HardSigmoidJitAutoFn.apply(x) 138 | 139 | 140 | @torch.jit.script 141 | def hard_swish_jit_fwd(x): 142 | return x * (x + 3).clamp(min=0, max=6).div(6.) 143 | 144 | 145 | @torch.jit.script 146 | def hard_swish_jit_bwd(x, grad_output): 147 | m = torch.ones_like(x) * (x >= 3.) 148 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) 149 | return grad_output * m 150 | 151 | 152 | class HardSwishJitAutoFn(torch.autograd.Function): 153 | """A memory efficient, jit-scripted HardSwish activation""" 154 | @staticmethod 155 | def forward(ctx, x): 156 | ctx.save_for_backward(x) 157 | return hard_swish_jit_fwd(x) 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | x = ctx.saved_tensors[0] 162 | return hard_swish_jit_bwd(x, grad_output) 163 | 164 | 165 | def hard_swish_me(x, inplace=False): 166 | return HardSwishJitAutoFn.apply(x) 167 | 168 | 169 | class HardSwishMe(nn.Module): 170 | def __init__(self, inplace: bool = False): 171 | super(HardSwishMe, self).__init__() 172 | 173 | def forward(self, x): 174 | return HardSwishJitAutoFn.apply(x) 175 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py: -------------------------------------------------------------------------------- 1 | """ Global layer config state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | 117 | 118 | def layer_config_kwargs(kwargs): 119 | """ Consume config kwargs and return contextmgr obj """ 120 | return set_layer_config( 121 | scriptable=kwargs.pop('scriptable', None), 122 | exportable=kwargs.pop('exportable', None), 123 | no_jit=kwargs.pop('no_jit', None)) 124 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint loading / state_dict helpers 2 | Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | import os 6 | from collections import OrderedDict 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | 13 | def load_checkpoint(model, checkpoint_path): 14 | if checkpoint_path and os.path.isfile(checkpoint_path): 15 | print("=> Loading checkpoint '{}'".format(checkpoint_path)) 16 | checkpoint = torch.load(checkpoint_path) 17 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 18 | new_state_dict = OrderedDict() 19 | for k, v in checkpoint['state_dict'].items(): 20 | if k.startswith('module'): 21 | name = k[7:] # remove `module.` 22 | else: 23 | name = k 24 | new_state_dict[name] = v 25 | model.load_state_dict(new_state_dict) 26 | else: 27 | model.load_state_dict(checkpoint) 28 | print("=> Loaded checkpoint '{}'".format(checkpoint_path)) 29 | else: 30 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) 31 | raise FileNotFoundError() 32 | 33 | 34 | def load_pretrained(model, url, filter_fn=None, strict=True): 35 | if not url: 36 | print("=> Warning: Pretrained model URL is empty, using random initialization.") 37 | return 38 | 39 | state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') 40 | 41 | input_conv = 'conv_stem' 42 | classifier = 'classifier' 43 | in_chans = getattr(model, input_conv).weight.shape[1] 44 | num_classes = getattr(model, classifier).weight.shape[0] 45 | 46 | input_conv_weight = input_conv + '.weight' 47 | pretrained_in_chans = state_dict[input_conv_weight].shape[1] 48 | if in_chans != pretrained_in_chans: 49 | if in_chans == 1: 50 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format( 51 | input_conv_weight, pretrained_in_chans)) 52 | conv1_weight = state_dict[input_conv_weight] 53 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) 54 | else: 55 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format( 56 | input_conv_weight, pretrained_in_chans)) 57 | del state_dict[input_conv_weight] 58 | strict = False 59 | 60 | classifier_weight = classifier + '.weight' 61 | pretrained_num_classes = state_dict[classifier_weight].shape[0] 62 | if num_classes != pretrained_num_classes: 63 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) 64 | del state_dict[classifier_weight] 65 | del state_dict[classifier + '.bias'] 66 | strict = False 67 | 68 | if filter_fn is not None: 69 | state_dict = filter_fn(state_dict) 70 | 71 | model.load_state_dict(state_dict, strict=strict) 72 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py: -------------------------------------------------------------------------------- 1 | from .config import set_layer_config 2 | from .helpers import load_checkpoint 3 | 4 | from .gen_efficientnet import * 5 | from .mobilenetv3 import * 6 | 7 | 8 | def create_model( 9 | model_name='mnasnet_100', 10 | pretrained=None, 11 | num_classes=1000, 12 | in_chans=3, 13 | checkpoint_path='', 14 | **kwargs): 15 | 16 | model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) 17 | 18 | if model_name in globals(): 19 | create_fn = globals()[model_name] 20 | model = create_fn(**model_kwargs) 21 | else: 22 | raise RuntimeError('Unknown model (%s)' % model_name) 23 | 24 | if checkpoint_path and not pretrained: 25 | load_checkpoint(model, checkpoint_path) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.2' 2 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'math'] 2 | 3 | from geffnet import efficientnet_b0 4 | from geffnet import efficientnet_b1 5 | from geffnet import efficientnet_b2 6 | from geffnet import efficientnet_b3 7 | 8 | from geffnet import efficientnet_es 9 | 10 | from geffnet import efficientnet_lite0 11 | 12 | from geffnet import mixnet_s 13 | from geffnet import mixnet_m 14 | from geffnet import mixnet_l 15 | from geffnet import mixnet_xl 16 | 17 | from geffnet import mobilenetv2_100 18 | from geffnet import mobilenetv2_110d 19 | from geffnet import mobilenetv2_120d 20 | from geffnet import mobilenetv2_140 21 | 22 | from geffnet import mobilenetv3_large_100 23 | from geffnet import mobilenetv3_rw 24 | from geffnet import mnasnet_a1 25 | from geffnet import mnasnet_b1 26 | from geffnet import fbnetc_100 27 | from geffnet import spnasnet_100 28 | 29 | from geffnet import tf_efficientnet_b0 30 | from geffnet import tf_efficientnet_b1 31 | from geffnet import tf_efficientnet_b2 32 | from geffnet import tf_efficientnet_b3 33 | from geffnet import tf_efficientnet_b4 34 | from geffnet import tf_efficientnet_b5 35 | from geffnet import tf_efficientnet_b6 36 | from geffnet import tf_efficientnet_b7 37 | from geffnet import tf_efficientnet_b8 38 | 39 | from geffnet import tf_efficientnet_b0_ap 40 | from geffnet import tf_efficientnet_b1_ap 41 | from geffnet import tf_efficientnet_b2_ap 42 | from geffnet import tf_efficientnet_b3_ap 43 | from geffnet import tf_efficientnet_b4_ap 44 | from geffnet import tf_efficientnet_b5_ap 45 | from geffnet import tf_efficientnet_b6_ap 46 | from geffnet import tf_efficientnet_b7_ap 47 | from geffnet import tf_efficientnet_b8_ap 48 | 49 | from geffnet import tf_efficientnet_b0_ns 50 | from geffnet import tf_efficientnet_b1_ns 51 | from geffnet import tf_efficientnet_b2_ns 52 | from geffnet import tf_efficientnet_b3_ns 53 | from geffnet import tf_efficientnet_b4_ns 54 | from geffnet import tf_efficientnet_b5_ns 55 | from geffnet import tf_efficientnet_b6_ns 56 | from geffnet import tf_efficientnet_b7_ns 57 | from geffnet import tf_efficientnet_l2_ns_475 58 | from geffnet import tf_efficientnet_l2_ns 59 | 60 | from geffnet import tf_efficientnet_es 61 | from geffnet import tf_efficientnet_em 62 | from geffnet import tf_efficientnet_el 63 | 64 | from geffnet import tf_efficientnet_cc_b0_4e 65 | from geffnet import tf_efficientnet_cc_b0_8e 66 | from geffnet import tf_efficientnet_cc_b1_8e 67 | 68 | from geffnet import tf_efficientnet_lite0 69 | from geffnet import tf_efficientnet_lite1 70 | from geffnet import tf_efficientnet_lite2 71 | from geffnet import tf_efficientnet_lite3 72 | from geffnet import tf_efficientnet_lite4 73 | 74 | from geffnet import tf_mixnet_s 75 | from geffnet import tf_mixnet_m 76 | from geffnet import tf_mixnet_l 77 | 78 | from geffnet import tf_mobilenetv3_large_075 79 | from geffnet import tf_mobilenetv3_large_100 80 | from geffnet import tf_mobilenetv3_large_minimal_100 81 | from geffnet import tf_mobilenetv3_small_075 82 | from geffnet import tf_mobilenetv3_small_100 83 | from geffnet import tf_mobilenetv3_small_minimal_100 84 | 85 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py: -------------------------------------------------------------------------------- 1 | """ ONNX optimization script 2 | 3 | Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. 4 | 5 | NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), 6 | it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). 7 | 8 | Copyright 2020 Ross Wightman 9 | """ 10 | import argparse 11 | import warnings 12 | 13 | import onnx 14 | from onnx import optimizer 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Optimize ONNX model") 18 | 19 | parser.add_argument("model", help="The ONNX model") 20 | parser.add_argument("--output", required=True, help="The optimized model output filename") 21 | 22 | 23 | def traverse_graph(graph, prefix=''): 24 | content = [] 25 | indent = prefix + ' ' 26 | graphs = [] 27 | num_nodes = 0 28 | for node in graph.node: 29 | pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) 30 | assert isinstance(gs, list) 31 | content.append(pn) 32 | graphs.extend(gs) 33 | num_nodes += 1 34 | for g in graphs: 35 | g_count, g_str = traverse_graph(g) 36 | content.append('\n' + g_str) 37 | num_nodes += g_count 38 | return num_nodes, '\n'.join(content) 39 | 40 | 41 | def main(): 42 | args = parser.parse_args() 43 | onnx_model = onnx.load(args.model) 44 | num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) 45 | 46 | # Optimizer passes to perform 47 | passes = [ 48 | #'eliminate_deadend', 49 | 'eliminate_identity', 50 | 'eliminate_nop_dropout', 51 | 'eliminate_nop_pad', 52 | 'eliminate_nop_transpose', 53 | 'eliminate_unused_initializer', 54 | 'extract_constant_to_initializer', 55 | 'fuse_add_bias_into_conv', 56 | 'fuse_bn_into_conv', 57 | 'fuse_consecutive_concats', 58 | 'fuse_consecutive_reduce_unsqueeze', 59 | 'fuse_consecutive_squeezes', 60 | 'fuse_consecutive_transposes', 61 | #'fuse_matmul_add_bias_into_gemm', 62 | 'fuse_pad_into_conv', 63 | #'fuse_transpose_into_gemm', 64 | #'lift_lexical_references', 65 | ] 66 | 67 | # Apply the optimization on the original serialized model 68 | # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing 69 | # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 70 | # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. 71 | warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." 72 | "Try onnxruntime optimization if this doesn't work.") 73 | optimized_model = optimizer.optimize(onnx_model, passes) 74 | 75 | num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) 76 | print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) 77 | print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) 78 | 79 | # Save the ONNX model 80 | onnx.save(optimized_model, args.output) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import onnx 4 | from caffe2.python.onnx.backend import Caffe2Backend 5 | 6 | 7 | parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") 8 | 9 | parser.add_argument("model", help="The ONNX model") 10 | parser.add_argument("--c2-prefix", required=True, 11 | help="The output file prefix for the caffe2 model init and predict file. ") 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | onnx_model = onnx.load(args.model) 17 | caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) 18 | caffe2_init_str = caffe2_init.SerializeToString() 19 | with open(args.c2_prefix + '.init.pb', "wb") as f: 20 | f.write(caffe2_init_str) 21 | caffe2_predict_str = caffe2_predict.SerializeToString() 22 | with open(args.c2_prefix + '.predict.pb', "wb") as f: 23 | f.write(caffe2_predict_str) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py: -------------------------------------------------------------------------------- 1 | """ ONNX-runtime validation script 2 | 3 | This script was created to verify accuracy and performance of exported ONNX 4 | models running with the onnxruntime. It utilizes the PyTorch dataloader/processing 5 | pipeline for a fair comparison against the originals. 6 | 7 | Copyright 2020 Ross Wightman 8 | """ 9 | import argparse 10 | import numpy as np 11 | import onnxruntime 12 | from data import create_loader, resolve_data_config, Dataset 13 | from utils import AverageMeter 14 | import time 15 | 16 | parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') 17 | parser.add_argument('data', metavar='DIR', 18 | help='path to dataset') 19 | parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', 20 | help='path to onnx model/weights file') 21 | parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', 22 | help='path to output optimized onnx graph') 23 | parser.add_argument('--profile', action='store_true', default=False, 24 | help='Enable profiler output.') 25 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 26 | help='number of data loading workers (default: 2)') 27 | parser.add_argument('-b', '--batch-size', default=256, type=int, 28 | metavar='N', help='mini-batch size (default: 256)') 29 | parser.add_argument('--img-size', default=None, type=int, 30 | metavar='N', help='Input image dimension, uses model default if empty') 31 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 32 | help='Override mean pixel value of dataset') 33 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 34 | help='Override std deviation of of dataset') 35 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', 36 | help='Override default crop pct of 0.875') 37 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 38 | help='Image resize interpolation type (overrides model)') 39 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', 40 | help='use tensorflow mnasnet preporcessing') 41 | parser.add_argument('--print-freq', '-p', default=10, type=int, 42 | metavar='N', help='print frequency (default: 10)') 43 | 44 | 45 | def main(): 46 | args = parser.parse_args() 47 | args.gpu_id = 0 48 | 49 | # Set graph optimization level 50 | sess_options = onnxruntime.SessionOptions() 51 | sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 52 | if args.profile: 53 | sess_options.enable_profiling = True 54 | if args.onnx_output_opt: 55 | sess_options.optimized_model_filepath = args.onnx_output_opt 56 | 57 | session = onnxruntime.InferenceSession(args.onnx_input, sess_options) 58 | 59 | data_config = resolve_data_config(None, args) 60 | loader = create_loader( 61 | Dataset(args.data, load_bytes=args.tf_preprocessing), 62 | input_size=data_config['input_size'], 63 | batch_size=args.batch_size, 64 | use_prefetcher=False, 65 | interpolation=data_config['interpolation'], 66 | mean=data_config['mean'], 67 | std=data_config['std'], 68 | num_workers=args.workers, 69 | crop_pct=data_config['crop_pct'], 70 | tensorflow_preprocessing=args.tf_preprocessing) 71 | 72 | input_name = session.get_inputs()[0].name 73 | 74 | batch_time = AverageMeter() 75 | top1 = AverageMeter() 76 | top5 = AverageMeter() 77 | end = time.time() 78 | for i, (input, target) in enumerate(loader): 79 | # run the net and return prediction 80 | output = session.run([], {input_name: input.data.numpy()}) 81 | output = output[0] 82 | 83 | # measure accuracy and record loss 84 | prec1, prec5 = accuracy_np(output, target.numpy()) 85 | top1.update(prec1.item(), input.size(0)) 86 | top5.update(prec5.item(), input.size(0)) 87 | 88 | # measure elapsed time 89 | batch_time.update(time.time() - end) 90 | end = time.time() 91 | 92 | if i % args.print_freq == 0: 93 | print('Test: [{0}/{1}]\t' 94 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' 95 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 96 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 97 | i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, 98 | ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) 99 | 100 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( 101 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) 102 | 103 | 104 | def accuracy_np(output, target): 105 | max_indices = np.argsort(output, axis=1)[:, ::-1] 106 | top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() 107 | top1 = 100 * np.equal(max_indices[:, 0], target).mean() 108 | return top1, top5 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchvision>=0.4.0 3 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('geffnet/version.py').read()) 14 | setup( 15 | name='geffnet', 16 | version=__version__, 17 | description='(Generic) EfficientNets for PyTorch', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/gen-efficientnet-pytorch', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', 44 | packages=find_packages(exclude=['data']), 45 | install_requires=['torch >= 1.4', 'torchvision'], 46 | python_requires='>=3.6', 47 | ) 48 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class AverageMeter: 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | def accuracy(output, target, topk=(1,)): 23 | """Computes the precision@k for the specified values of k""" 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | def get_outdir(path, *paths, inc=False): 39 | outdir = os.path.join(path, *paths) 40 | if not os.path.exists(outdir): 41 | os.makedirs(outdir) 42 | elif inc: 43 | count = 1 44 | outdir_inc = outdir + '-' + str(count) 45 | while os.path.exists(outdir_inc): 46 | count = count + 1 47 | outdir_inc = outdir + '-' + str(count) 48 | assert count < 100 49 | outdir = outdir_inc 50 | os.makedirs(outdir) 51 | return outdir 52 | 53 | -------------------------------------------------------------------------------- /src/controlnet_aux/normalbae/nets/submodules/encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self): 9 | super(Encoder, self).__init__() 10 | 11 | basemodel_name = 'tf_efficientnet_b5_ap' 12 | print('Loading base model ()...'.format(basemodel_name), end='') 13 | repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') 14 | basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') 15 | print('Done.') 16 | 17 | # Remove last layer 18 | print('Removing last two layers (global_pool & classifier).') 19 | basemodel.global_pool = nn.Identity() 20 | basemodel.classifier = nn.Identity() 21 | 22 | self.original_model = basemodel 23 | 24 | def forward(self, x): 25 | features = [x] 26 | for k, v in self.original_model._modules.items(): 27 | if (k == 'blocks'): 28 | for ki, vi in v._modules.items(): 29 | features.append(vi(features[-1])) 30 | else: 31 | features.append(v(features[-1])) 32 | return features 33 | 34 | 35 | -------------------------------------------------------------------------------- /src/controlnet_aux/open_pose/hand.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from scipy.ndimage.filters import gaussian_filter 5 | from skimage.measure import label 6 | 7 | from . import util 8 | from .model import handpose_model 9 | 10 | 11 | class Hand(object): 12 | def __init__(self, model_path): 13 | self.model = handpose_model() 14 | model_dict = util.transfer(self.model, torch.load(model_path)) 15 | self.model.load_state_dict(model_dict) 16 | self.model.eval() 17 | 18 | def to(self, device): 19 | self.model.to(device) 20 | return self 21 | 22 | def __call__(self, oriImgRaw): 23 | device = next(iter(self.model.parameters())).device 24 | scale_search = [0.5, 1.0, 1.5, 2.0] 25 | # scale_search = [0.5] 26 | boxsize = 368 27 | stride = 8 28 | padValue = 128 29 | thre = 0.05 30 | multiplier = [x * boxsize for x in scale_search] 31 | 32 | wsize = 128 33 | heatmap_avg = np.zeros((wsize, wsize, 22)) 34 | 35 | Hr, Wr, Cr = oriImgRaw.shape 36 | 37 | oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) 38 | 39 | for m in range(len(multiplier)): 40 | scale = multiplier[m] 41 | imageToTest = util.smart_resize(oriImg, (scale, scale)) 42 | 43 | imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) 44 | im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 45 | im = np.ascontiguousarray(im) 46 | 47 | data = torch.from_numpy(im).float() 48 | data = data.to(device) 49 | 50 | with torch.no_grad(): 51 | output = self.model(data).cpu().numpy() 52 | 53 | # extract outputs, resize, and remove padding 54 | heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps 55 | heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) 56 | heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] 57 | heatmap = util.smart_resize(heatmap, (wsize, wsize)) 58 | 59 | heatmap_avg += heatmap / len(multiplier) 60 | 61 | all_peaks = [] 62 | for part in range(21): 63 | map_ori = heatmap_avg[:, :, part] 64 | one_heatmap = gaussian_filter(map_ori, sigma=3) 65 | binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) 66 | 67 | if np.sum(binary) == 0: 68 | all_peaks.append([0, 0]) 69 | continue 70 | label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) 71 | max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 72 | label_img[label_img != max_index] = 0 73 | map_ori[label_img == 0] = 0 74 | 75 | y, x = util.npmax(map_ori) 76 | y = int(float(y) * float(Hr) / float(wsize)) 77 | x = int(float(x) * float(Wr) / float(wsize)) 78 | all_peaks.append([x, y]) 79 | return np.array(all_peaks) 80 | 81 | if __name__ == "__main__": 82 | hand_estimation = Hand('../model/hand_pose_model.pth') 83 | 84 | # test_image = '../images/hand.jpg' 85 | test_image = '../images/hand.jpg' 86 | oriImg = cv2.imread(test_image) # B,G,R order 87 | peaks = hand_estimation(oriImg) 88 | canvas = util.draw_handpose(oriImg, peaks, True) 89 | cv2.imshow('', canvas) 90 | cv2.waitKey(0) -------------------------------------------------------------------------------- /src/controlnet_aux/pidi/LICENSE: -------------------------------------------------------------------------------- 1 | It is just for research purpose, and commercial use should be contacted with authors first. 2 | 3 | Copyright (c) 2021 Zhuo Su 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/pidi/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from huggingface_hub import hf_hub_download 9 | from PIL import Image 10 | 11 | from ..util import HWC3, nms, resize_image, safe_step 12 | from .model import pidinet 13 | 14 | 15 | class PidiNetDetector: 16 | def __init__(self, netNetwork): 17 | self.netNetwork = netNetwork 18 | 19 | @classmethod 20 | def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): 21 | filename = filename or "table5_pidinet.pth" 22 | 23 | if os.path.isdir(pretrained_model_or_path): 24 | model_path = os.path.join(pretrained_model_or_path, filename) 25 | else: 26 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 27 | 28 | netNetwork = pidinet() 29 | netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) 30 | netNetwork.eval() 31 | 32 | return cls(netNetwork) 33 | 34 | def to(self, device): 35 | self.netNetwork.to(device) 36 | return self 37 | 38 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, **kwargs): 39 | if "return_pil" in kwargs: 40 | warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) 41 | output_type = "pil" if kwargs["return_pil"] else "np" 42 | if type(output_type) is bool: 43 | warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") 44 | if output_type: 45 | output_type = "pil" 46 | 47 | device = next(iter(self.netNetwork.parameters())).device 48 | if not isinstance(input_image, np.ndarray): 49 | input_image = np.array(input_image, dtype=np.uint8) 50 | 51 | input_image = HWC3(input_image) 52 | input_image = resize_image(input_image, detect_resolution) 53 | assert input_image.ndim == 3 54 | input_image = input_image[:, :, ::-1].copy() 55 | with torch.no_grad(): 56 | image_pidi = torch.from_numpy(input_image).float().to(device) 57 | image_pidi = image_pidi / 255.0 58 | image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') 59 | edge = self.netNetwork(image_pidi)[-1] 60 | edge = edge.cpu().numpy() 61 | if apply_filter: 62 | edge = edge > 0.5 63 | if safe: 64 | edge = safe_step(edge) 65 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 66 | 67 | detected_map = edge[0, 0] 68 | detected_map = HWC3(detected_map) 69 | 70 | img = resize_image(input_image, image_resolution) 71 | H, W, C = img.shape 72 | 73 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 74 | 75 | if scribble: 76 | detected_map = nms(detected_map, 127, 3.0) 77 | detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) 78 | detected_map[detected_map > 4] = 255 79 | detected_map[detected_map < 255] = 0 80 | 81 | if output_type == "pil": 82 | detected_map = Image.fromarray(detected_map) 83 | 84 | return detected_map 85 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from typing import Union 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from huggingface_hub import hf_hub_download 15 | from PIL import Image 16 | 17 | from ..util import HWC3, resize_image 18 | from .automatic_mask_generator import SamAutomaticMaskGenerator 19 | from .build_sam import sam_model_registry 20 | 21 | 22 | class SamDetector: 23 | def __init__(self, mask_generator: SamAutomaticMaskGenerator): 24 | self.mask_generator = mask_generator 25 | 26 | @classmethod 27 | def from_pretrained(cls, pretrained_model_or_path, model_type="vit_h", filename="sam_vit_h_4b8939.pth", subfolder=None, cache_dir=None): 28 | """ 29 | Possible model_type : vit_h, vit_l, vit_b, vit_t 30 | download weights from https://github.com/facebookresearch/segment-anything 31 | """ 32 | if os.path.isdir(pretrained_model_or_path): 33 | model_path = os.path.join(pretrained_model_or_path, filename) 34 | else: 35 | model_path = hf_hub_download(pretrained_model_or_path, filename, subfolder=subfolder, cache_dir=cache_dir) 36 | 37 | sam = sam_model_registry[model_type](checkpoint=model_path) 38 | 39 | if torch.cuda.is_available(): 40 | sam.to("cuda") 41 | 42 | mask_generator = SamAutomaticMaskGenerator(sam) 43 | 44 | return cls(mask_generator) 45 | 46 | 47 | def show_anns(self, anns): 48 | if len(anns) == 0: 49 | return 50 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 51 | h, w = anns[0]['segmentation'].shape 52 | final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") 53 | for ann in sorted_anns: 54 | m = ann['segmentation'] 55 | img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) 56 | for i in range(3): 57 | img[:,:,i] = np.random.randint(255, dtype=np.uint8) 58 | final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255))) 59 | 60 | return np.array(final_img, dtype=np.uint8) 61 | 62 | def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> Image.Image: 63 | if "image" in kwargs: 64 | warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) 65 | input_image = kwargs.pop("image") 66 | 67 | if input_image is None: 68 | raise ValueError("input_image must be defined.") 69 | 70 | if not isinstance(input_image, np.ndarray): 71 | input_image = np.array(input_image, dtype=np.uint8) 72 | 73 | input_image = HWC3(input_image) 74 | input_image = resize_image(input_image, detect_resolution) 75 | 76 | # Generate Masks 77 | masks = self.mask_generator.generate(input_image) 78 | # Create map 79 | map = self.show_anns(masks) 80 | 81 | detected_map = map 82 | detected_map = HWC3(detected_map) 83 | 84 | img = resize_image(input_image, image_resolution) 85 | H, W, C = img.shape 86 | 87 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 88 | 89 | if output_type == "pil": 90 | detected_map = Image.fromarray(detected_map) 91 | 92 | return detected_map 93 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | def build_sam_vit_t(checkpoint=None): 48 | prompt_embed_dim = 256 49 | image_size = 1024 50 | vit_patch_size = 16 51 | image_embedding_size = image_size // vit_patch_size 52 | mobile_sam = Sam( 53 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 54 | embed_dims=[64, 128, 160, 320], 55 | depths=[2, 2, 6, 2], 56 | num_heads=[2, 4, 5, 10], 57 | window_sizes=[7, 7, 14, 7], 58 | mlp_ratio=4., 59 | drop_rate=0., 60 | drop_path_rate=0.0, 61 | use_checkpoint=False, 62 | mbconv_expand_ratio=4.0, 63 | local_conv_size=3, 64 | layer_lr_decay=0.8 65 | ), 66 | prompt_encoder=PromptEncoder( 67 | embed_dim=prompt_embed_dim, 68 | image_embedding_size=(image_embedding_size, image_embedding_size), 69 | input_image_size=(image_size, image_size), 70 | mask_in_chans=16, 71 | ), 72 | mask_decoder=MaskDecoder( 73 | num_multimask_outputs=3, 74 | transformer=TwoWayTransformer( 75 | depth=2, 76 | embedding_dim=prompt_embed_dim, 77 | mlp_dim=2048, 78 | num_heads=8, 79 | ), 80 | transformer_dim=prompt_embed_dim, 81 | iou_head_depth=3, 82 | iou_head_hidden_dim=256, 83 | ), 84 | pixel_mean=[123.675, 116.28, 103.53], 85 | pixel_std=[58.395, 57.12, 57.375], 86 | ) 87 | 88 | mobile_sam.eval() 89 | if checkpoint is not None: 90 | with open(checkpoint, "rb") as f: 91 | state_dict = torch.load(f) 92 | mobile_sam.load_state_dict(state_dict) 93 | return mobile_sam 94 | 95 | 96 | sam_model_registry = { 97 | "default": build_sam_vit_h, 98 | "vit_h": build_sam_vit_h, 99 | "vit_l": build_sam_vit_l, 100 | "vit_b": build_sam_vit_b, 101 | "vit_t": build_sam_vit_t, 102 | } 103 | 104 | 105 | def _build_sam( 106 | encoder_embed_dim, 107 | encoder_depth, 108 | encoder_num_heads, 109 | encoder_global_attn_indexes, 110 | checkpoint=None, 111 | ): 112 | prompt_embed_dim = 256 113 | image_size = 1024 114 | vit_patch_size = 16 115 | image_embedding_size = image_size // vit_patch_size 116 | sam = Sam( 117 | image_encoder=ImageEncoderViT( 118 | depth=encoder_depth, 119 | embed_dim=encoder_embed_dim, 120 | img_size=image_size, 121 | mlp_ratio=4, 122 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 123 | num_heads=encoder_num_heads, 124 | patch_size=vit_patch_size, 125 | qkv_bias=True, 126 | use_rel_pos=True, 127 | global_attn_indexes=encoder_global_attn_indexes, 128 | window_size=14, 129 | out_chans=prompt_embed_dim, 130 | ), 131 | prompt_encoder=PromptEncoder( 132 | embed_dim=prompt_embed_dim, 133 | image_embedding_size=(image_embedding_size, image_embedding_size), 134 | input_image_size=(image_size, image_size), 135 | mask_in_chans=16, 136 | ), 137 | mask_decoder=MaskDecoder( 138 | num_multimask_outputs=3, 139 | transformer=TwoWayTransformer( 140 | depth=2, 141 | embedding_dim=prompt_embed_dim, 142 | mlp_dim=2048, 143 | num_heads=8, 144 | ), 145 | transformer_dim=prompt_embed_dim, 146 | iou_head_depth=3, 147 | iou_head_hidden_dim=256, 148 | ), 149 | pixel_mean=[123.675, 116.28, 103.53], 150 | pixel_std=[58.395, 57.12, 57.375], 151 | ) 152 | sam.eval() 153 | if checkpoint is not None: 154 | with open(checkpoint, "rb") as f: 155 | state_dict = torch.load(f) 156 | sam.load_state_dict(state_dict) 157 | return sam 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT 13 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /src/controlnet_aux/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /src/controlnet_aux/shuffle/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from ..util import HWC3, img2mask, make_noise_disk, resize_image 8 | 9 | 10 | class ContentShuffleDetector: 11 | def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): 12 | if "return_pil" in kwargs: 13 | warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) 14 | output_type = "pil" if kwargs["return_pil"] else "np" 15 | if type(output_type) is bool: 16 | warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") 17 | if output_type: 18 | output_type = "pil" 19 | 20 | if not isinstance(input_image, np.ndarray): 21 | input_image = np.array(input_image, dtype=np.uint8) 22 | 23 | input_image = HWC3(input_image) 24 | input_image = resize_image(input_image, detect_resolution) 25 | 26 | H, W, C = input_image.shape 27 | if h is None: 28 | h = H 29 | if w is None: 30 | w = W 31 | if f is None: 32 | f = 256 33 | x = make_noise_disk(h, w, 1, f) * float(W - 1) 34 | y = make_noise_disk(h, w, 1, f) * float(H - 1) 35 | flow = np.concatenate([x, y], axis=2).astype(np.float32) 36 | detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR) 37 | 38 | img = resize_image(input_image, image_resolution) 39 | H, W, C = img.shape 40 | 41 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 42 | 43 | if output_type == "pil": 44 | detected_map = Image.fromarray(detected_map) 45 | 46 | return detected_map 47 | 48 | 49 | class ColorShuffleDetector: 50 | def __call__(self, img): 51 | H, W, C = img.shape 52 | F = np.random.randint(64, 384) 53 | A = make_noise_disk(H, W, 3, F) 54 | B = make_noise_disk(H, W, 3, F) 55 | C = (A + B) / 2.0 56 | A = (C + (A - C) * 3.0).clip(0, 1) 57 | B = (C + (B - C) * 3.0).clip(0, 1) 58 | L = img.astype(np.float32) / 255.0 59 | Y = A * L + B * (1 - L) 60 | Y -= np.min(Y, axis=(0, 1), keepdims=True) 61 | Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5) 62 | Y *= 255.0 63 | return Y.clip(0, 255).astype(np.uint8) 64 | 65 | 66 | class GrayDetector: 67 | def __call__(self, img): 68 | eps = 1e-5 69 | X = img.astype(np.float32) 70 | r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2] 71 | kr, kg, kb = [random.random() + eps for _ in range(3)] 72 | ks = kr + kg + kb 73 | kr /= ks 74 | kg /= ks 75 | kb /= ks 76 | Y = r * kr + g * kg + b * kb 77 | Y = np.stack([Y] * 3, axis=2) 78 | return Y.clip(0, 255).astype(np.uint8) 79 | 80 | 81 | class DownSampleDetector: 82 | def __call__(self, img, level=3, k=16.0): 83 | h = img.astype(np.float32) 84 | for _ in range(level): 85 | h += np.random.normal(loc=0.0, scale=k, size=h.shape) 86 | h = cv2.pyrDown(h) 87 | for _ in range(level): 88 | h = cv2.pyrUp(h) 89 | h += np.random.normal(loc=0.0, scale=k, size=h.shape) 90 | return h.clip(0, 255).astype(np.uint8) 91 | 92 | 93 | class Image2MaskShuffleDetector: 94 | def __init__(self, resolution=(640, 512)): 95 | self.H, self.W = resolution 96 | 97 | def __call__(self, img): 98 | m = img2mask(img, self.H, self.W) 99 | m *= 255.0 100 | return m.clip(0, 255).astype(np.uint8) 101 | -------------------------------------------------------------------------------- /src/controlnet_aux/teed/Fsmish.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script based on: 3 | Wang, Xueliang, Honge Ren, and Achuan Wang. 4 | "Smish: A Novel Activation Function for Deep Learning Methods. 5 | " Electronics 11.4 (2022): 540. 6 | """ 7 | 8 | # import pytorch 9 | import torch 10 | 11 | 12 | @torch.jit.script 13 | def smish(input): 14 | """ 15 | Applies the mish function element-wise: 16 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) 17 | See additional documentation for mish class. 18 | """ 19 | return input * torch.tanh(torch.log(1 + torch.sigmoid(input))) 20 | -------------------------------------------------------------------------------- /src/controlnet_aux/teed/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xavier Soria Poma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/controlnet_aux/teed/Xsmish.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script based on: 3 | Wang, Xueliang, Honge Ren, and Achuan Wang. 4 | "Smish: A Novel Activation Function for Deep Learning Methods. 5 | " Electronics 11.4 (2022): 540. 6 | smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) 7 | """ 8 | 9 | # import pytorch 10 | # import activation functions 11 | from torch import nn 12 | 13 | from .Fsmish import smish 14 | 15 | 16 | class Smish(nn.Module): 17 | """ 18 | Applies the mish function element-wise: 19 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 20 | Shape: 21 | - Input: (N, *) where * means, any number of additional 22 | dimensions 23 | - Output: (N, *), same shape as the input 24 | Examples: 25 | >>> m = Mish() 26 | >>> input = torch.randn(2) 27 | >>> output = m(input) 28 | Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html 29 | """ 30 | 31 | def __init__(self): 32 | """ 33 | Init method. 34 | """ 35 | super().__init__() 36 | 37 | def forward(self, input): 38 | """ 39 | Forward pass of the function. 40 | """ 41 | return smish(input) 42 | -------------------------------------------------------------------------------- /src/controlnet_aux/teed/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | 10 | from ..util import HWC3, resize_image, safe_step 11 | from .ted import TED 12 | 13 | 14 | class TEEDdetector: 15 | def __init__(self, model): 16 | self.model = model 17 | 18 | @classmethod 19 | def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): 20 | if os.path.isdir(pretrained_model_or_path): 21 | model_path = os.path.join(pretrained_model_or_path, filename) 22 | else: 23 | model_path = hf_hub_download( 24 | pretrained_model_or_path, filename, subfolder=subfolder 25 | ) 26 | 27 | model = TED() 28 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 29 | 30 | return cls(model) 31 | 32 | def to(self, device): 33 | self.model.to(device) 34 | return self 35 | 36 | def __call__( 37 | self, 38 | input_image, 39 | detect_resolution=512, 40 | safe_steps=2, 41 | output_type="pil", 42 | ): 43 | device = next(iter(self.model.parameters())).device 44 | if not isinstance(input_image, np.ndarray): 45 | input_image = np.array(input_image, dtype=np.uint8) 46 | output_type = output_type or "pil" 47 | else: 48 | output_type = output_type or "np" 49 | 50 | original_height, original_width, _ = input_image.shape 51 | 52 | input_image = HWC3(input_image) 53 | input_image = resize_image(input_image, detect_resolution) 54 | 55 | assert input_image.ndim == 3 56 | height, width, _ = input_image.shape 57 | with torch.no_grad(): 58 | image_teed = torch.from_numpy(input_image.copy()).float().to(device) 59 | image_teed = rearrange(image_teed, "h w c -> 1 c h w") 60 | edges = self.model(image_teed) 61 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 62 | edges = [ 63 | cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) 64 | for e in edges 65 | ] 66 | edges = np.stack(edges, axis=2) 67 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 68 | if safe_steps != 0: 69 | edge = safe_step(edge, safe_steps) 70 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 71 | 72 | detected_map = edge 73 | detected_map = HWC3(detected_map) 74 | 75 | detected_map = cv2.resize( 76 | detected_map, 77 | (original_width, original_height), 78 | interpolation=cv2.INTER_LINEAR, 79 | ) 80 | 81 | if output_type == "pil": 82 | detected_map = Image.fromarray(detected_map) 83 | 84 | return detected_map 85 | -------------------------------------------------------------------------------- /src/controlnet_aux/tests/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/tests/requirements.txt -------------------------------------------------------------------------------- /src/controlnet_aux/tests/test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/tests/test_image.png -------------------------------------------------------------------------------- /src/controlnet_aux/tests/test_processor.py: -------------------------------------------------------------------------------- 1 | """Test the Processor class.""" 2 | import unittest 3 | from PIL import Image 4 | 5 | from controlnet_aux.processor import Processor 6 | 7 | 8 | class TestProcessor(unittest.TestCase): 9 | def test_hed(self): 10 | processor = Processor('hed') 11 | image = Image.open('test_image.png') 12 | processed_image = processor(image) 13 | self.assertIsInstance(processed_image, bytes) 14 | 15 | def test_midas(self): 16 | processor = Processor('midas') 17 | image = Image.open('test_image.png') 18 | processed_image = processor(image) 19 | self.assertIsInstance(processed_image, bytes) 20 | 21 | def test_mlsd(self): 22 | processor = Processor('mlsd') 23 | image = Image.open('test_image.png') 24 | processed_image = processor(image) 25 | self.assertIsInstance(processed_image, bytes) 26 | 27 | def test_openpose(self): 28 | processor = Processor('openpose') 29 | image = Image.open('test_image.png') 30 | processed_image = processor(image) 31 | self.assertIsInstance(processed_image, bytes) 32 | 33 | def test_pidinet(self): 34 | processor = Processor('pidinet') 35 | image = Image.open('test_image.png') 36 | processed_image = processor(image) 37 | self.assertIsInstance(processed_image, bytes) 38 | 39 | def test_normalbae(self): 40 | processor = Processor('normalbae') 41 | image = Image.open('test_image.png') 42 | processed_image = processor(image) 43 | self.assertIsInstance(processed_image, bytes) 44 | 45 | def test_lineart(self): 46 | processor = Processor('lineart') 47 | image = Image.open('test_image.png') 48 | processed_image = processor(image) 49 | self.assertIsInstance(processed_image, bytes) 50 | 51 | def test_lineart_coarse(self): 52 | processor = Processor('lineart_coarse') 53 | image = Image.open('test_image.png') 54 | processed_image = processor(image) 55 | self.assertIsInstance(processed_image, bytes) 56 | 57 | def test_lineart_anime(self): 58 | processor = Processor('lineart_anime') 59 | image = Image.open('test_image.png') 60 | processed_image = processor(image) 61 | self.assertIsInstance(processed_image, bytes) 62 | 63 | def test_canny(self): 64 | processor = Processor('canny') 65 | image = Image.open('test_image.png') 66 | processed_image = processor(image) 67 | self.assertIsInstance(processed_image, bytes) 68 | 69 | def test_content_shuffle(self): 70 | processor = Processor('content_shuffle') 71 | image = Image.open('test_image.png') 72 | processed_image = processor(image) 73 | self.assertIsInstance(processed_image, bytes) 74 | 75 | def test_zoe(self): 76 | processor = Processor('zoe') 77 | image = Image.open('test_image.png') 78 | processed_image = processor(image) 79 | self.assertIsInstance(processed_image, bytes) 80 | 81 | def test_mediapipe_face(self): 82 | processor = Processor('mediapipe_face') 83 | image = Image.open('test_image.png') 84 | processed_image = processor(image) 85 | self.assertIsInstance(processed_image, bytes) 86 | 87 | 88 | if __name__ == '__main__': 89 | unittest.main() -------------------------------------------------------------------------------- /src/controlnet_aux/tests/test_processor_pytest.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import numpy as np 4 | import pytest 5 | from PIL import Image 6 | 7 | from controlnet_aux.processor import MODELS, Processor 8 | 9 | 10 | @pytest.fixture(params=[ 11 | 'scribble_hed', 12 | 'softedge_hed', 13 | 'scribble_hedsafe', 14 | 'softedge_hedsafe', 15 | 'depth_midas', 16 | 'mlsd', 17 | 'openpose', 18 | 'openpose_hand', 19 | 'openpose_face', 20 | 'openpose_faceonly', 21 | 'openpose_full', 22 | 'scribble_pidinet', 23 | 'softedge_pidinet', 24 | 'scribble_pidsafe', 25 | 'softedge_pidsafe', 26 | 'normal_bae', 27 | 'lineart_coarse', 28 | 'lineart_realistic', 29 | 'lineart_anime', 30 | 'canny', 31 | 'shuffle', 32 | 'depth_zoe', 33 | 'depth_leres', 34 | 'depth_leres++', 35 | 'mediapipe_face' 36 | ]) 37 | def processor(request): 38 | return Processor(request.param) 39 | 40 | 41 | def test_processor_init(processor): 42 | assert isinstance(processor.processor, MODELS[processor.processor_id]['class']) 43 | assert isinstance(processor.params, dict) 44 | 45 | 46 | def test_processor_call(processor): 47 | # Load test image 48 | with open('test_image.png', 'rb') as f: 49 | image_bytes = f.read() 50 | image = Image.open(io.BytesIO(image_bytes)) 51 | 52 | # Output size 53 | resolution = 512 54 | W, H = image.size 55 | H = float(H) 56 | W = float(W) 57 | k = float(resolution) / min(H, W) 58 | H *= k 59 | W *= k 60 | H = int(np.round(H / 64.0)) * 64 61 | W = int(np.round(W / 64.0)) * 64 62 | 63 | # Test processing 64 | processed_image = processor(image) 65 | assert isinstance(processed_image, Image.Image) 66 | assert processed_image.size == (W, H) 67 | 68 | 69 | def test_processor_call_bytes(processor): 70 | # Load test image 71 | with open('test_image.png', 'rb') as f: 72 | image_bytes = f.read() 73 | 74 | # Test processing 75 | processed_image_bytes = processor(image_bytes, to_pil=False) 76 | assert isinstance(processed_image_bytes, bytes) 77 | assert len(processed_image_bytes) > 0 -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | 10 | from ..util import HWC3, resize_image 11 | from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth 12 | from .zoedepth.models.zoedepth_nk.zoedepth_nk_v1 import ZoeDepthNK 13 | from .zoedepth.utils.config import get_config 14 | 15 | 16 | class ZoeDetector: 17 | def __init__(self, model): 18 | self.model = model 19 | 20 | @classmethod 21 | def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None, local_files_only=False): 22 | filename = filename or "ZoeD_M12_N.pt" 23 | 24 | if os.path.isdir(pretrained_model_or_path): 25 | model_path = os.path.join(pretrained_model_or_path, filename) 26 | else: 27 | model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) 28 | 29 | conf = get_config(model_type, "infer") 30 | model_cls = ZoeDepth if model_type == "zoedepth" else ZoeDepthNK 31 | model = model_cls.build_from_config(conf) 32 | try: 33 | # Try to load the model with standard approach first 34 | state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model'] 35 | model.load_state_dict(state_dict) 36 | except RuntimeError as e: 37 | # If that fails, try loading with strict=False 38 | print(f"Warning: Standard model loading failed, trying with strict=False: {str(e)}") 39 | state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model'] 40 | model.load_state_dict(state_dict, strict=False) 41 | 42 | model.eval() 43 | 44 | return cls(model) 45 | 46 | def to(self, device): 47 | self.model.to(device) 48 | return self 49 | 50 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None, gamma_corrected=False): 51 | device = next(iter(self.model.parameters())).device 52 | if not isinstance(input_image, np.ndarray): 53 | input_image = np.array(input_image, dtype=np.uint8) 54 | output_type = output_type or "pil" 55 | else: 56 | output_type = output_type or "np" 57 | 58 | input_image = HWC3(input_image) 59 | input_image = resize_image(input_image, detect_resolution) 60 | 61 | assert input_image.ndim == 3 62 | image_depth = input_image 63 | with torch.no_grad(): 64 | image_depth = torch.from_numpy(image_depth).float().to(device) 65 | image_depth = image_depth / 255.0 66 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 67 | depth = self.model.infer(image_depth) 68 | 69 | depth = depth[0, 0].cpu().numpy() 70 | 71 | vmin = np.percentile(depth, 2) 72 | vmax = np.percentile(depth, 85) 73 | 74 | depth -= vmin 75 | depth /= vmax - vmin 76 | depth = 1.0 - depth 77 | 78 | if gamma_corrected: 79 | depth = np.power(depth, 2.2) 80 | depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) 81 | 82 | detected_map = depth_image 83 | detected_map = HWC3(detected_map) 84 | 85 | img = resize_image(input_image, image_resolution) 86 | H, W, C = img.shape 87 | 88 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 89 | 90 | if output_type == "pil": 91 | detected_map = Image.fromarray(detected_map) 92 | 93 | return detected_map 94 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/zoe/zoedepth/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/controlnet_aux/b1aac318b15c4e65a21bc516dc236ad4bf593d46/src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .timm_adapter import create_model_adapter 5 | 6 | from .utils import activations, get_activation, Transpose 7 | 8 | 9 | def forward_levit(pretrained, x): 10 | pretrained.model.forward_features(x) 11 | 12 | layer_1 = pretrained.activations["1"] 13 | layer_2 = pretrained.activations["2"] 14 | layer_3 = pretrained.activations["3"] 15 | 16 | layer_1 = pretrained.act_postprocess1(layer_1) 17 | layer_2 = pretrained.act_postprocess2(layer_2) 18 | layer_3 = pretrained.act_postprocess3(layer_3) 19 | 20 | return layer_1, layer_2, layer_3 21 | 22 | 23 | def _make_levit_backbone( 24 | model, 25 | hooks=[3, 11, 21], 26 | patch_grid=[14, 14] 27 | ): 28 | pretrained = nn.Module() 29 | 30 | pretrained.model = model 31 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 32 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 33 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 34 | 35 | pretrained.activations = activations 36 | 37 | patch_grid_size = np.array(patch_grid, dtype=int) 38 | 39 | pretrained.act_postprocess1 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 42 | ) 43 | pretrained.act_postprocess2 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) 46 | ) 47 | pretrained.act_postprocess3 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) 50 | ) 51 | 52 | return pretrained 53 | 54 | 55 | class ConvTransposeNorm(nn.Sequential): 56 | """ 57 | Modification of 58 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm 59 | such that ConvTranspose2d is used instead of Conv2d. 60 | """ 61 | 62 | def __init__( 63 | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, 64 | groups=1, bn_weight_init=1): 65 | super().__init__() 66 | self.add_module('c', 67 | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) 68 | self.add_module('bn', nn.BatchNorm2d(out_chs)) 69 | 70 | nn.init.constant_(self.bn.weight, bn_weight_init) 71 | 72 | @torch.no_grad() 73 | def fuse(self): 74 | c, bn = self._modules.values() 75 | w = bn.weight / (bn.running_var + bn.eps) ** 0.5 76 | w = c.weight * w[:, None, None, None] 77 | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 78 | m = nn.ConvTranspose2d( 79 | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, 80 | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 81 | m.weight.data.copy_(w) 82 | m.bias.data.copy_(b) 83 | return m 84 | 85 | 86 | def stem_b4_transpose(in_chs, out_chs, activation): 87 | """ 88 | Modification of 89 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 90 | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. 91 | """ 92 | return nn.Sequential( 93 | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), 94 | activation(), 95 | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), 96 | activation()) 97 | 98 | 99 | def _make_pretrained_levit_384(pretrained, hooks=None): 100 | model = create_model_adapter("levit_384", pretrained=pretrained) 101 | 102 | hooks = [3, 11, 21] if hooks == None else hooks 103 | return _make_levit_backbone( 104 | model, 105 | hooks=hooks 106 | ) 107 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .timm_adapter import create_model_adapter 3 | 4 | from pathlib import Path 5 | from .utils import activations, forward_default, get_activation 6 | 7 | from ..external.next_vit.classification.nextvit import * 8 | 9 | 10 | def forward_next_vit(pretrained, x): 11 | return forward_default(pretrained, x, "forward") 12 | 13 | 14 | def _make_next_vit_backbone( 15 | model, 16 | hooks=[2, 6, 36, 39], 17 | ): 18 | pretrained = nn.Module() 19 | 20 | pretrained.model = model 21 | pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) 22 | pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) 23 | pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) 24 | pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) 25 | 26 | pretrained.activations = activations 27 | 28 | return pretrained 29 | 30 | 31 | def _make_pretrained_next_vit_large_6m(hooks=None): 32 | model = create_model_adapter("nextvit_large") 33 | 34 | hooks = [2, 6, 36, 39] if hooks == None else hooks 35 | return _make_next_vit_backbone( 36 | model, 37 | hooks=hooks, 38 | ) 39 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py: -------------------------------------------------------------------------------- 1 | from .timm_adapter import create_model_adapter 2 | from .swin_common import _make_swin_backbone 3 | 4 | 5 | def _make_pretrained_swinl12_384(pretrained, hooks=None): 6 | model = create_model_adapter("swin_large_patch4_window12_384", pretrained=pretrained) 7 | 8 | hooks = [1, 1, 17, 1] if hooks == None else hooks 9 | return _make_swin_backbone( 10 | model, 11 | hooks=hooks 12 | ) 13 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py: -------------------------------------------------------------------------------- 1 | from .timm_adapter import create_model_adapter 2 | from .swin_common import _make_swin_backbone 3 | 4 | 5 | def _make_pretrained_swin2l24_384(pretrained, hooks=None): 6 | model = create_model_adapter("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) 7 | 8 | hooks = [1, 1, 17, 1] if hooks == None else hooks 9 | return _make_swin_backbone( 10 | model, 11 | hooks=hooks 12 | ) 13 | 14 | 15 | def _make_pretrained_swin2b24_384(pretrained, hooks=None): 16 | model = create_model_adapter("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) 17 | 18 | hooks = [1, 1, 17, 1] if hooks == None else hooks 19 | return _make_swin_backbone( 20 | model, 21 | hooks=hooks 22 | ) 23 | 24 | 25 | def _make_pretrained_swin2t16_256(pretrained, hooks=None): 26 | model = create_model_adapter("swinv2_tiny_window16_256", pretrained=pretrained) 27 | 28 | hooks = [1, 1, 5, 1] if hooks == None else hooks 29 | return _make_swin_backbone( 30 | model, 31 | hooks=hooks, 32 | patch_grid=[64, 64] 33 | ) 34 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from .utils import activations, forward_default, get_activation, Transpose 7 | 8 | 9 | def forward_swin(pretrained, x): 10 | return forward_default(pretrained, x) 11 | 12 | 13 | def _make_swin_backbone( 14 | model, 15 | hooks=[1, 1, 17, 1], 16 | patch_grid=[96, 96] 17 | ): 18 | pretrained = nn.Module() 19 | 20 | pretrained.model = model 21 | pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) 22 | pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) 23 | pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) 24 | pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) 25 | 26 | pretrained.activations = activations 27 | 28 | if hasattr(model, "patch_grid"): 29 | used_patch_grid = model.patch_grid 30 | else: 31 | used_patch_grid = patch_grid 32 | 33 | patch_grid_size = np.array(used_patch_grid, dtype=int) 34 | 35 | pretrained.act_postprocess1 = nn.Sequential( 36 | Transpose(1, 2), 37 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 38 | ) 39 | pretrained.act_postprocess2 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) 42 | ) 43 | pretrained.act_postprocess3 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) 46 | ) 47 | pretrained.act_postprocess4 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) 50 | ) 51 | 52 | return pretrained 53 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/timm_adapter.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import timm 3 | 4 | # Check the installed timm version 5 | timm_version = importlib.metadata.version("timm") 6 | is_new_timm = timm_version >= "1.0.0" 7 | 8 | def create_model_adapter(model_name, pretrained=False, **kwargs): 9 | """ 10 | Adapter function for creating models with timm that works with both old (0.6.7) and new (1.0+) versions. 11 | """ 12 | if is_new_timm: 13 | # In timm 1.0+, 'pretrained' is deprecated in favor of 'pretrained_cfg' or explicit pretrained_cfg_url 14 | if pretrained: 15 | return timm.create_model(model_name, pretrained_cfg='default', **kwargs) 16 | else: 17 | return timm.create_model(model_name, pretrained_cfg=None, **kwargs) 18 | else: 19 | # Old timm behavior 20 | return timm.create_model(model_name, pretrained=pretrained, **kwargs) -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/builder.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | from importlib import import_module 26 | from .depth_model import DepthModel 27 | 28 | def build_model(config) -> DepthModel: 29 | """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. 30 | This function should be used to construct models for training and evaluation. 31 | 32 | Args: 33 | config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. 34 | 35 | Returns: 36 | torch.nn.Module: Model corresponding to name and version as specified in config 37 | """ 38 | module_name = f"zoedepth.models.{config.model}" 39 | try: 40 | module = import_module(module_name) 41 | except ModuleNotFoundError as e: 42 | # print the original error message 43 | print(e) 44 | raise ValueError( 45 | f"Model {config.model} not found. Refer above error for details.") from e 46 | try: 47 | get_version = getattr(module, "get_version") 48 | except AttributeError as e: 49 | raise ValueError( 50 | f"Model {config.model} has no get_version function.") from e 51 | return get_version(config.version_name).build_from_config(config) 52 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | 29 | def log_binom(n, k, eps=1e-7): 30 | """ log(nCk) using stirling approximation """ 31 | n = n + eps 32 | k = k + eps 33 | return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) 34 | 35 | 36 | class LogBinomial(nn.Module): 37 | def __init__(self, n_classes=256, act=torch.softmax): 38 | """Compute log binomial distribution for n_classes 39 | 40 | Args: 41 | n_classes (int, optional): number of output classes. Defaults to 256. 42 | """ 43 | super().__init__() 44 | self.K = n_classes 45 | self.act = act 46 | self.register_buffer('k_idx', torch.arange( 47 | 0, n_classes).view(1, -1, 1, 1)) 48 | self.register_buffer('K_minus_1', torch.Tensor( 49 | [self.K-1]).view(1, -1, 1, 1)) 50 | 51 | def forward(self, x, t=1., eps=1e-4): 52 | """Compute log binomial distribution for x 53 | 54 | Args: 55 | x (torch.Tensor - NCHW): probabilities 56 | t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. 57 | eps (float, optional): Small number for numerical stability. Defaults to 1e-4. 58 | 59 | Returns: 60 | torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) 61 | """ 62 | if x.ndim == 3: 63 | x = x.unsqueeze(1) # make it nchw 64 | 65 | one_minus_x = torch.clamp(1 - x, eps, 1) 66 | x = torch.clamp(x, eps, 1) 67 | y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ 68 | torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) 69 | return self.act(y/t, dim=1) 70 | 71 | 72 | class ConditionalLogBinomial(nn.Module): 73 | def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): 74 | """Conditional Log Binomial distribution 75 | 76 | Args: 77 | in_features (int): number of input channels in main feature 78 | condition_dim (int): number of input channels in condition feature 79 | n_classes (int, optional): Number of classes. Defaults to 256. 80 | bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. 81 | p_eps (float, optional): small eps value. Defaults to 1e-4. 82 | max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. 83 | min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. 84 | """ 85 | super().__init__() 86 | self.p_eps = p_eps 87 | self.max_temp = max_temp 88 | self.min_temp = min_temp 89 | self.log_binomial_transform = LogBinomial(n_classes, act=act) 90 | bottleneck = (in_features + condition_dim) // bottleneck_factor 91 | self.mlp = nn.Sequential( 92 | nn.Conv2d(in_features + condition_dim, bottleneck, 93 | kernel_size=1, stride=1, padding=0), 94 | nn.GELU(), 95 | # 2 for p linear norm, 2 for t linear norm 96 | nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), 97 | nn.Softplus() 98 | ) 99 | 100 | def forward(self, x, cond): 101 | """Forward pass 102 | 103 | Args: 104 | x (torch.Tensor - NCHW): Main feature 105 | cond (torch.Tensor - NCHW): condition feature 106 | 107 | Returns: 108 | torch.Tensor: Output log binomial distribution 109 | """ 110 | pt = self.mlp(torch.concat((x, cond), dim=1)) 111 | p, t = pt[:, :2, ...], pt[:, 2:, ...] 112 | 113 | p = p + self.p_eps 114 | p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) 115 | 116 | t = t + self.p_eps 117 | t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) 118 | t = t.unsqueeze(1) 119 | t = (self.max_temp - self.min_temp) * t + self.min_temp 120 | 121 | return self.log_binomial_transform(p, t) 122 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | 29 | class PatchTransformerEncoder(nn.Module): 30 | def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): 31 | """ViT-like transformer block 32 | 33 | Args: 34 | in_channels (int): Input channels 35 | patch_size (int, optional): patch size. Defaults to 10. 36 | embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. 37 | num_heads (int, optional): number of attention heads. Defaults to 4. 38 | use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. 39 | """ 40 | super(PatchTransformerEncoder, self).__init__() 41 | self.use_class_token = use_class_token 42 | encoder_layers = nn.TransformerEncoderLayer( 43 | embedding_dim, num_heads, dim_feedforward=1024) 44 | self.transformer_encoder = nn.TransformerEncoder( 45 | encoder_layers, num_layers=4) # takes shape S,N,E 46 | 47 | self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, 48 | kernel_size=patch_size, stride=patch_size, padding=0) 49 | 50 | def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): 51 | """Generate positional encodings 52 | 53 | Args: 54 | sequence_length (int): Sequence length 55 | embedding_dim (int): Embedding dimension 56 | 57 | Returns: 58 | torch.Tensor SBE: Positional encodings 59 | """ 60 | position = torch.arange( 61 | 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) 62 | index = torch.arange( 63 | 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) 64 | div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) 65 | pos_encoding = position * div_term 66 | pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) 67 | pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) 68 | return pos_encoding 69 | 70 | 71 | def forward(self, x): 72 | """Forward pass 73 | 74 | Args: 75 | x (torch.Tensor - NCHW): Input feature tensor 76 | 77 | Returns: 78 | torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim 79 | """ 80 | embeddings = self.embedding_convPxP(x).flatten( 81 | 2) # .shape = n,c,s = n, embedding_dim, s 82 | if self.use_class_token: 83 | # extra special token at start ? 84 | embeddings = nn.functional.pad(embeddings, (1, 0)) 85 | 86 | # change to S,N,E format required by transformer 87 | embeddings = embeddings.permute(2, 0, 1) 88 | S, N, E = embeddings.shape 89 | embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) 90 | x = self.transformer_encoder(embeddings) # .shape = S, N, E 91 | return x 92 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/model_io.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | import torch 26 | 27 | def load_state_dict(model, state_dict): 28 | """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. 29 | 30 | DataParallel prefixes state_dict keys with 'module.' when saving. 31 | If the model is not a DataParallel model but the state_dict is, then prefixes are removed. 32 | If the model is a DataParallel model but the state_dict is not, then prefixes are added. 33 | """ 34 | state_dict = state_dict.get('model', state_dict) 35 | # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' 36 | 37 | do_prefix = isinstance( 38 | model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) 39 | state = {} 40 | for k, v in state_dict.items(): 41 | if k.startswith('module.') and not do_prefix: 42 | k = k[7:] 43 | 44 | if not k.startswith('module.') and do_prefix: 45 | k = 'module.' + k 46 | 47 | state[k] = v 48 | 49 | model.load_state_dict(state) 50 | print("Loaded successfully") 51 | return model 52 | 53 | 54 | def load_wts(model, checkpoint_path): 55 | ckpt = torch.load(checkpoint_path, map_location='cpu') 56 | return load_state_dict(model, ckpt) 57 | 58 | 59 | def load_state_dict_from_url(model, url, **kwargs): 60 | state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) 61 | return load_state_dict(model, state_dict) 62 | 63 | 64 | def load_state_from_resource(model, resource: str): 65 | """Loads weights to the model from a given resource. A resource can be of following types: 66 | 1. URL. Prefixed with "url::" 67 | e.g. url::http(s)://url.resource.com/ckpt.pt 68 | 69 | 2. Local path. Prefixed with "local::" 70 | e.g. local::/path/to/ckpt.pt 71 | 72 | 73 | Args: 74 | model (torch.nn.Module): Model 75 | resource (str): resource string 76 | 77 | Returns: 78 | torch.nn.Module: Model with loaded weights 79 | """ 80 | print(f"Using pretrained resource {resource}") 81 | 82 | if resource.startswith('url::'): 83 | url = resource.split('url::')[1] 84 | return load_state_dict_from_url(model, url, progress=True) 85 | 86 | elif resource.startswith('local::'): 87 | path = resource.split('local::')[1] 88 | return load_wts(model, path) 89 | 90 | else: 91 | raise ValueError("Invalid resource type, only url:: and local:: are supported") 92 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | from .zoedepth_v1 import ZoeDepth 26 | 27 | all_versions = { 28 | "v1": ZoeDepth, 29 | } 30 | 31 | get_version = lambda v : all_versions[v] -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "ZoeDepth", 4 | "version_name": "v1", 5 | "n_bins": 64, 6 | "bin_embedding_dim": 128, 7 | "bin_centers_type": "softplus", 8 | "n_attractors":[16, 8, 4, 1], 9 | "attractor_alpha": 1000, 10 | "attractor_gamma": 2, 11 | "attractor_kind" : "mean", 12 | "attractor_type" : "inv", 13 | "midas_model_type" : "DPT_BEiT_L_384", 14 | "min_temp": 0.0212, 15 | "max_temp": 50.0, 16 | "output_distribution": "logbinomial", 17 | "memory_efficient": true, 18 | "inverse_midas": false, 19 | "img_size": [384, 512] 20 | }, 21 | 22 | "train": { 23 | "train_midas": true, 24 | "use_pretrained_midas": true, 25 | "trainer": "zoedepth", 26 | "epochs": 5, 27 | "bs": 16, 28 | "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, 29 | "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, 30 | "same_lr": false, 31 | "w_si": 1, 32 | "w_domain": 0.2, 33 | "w_reg": 0, 34 | "w_grad": 0, 35 | "avoid_boundary": false, 36 | "random_crop": false, 37 | "input_width": 640, 38 | "input_height": 480, 39 | "midas_lr_factor": 1, 40 | "encoder_lr_factor":10, 41 | "pos_enc_lr_factor":10, 42 | "freeze_midas_bn": true 43 | 44 | }, 45 | 46 | "infer":{ 47 | "train_midas": false, 48 | "use_pretrained_midas": false, 49 | "pretrained_resource" : null, 50 | "force_keep_ar": true 51 | }, 52 | 53 | "eval":{ 54 | "train_midas": false, 55 | "use_pretrained_midas": false, 56 | "pretrained_resource" : null 57 | } 58 | } -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "bin_centers_type": "normed", 4 | "img_size": [384, 768] 5 | }, 6 | 7 | "train": { 8 | }, 9 | 10 | "infer":{ 11 | "train_midas": false, 12 | "use_pretrained_midas": false, 13 | "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", 14 | "force_keep_ar": true 15 | }, 16 | 17 | "eval":{ 18 | "train_midas": false, 19 | "use_pretrained_midas": false, 20 | "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" 21 | } 22 | } -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | from .zoedepth_nk_v1 import ZoeDepthNK 26 | 27 | all_versions = { 28 | "v1": ZoeDepthNK, 29 | } 30 | 31 | get_version = lambda v : all_versions[v] -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "ZoeDepthNK", 4 | "version_name": "v1", 5 | "bin_conf" : [ 6 | { 7 | "name": "nyu", 8 | "n_bins": 64, 9 | "min_depth": 1e-3, 10 | "max_depth": 10.0 11 | }, 12 | { 13 | "name": "kitti", 14 | "n_bins": 64, 15 | "min_depth": 1e-3, 16 | "max_depth": 80.0 17 | } 18 | ], 19 | "bin_embedding_dim": 128, 20 | "bin_centers_type": "softplus", 21 | "n_attractors":[16, 8, 4, 1], 22 | "attractor_alpha": 1000, 23 | "attractor_gamma": 2, 24 | "attractor_kind" : "mean", 25 | "attractor_type" : "inv", 26 | "min_temp": 0.0212, 27 | "max_temp": 50.0, 28 | "memory_efficient": true, 29 | "midas_model_type" : "DPT_BEiT_L_384", 30 | "img_size": [384, 512] 31 | }, 32 | 33 | "train": { 34 | "train_midas": true, 35 | "use_pretrained_midas": true, 36 | "trainer": "zoedepth_nk", 37 | "epochs": 5, 38 | "bs": 16, 39 | "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, 40 | "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, 41 | "same_lr": false, 42 | "w_si": 1, 43 | "w_domain": 100, 44 | "avoid_boundary": false, 45 | "random_crop": false, 46 | "input_width": 640, 47 | "input_height": 480, 48 | "w_grad": 0, 49 | "w_reg": 0, 50 | "midas_lr_factor": 10, 51 | "encoder_lr_factor":10, 52 | "pos_enc_lr_factor":10 53 | }, 54 | 55 | "infer": { 56 | "train_midas": false, 57 | "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", 58 | "use_pretrained_midas": false, 59 | "force_keep_ar": true 60 | }, 61 | 62 | "eval": { 63 | "train_midas": false, 64 | "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", 65 | "use_pretrained_midas": false 66 | } 67 | } -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Intelligent Systems Lab Org 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # File author: Shariq Farooq Bhat 24 | 25 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def infer_type(x): # hacky way to infer type from string args 4 | if not isinstance(x, str): 5 | return x 6 | 7 | try: 8 | x = int(x) 9 | return x 10 | except ValueError: 11 | pass 12 | 13 | try: 14 | x = float(x) 15 | return x 16 | except ValueError: 17 | pass 18 | 19 | return x 20 | 21 | 22 | def parse_unknown(unknown_args): 23 | clean = [] 24 | for a in unknown_args: 25 | if "=" in a: 26 | k, v = a.split("=") 27 | clean.extend([k, v]) 28 | else: 29 | clean.append(a) 30 | 31 | keys = clean[::2] 32 | values = clean[1::2] 33 | return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} 34 | -------------------------------------------------------------------------------- /src/controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | EasyDict 3 | Copy/pasted from https://github.com/makinacorpus/easydict 4 | Original author: Mathieu Leplatre 5 | """ 6 | 7 | class EasyDict(dict): 8 | """ 9 | Get attributes 10 | 11 | >>> d = EasyDict({'foo':3}) 12 | >>> d['foo'] 13 | 3 14 | >>> d.foo 15 | 3 16 | >>> d.bar 17 | Traceback (most recent call last): 18 | ... 19 | AttributeError: 'EasyDict' object has no attribute 'bar' 20 | 21 | Works recursively 22 | 23 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 24 | >>> isinstance(d.bar, dict) 25 | True 26 | >>> d.bar.x 27 | 1 28 | 29 | Bullet-proof 30 | 31 | >>> EasyDict({}) 32 | {} 33 | >>> EasyDict(d={}) 34 | {} 35 | >>> EasyDict(None) 36 | {} 37 | >>> d = {'a': 1} 38 | >>> EasyDict(**d) 39 | {'a': 1} 40 | >>> EasyDict((('a', 1), ('b', 2))) 41 | {'a': 1, 'b': 2} 42 | 43 | Set attributes 44 | 45 | >>> d = EasyDict() 46 | >>> d.foo = 3 47 | >>> d.foo 48 | 3 49 | >>> d.bar = {'prop': 'value'} 50 | >>> d.bar.prop 51 | 'value' 52 | >>> d 53 | {'foo': 3, 'bar': {'prop': 'value'}} 54 | >>> d.bar.prop = 'newer' 55 | >>> d.bar.prop 56 | 'newer' 57 | 58 | 59 | Values extraction 60 | 61 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 62 | >>> isinstance(d.bar, list) 63 | True 64 | >>> from operator import attrgetter 65 | >>> list(map(attrgetter('x'), d.bar)) 66 | [1, 3] 67 | >>> list(map(attrgetter('y'), d.bar)) 68 | [2, 4] 69 | >>> d = EasyDict() 70 | >>> list(d.keys()) 71 | [] 72 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 73 | >>> d.foo 74 | 3 75 | >>> d.bar.x 76 | 1 77 | 78 | Still like a dict though 79 | 80 | >>> o = EasyDict({'clean':True}) 81 | >>> list(o.items()) 82 | [('clean', True)] 83 | 84 | And like a class 85 | 86 | >>> class Flower(EasyDict): 87 | ... power = 1 88 | ... 89 | >>> f = Flower() 90 | >>> f.power 91 | 1 92 | >>> f = Flower({'height': 12}) 93 | >>> f.height 94 | 12 95 | >>> f['power'] 96 | 1 97 | >>> sorted(f.keys()) 98 | ['height', 'power'] 99 | 100 | update and pop items 101 | >>> d = EasyDict(a=1, b='2') 102 | >>> e = EasyDict(c=3.0, a=9.0) 103 | >>> d.update(e) 104 | >>> d.c 105 | 3.0 106 | >>> d['c'] 107 | 3.0 108 | >>> d.get('c') 109 | 3.0 110 | >>> d.update(a=4, b=4) 111 | >>> d.b 112 | 4 113 | >>> d.pop('a') 114 | 4 115 | >>> d.a 116 | Traceback (most recent call last): 117 | ... 118 | AttributeError: 'EasyDict' object has no attribute 'a' 119 | """ 120 | def __init__(self, d=None, **kwargs): 121 | if d is None: 122 | d = {} 123 | else: 124 | d = dict(d) 125 | if kwargs: 126 | d.update(**kwargs) 127 | for k, v in d.items(): 128 | setattr(self, k, v) 129 | # Class attributes 130 | for k in self.__class__.__dict__.keys(): 131 | if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): 132 | setattr(self, k, getattr(self, k)) 133 | 134 | def __setattr__(self, name, value): 135 | if isinstance(value, (list, tuple)): 136 | value = [self.__class__(x) 137 | if isinstance(x, dict) else x for x in value] 138 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 139 | value = self.__class__(value) 140 | super(EasyDict, self).__setattr__(name, value) 141 | super(EasyDict, self).__setitem__(name, value) 142 | 143 | __setitem__ = __setattr__ 144 | 145 | def update(self, e=None, **f): 146 | d = e or dict() 147 | d.update(f) 148 | for k in d: 149 | setattr(self, k, d[k]) 150 | 151 | def pop(self, k, d=None): 152 | delattr(self, k) 153 | return super(EasyDict, self).pop(k, d) 154 | 155 | 156 | if __name__ == "__main__": 157 | import doctest 158 | doctest.testmod() -------------------------------------------------------------------------------- /tests/test_controlnet_aux.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from io import BytesIO 4 | 5 | import numpy as np 6 | import pytest 7 | import requests 8 | from PIL import Image 9 | 10 | from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, 11 | LeresDetector, LineartAnimeDetector, 12 | LineartDetector, MediapipeFaceDetector, 13 | MidasDetector, MLSDdetector, NormalBaeDetector, 14 | OpenposeDetector, PidiNetDetector, SamDetector, 15 | ZoeDetector, DWposeDetector) 16 | 17 | OUTPUT_DIR = "tests/outputs" 18 | 19 | def output(name, img): 20 | img.save(os.path.join(OUTPUT_DIR, "{:s}.png".format(name))) 21 | 22 | def common(name, processor, img): 23 | output(name, processor(img)) 24 | output(name + "_pil_np", Image.fromarray(processor(img, output_type="np"))) 25 | output(name + "_np_np", Image.fromarray(processor(np.array(img, dtype=np.uint8), output_type="np"))) 26 | output(name + "_np_pil", processor(np.array(img, dtype=np.uint8), output_type="pil")) 27 | output(name + "_scaled", processor(img, detect_resolution=640, image_resolution=768)) 28 | 29 | def return_pil(name, processor, img): 30 | output(name + "_pil_false", Image.fromarray(processor(img, return_pil=False))) 31 | output(name + "_pil_true", processor(img, return_pil=True)) 32 | 33 | @pytest.fixture(scope="module") 34 | def img(): 35 | if os.path.exists(OUTPUT_DIR): 36 | shutil.rmtree(OUTPUT_DIR) 37 | os.mkdir(OUTPUT_DIR) 38 | url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png" 39 | response = requests.get(url) 40 | img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512)) 41 | return img 42 | 43 | def test_canny(img): 44 | canny = CannyDetector() 45 | common("canny", canny, img) 46 | output("canny_img", canny(img=img)) 47 | 48 | def test_hed(img): 49 | hed = HEDdetector.from_pretrained("lllyasviel/Annotators") 50 | common("hed", hed, img) 51 | return_pil("hed", hed, img) 52 | output("hed_safe", hed(img, safe=True)) 53 | output("hed_scribble", hed(img, scribble=True)) 54 | 55 | def test_leres(img): 56 | leres = LeresDetector.from_pretrained("lllyasviel/Annotators") 57 | common("leres", leres, img) 58 | output("leres_boost", leres(img, boost=True)) 59 | 60 | def test_lineart(img): 61 | lineart = LineartDetector.from_pretrained("lllyasviel/Annotators") 62 | common("lineart", lineart, img) 63 | return_pil("lineart", lineart, img) 64 | output("lineart_coarse", lineart(img, coarse=True)) 65 | 66 | def test_lineart_anime(img): 67 | lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") 68 | common("lineart_anime", lineart_anime, img) 69 | return_pil("lineart_anime", lineart_anime, img) 70 | 71 | def test_mediapipe_face(img): 72 | mediapipe = MediapipeFaceDetector() 73 | common("mediapipe", mediapipe, img) 74 | output("mediapipe_image", mediapipe(image=img)) 75 | 76 | def test_midas(img): 77 | midas = MidasDetector.from_pretrained("lllyasviel/Annotators") 78 | common("midas", midas, img) 79 | output("midas_normal", midas(img, depth_and_normal=True)[1]) 80 | 81 | def test_mlsd(img): 82 | mlsd = MLSDdetector.from_pretrained("lllyasviel/Annotators") 83 | common("mlsd", mlsd, img) 84 | return_pil("mlsd", mlsd, img) 85 | 86 | def test_normalbae(img): 87 | normal_bae = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") 88 | common("normal_bae", normal_bae, img) 89 | return_pil("normal_bae", normal_bae, img) 90 | 91 | def test_openpose(img): 92 | openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") 93 | common("openpose", openpose, img) 94 | return_pil("openpose", openpose, img) 95 | output("openpose_hand_and_face_false", openpose(img, hand_and_face=False)) 96 | output("openpose_hand_and_face_true", openpose(img, hand_and_face=True)) 97 | output("openpose_face", openpose(img, include_body=True, include_hand=False, include_face=True)) 98 | output("openpose_faceonly", openpose(img, include_body=False, include_hand=False, include_face=True)) 99 | output("openpose_full", openpose(img, include_body=True, include_hand=True, include_face=True)) 100 | output("openpose_hand", openpose(img, include_body=True, include_hand=True, include_face=False)) 101 | 102 | def test_pidi(img): 103 | pidi = PidiNetDetector.from_pretrained("lllyasviel/Annotators") 104 | common("pidi", pidi, img) 105 | return_pil("pidi", pidi, img) 106 | output("pidi_safe", pidi(img, safe=True)) 107 | output("pidi_scribble", pidi(img, scribble=True)) 108 | 109 | def test_sam(img): 110 | sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") 111 | common("sam", sam, img) 112 | output("sam_image", sam(image=img)) 113 | 114 | def test_shuffle(img): 115 | shuffle = ContentShuffleDetector() 116 | common("shuffle", shuffle, img) 117 | return_pil("shuffle", shuffle, img) 118 | 119 | def test_zoe(img): 120 | zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") 121 | common("zoe", zoe, img) 122 | 123 | def test_dwpose(img): 124 | dwpose = DWposeDetector() 125 | common("dwpose", dwpose, img) 126 | return_pil("dwpose", dwpose, img) 127 | --------------------------------------------------------------------------------