├── images ├── cat.png ├── demo.gif ├── park.png ├── webui.png ├── buildings.png ├── cat_inpaint.png ├── park_inpaint.png └── buildings_inpaint.png ├── LICENSE ├── README.md └── demo.py /images/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/cat.png -------------------------------------------------------------------------------- /images/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/demo.gif -------------------------------------------------------------------------------- /images/park.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/park.png -------------------------------------------------------------------------------- /images/webui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/webui.png -------------------------------------------------------------------------------- /images/buildings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/buildings.png -------------------------------------------------------------------------------- /images/cat_inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/cat_inpaint.png -------------------------------------------------------------------------------- /images/park_inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/park_inpaint.png -------------------------------------------------------------------------------- /images/buildings_inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/lama-with-maskdino/HEAD/images/buildings_inpaint.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 qwopqwop200 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lama-with-MaskDINO 2 | ![demo](./images/demo.gif) 3 | 4 | It was inspired by [Auto-LaMa](https://github.com/andy971022/auto-lama#readme). 5 | 6 | Unlike Auto-Lama, it differs in: 7 | 1. Use the object instance segmentation model [MaskDINO](https://github.com/IDEA-Research/MaskDINO) instead of the object detection model [DETR](https://github.com/facebookresearch/detr). 8 | 1. Use [LaMa with refiner](https://github.com/geomagical/lama-with-refiner) for better results. 9 | ## simple demo with [gradio](https://github.com/gradio-app/gradio) 10 | ![webui](./images/webui.png) 11 | ## Environment setup 12 | A minimum of 12 gb memory gpu is required. 13 | 1. Download pre-trained weights [MaskDINO](https://github.com/IDEA-Research/detrex-storage/releases/download/maskdino-v0.1.0/maskdino_swinl_50ep_300q_hid2048_3sd1_instance_maskenhanced_mask52.3ap_box59.0ap.pth) and [LaMa](https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) 14 | 1. Put the directory like this 15 | ``` 16 | .root 17 | ├─demo.py 18 | ├─ckpt 19 | │ ├──maskdino_swinl_50ep_300q_hid2048_3sd1_instance_maskenhanced_mask52.3ap_box59.0ap.pth 20 | │ └─models 21 | │ ├──config.yaml 22 | │ └─models 23 | │ └─best.ckpt 24 | └─images 25 | ├──buildings.png 26 | ├──cat.png 27 | └──park.png 28 | ``` 29 | 3. conda environment setup 30 | ``` 31 | conda create --name maskdino python=3.8 -y 32 | conda activate maskdino 33 | conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia 34 | pip install -U opencv-python 35 | 36 | mkdir repo 37 | git clone git@github.com:facebookresearch/detectron2.git 38 | cd detectron2 39 | pip install -e . 40 | pip install git+https://github.com/cocodataset/panopticapi.git 41 | 42 | cd .. 43 | git clone -b quickfix/infer_demo --single-branch https://github.com/MeAmarP/MaskDINO.git 44 | cd MaskDINO 45 | pip install -r requirements.txt 46 | cd maskdino/modeling/pixel_decoder/ops 47 | python setup.py build install 48 | cd ../../../../.. 49 | 50 | git clone https://github.com/geomagical/lama-with-refiner.git 51 | cd lama-with-refiner 52 | pip install -r requirements.txt 53 | pip install --upgrade numpy==1.23.0 54 | cd ../.. 55 | pip install gradio 56 | ``` 57 | 4. Run 58 | ``` bash 59 | #localhost http://127.0.0.1:7860 60 | python demo.py 61 | ``` 62 | ## Acknowledgments 63 | Many thanks to these excellent opensource projects 64 | * [LaMA](https://github.com/saic-mdal/lama) 65 | * [LaMa with refiner](https://github.com/geomagical/lama-with-refiner) 66 | * [MaskDINO](https://github.com/IDEA-Research/MaskDINO) 67 | * [MaskDINO inference code](https://github.com/MeAmarP/MaskDINO/tree/quickfix/infer_demo) 68 | * [Detectron2](https://github.com/facebookresearch/detectron2) 69 | * [Auto-Lama](https://github.com/andy971022/auto-lama) 70 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('./repo/lama-with-refiner/') 5 | sys.path.append('./repo/MaskDINO/') 6 | sys.path.append('./repo/MaskDINO/demo/') 7 | 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | import numpy as np 11 | from scipy.ndimage.filters import gaussian_filter 12 | import cv2 13 | 14 | import argparse 15 | from omegaconf import OmegaConf 16 | 17 | from detectron2.config import get_cfg 18 | from detectron2.projects.deeplab import add_deeplab_config 19 | 20 | from maskdino import add_maskdino_config 21 | from predictor import VisualizationDemo 22 | 23 | from saicinpainting.evaluation.refinement import refine_predict 24 | from saicinpainting.training.trainers import load_checkpoint 25 | 26 | import gradio as gr 27 | 28 | className = {'person': 0, 29 | 'bicycle': 1, 30 | 'car': 2, 31 | 'motorcycle': 3, 32 | 'airplane': 4, 33 | 'bus': 5, 34 | 'train': 6, 35 | 'truck': 7, 36 | 'boat': 8, 37 | 'traffic light': 9, 38 | 'fire hydrant': 10, 39 | 'stop sign': 11, 40 | 'parking meter': 12, 41 | 'bench': 13, 42 | 'bird': 14, 43 | 'cat': 15, 44 | 'dog': 16, 45 | 'horse': 17, 46 | 'sheep': 18, 47 | 'cow': 19, 48 | 'elephant': 20, 49 | 'bear': 21, 50 | 'zebra': 22, 51 | 'giraffe': 23, 52 | 'backpack': 24, 53 | 'umbrella': 25, 54 | 'handbag': 26, 55 | 'tie': 27, 56 | 'suitcase': 28, 57 | 'frisbee': 29, 58 | 'skis': 30, 59 | 'snowboard': 31, 60 | 'sports ball': 32, 61 | 'kite': 33, 62 | 'baseball bat': 34, 63 | 'baseball glove': 35, 64 | 'skateboard': 36, 65 | 'surfboard': 37, 66 | 'tennis racket': 38, 67 | 'bottle': 39, 68 | 'wine glass': 40, 69 | 'cup': 41, 70 | 'fork': 42, 71 | 'knife': 43, 72 | 'spoon': 44, 73 | 'bowl': 45, 74 | 'banana': 46, 75 | 'apple': 47, 76 | 'sandwich': 48, 77 | 'orange': 49, 78 | 'broccoli': 50, 79 | 'carrot': 51, 80 | 'hot dog': 52, 81 | 'pizza': 53, 82 | 'donut': 54, 83 | 'cake': 55, 84 | 'chair': 56, 85 | 'couch': 57, 86 | 'potted plant': 58, 87 | 'bed': 59, 88 | 'dining table': 60, 89 | 'toilet': 61, 90 | 'tv': 62, 91 | 'laptop': 63, 92 | 'mouse': 64, 93 | 'remote': 65, 94 | 'keyboard': 66, 95 | 'cell phone': 67, 96 | 'microwave': 68, 97 | 'oven': 69, 98 | 'toaster': 70, 99 | 'sink': 71, 100 | 'refrigerator': 72, 101 | 'book': 73, 102 | 'clock': 74, 103 | 'vase': 75, 104 | 'scissors': 76, 105 | 'teddy bear': 77, 106 | 'hair drier': 78, 107 | 'toothbrush': 79} 108 | 109 | def get_parser(): 110 | parser = argparse.ArgumentParser(description="maskdino demo for builtin configs") 111 | parser.add_argument( 112 | "--config-file", 113 | default="./repo/MaskDINO/configs/coco/instance-segmentation/swin/maskdino_R50_bs16_50ep_4s_dowsample1_2048.yaml", 114 | metavar="FILE", 115 | help="path to config file", 116 | ) 117 | return parser 118 | 119 | def setup_cfg(): 120 | args = get_parser().parse_args(args=[]) 121 | cfg = get_cfg() 122 | add_deeplab_config(cfg) 123 | add_maskdino_config(cfg) 124 | cfg.merge_from_file(args.config_file) 125 | cfg.MODEL.WEIGHTS = './ckpt/maskdino_swinl_50ep_300q_hid2048_3sd1_instance_maskenhanced_mask52.3ap_box59.0ap.pth' 126 | cfg.freeze() 127 | return cfg 128 | 129 | def get_seg_model(): 130 | cfg = setup_cfg() 131 | model = VisualizationDemo(cfg) 132 | return model 133 | 134 | def get_inpaint_model(): 135 | predict_config = OmegaConf.load('./repo/lama-with-refiner/configs/prediction/default.yaml') 136 | predict_config.model.path = './ckpt/models/' 137 | predict_config.refiner.gpu_ids = '0' 138 | 139 | device = torch.device(predict_config.device) 140 | train_config_path = os.path.join(predict_config.model.path, 'config.yaml') 141 | 142 | train_config = OmegaConf.load(train_config_path) 143 | train_config.training_model.predict_only = True 144 | train_config.visualizer.kind = 'noop' 145 | 146 | checkpoint_path = os.path.join(predict_config.model.path, 147 | 'models', 148 | predict_config.model.checkpoint) 149 | 150 | model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu') 151 | model.freeze() 152 | model.to(device) 153 | return model,predict_config 154 | 155 | def ceil_modulo(x, mod): 156 | if x % mod == 0: 157 | return x 158 | return (x // mod + 1) * mod 159 | 160 | def pad_img_to_modulo(img, mod): 161 | channels, height, width = img.shape 162 | out_height = ceil_modulo(height, mod) 163 | out_width = ceil_modulo(width, mod) 164 | return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') 165 | 166 | seg_model = get_seg_model() 167 | inpaint_model,predict_config = get_inpaint_model() 168 | 169 | def inference(img,class_name,confidence_score,sigma,mask_threshold): 170 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 171 | 172 | predictions, visualized_output = seg_model.run_on_image(img) 173 | 174 | img = img.astype('float32') / 255 175 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 176 | img = np.transpose(img, (2, 0, 1)) 177 | 178 | preds = predictions['instances'].get_fields() 179 | 180 | masks = preds['pred_masks'][torch.logical_and(preds['pred_classes'] == className[class_name], preds['scores'] > confidence_score)] 181 | masks = torch.max(masks,axis=0) 182 | masks = masks.values.cpu().numpy() 183 | masks = gaussian_filter(masks, sigma=sigma) 184 | masks = (masks > mask_threshold) * 255 185 | 186 | batch = dict(image=img, mask=masks[None, ...]) 187 | 188 | batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])] 189 | batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(predict_config.device) 190 | batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(predict_config.device) 191 | 192 | cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner) 193 | cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy() 194 | 195 | cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') 196 | 197 | return cur_res 198 | 199 | demo = gr.Interface(fn=inference, 200 | inputs=[gr.Image(label='input'), 201 | gr.Dropdown(list(className.keys()),value='person',label= 'class name'), 202 | gr.Slider(0, 1, value=0.5, step=0.05,label='confidence score'), 203 | gr.Slider(1, 20, value=7, step=1,label='gaussian blur kernel size'), 204 | gr.Slider(0, 1, value=0.2, step=0.05,label='mask threshold')], 205 | outputs="image", 206 | examples=[["./images/buildings.png",'person',0.5,7,0.2], 207 | ["./images/park.png",'person',0.5,7,0.2], 208 | ["./images/cat.png",'remote',0.5,7,0.2]]) 209 | 210 | demo.launch() --------------------------------------------------------------------------------