├── .gitignore ├── LICENSE ├── README.md ├── README_DEV.md ├── fast_sam ├── __init__.py └── fast_sam_wrapper.py ├── ia_check_versions.py ├── ia_config.py ├── ia_devices.py ├── ia_file_manager.py ├── ia_get_dataset_colormap.py ├── ia_logging.py ├── ia_sam_manager.py ├── ia_threading.py ├── ia_ui_gradio.py ├── ia_ui_items.py ├── iasam_app.py ├── images ├── inpaint_anything_explanation_image_1.png ├── inpaint_anything_ui_image_1.png ├── sample_input_image.png ├── sample_mask_image.png └── sample_seg_color_image.png ├── inpalib ├── __init__.py ├── masklib.py └── samlib.py ├── javascript └── inpaint-anything.js ├── lama_cleaner ├── __init__.py ├── benchmark.py ├── const.py ├── file_manager │ ├── __init__.py │ ├── file_manager.py │ ├── storage_backends.py │ └── utils.py ├── helper.py ├── installer.py ├── model │ ├── __init__.py │ ├── base.py │ ├── controlnet.py │ ├── ddim_sampler.py │ ├── fcf.py │ ├── instruct_pix2pix.py │ ├── lama.py │ ├── ldm.py │ ├── manga.py │ ├── mat.py │ ├── opencv2.py │ ├── paint_by_example.py │ ├── pipeline │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_controlnet_inpaint.py │ ├── plms_sampler.py │ ├── sd.py │ ├── utils.py │ └── zits.py ├── model_manager.py ├── parse_args.py ├── plugins │ ├── __init__.py │ ├── anime_seg.py │ ├── base_plugin.py │ ├── gfpgan_plugin.py │ ├── gfpganer.py │ ├── gif.py │ ├── interactive_seg.py │ ├── realesrgan.py │ ├── remove_bg.py │ ├── restoreformer.py │ └── segment_anything │ │ ├── __init__.py │ │ ├── build_sam.py │ │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils │ │ ├── __init__.py │ │ └── transforms.py ├── runtime.py ├── schema.py ├── server.py ├── tests │ ├── __init__.py │ ├── test_controlnet.py │ ├── test_instruct_pix2pix.py │ ├── test_interactive_seg.py │ ├── test_load_img.py │ ├── test_model.py │ ├── test_model_md5.py │ ├── test_paint_by_example.py │ ├── test_plugins.py │ ├── test_save_exif.py │ └── test_sd_model.py └── web_config.py ├── mobile_sam ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── tiny_vit_sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ ├── torch_nms.py │ └── transforms.py ├── requirements.txt ├── requirements_cu118.txt ├── requirements_mac.txt ├── sam2 ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ ├── torch_nms.py │ └── transforms.py ├── sam2_configs ├── __init__.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml └── sam2_hiera_t.yaml ├── segment_anything_fb ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ ├── torch_nms.py │ └── transforms.py └── segment_anything_hq ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── build_sam_baseline.py ├── modeling ├── __init__.py ├── common.py ├── image_encoder.py ├── mask_decoder.py ├── mask_decoder_hq.py ├── prompt_encoder.py ├── sam.py └── transformer.py ├── predictor.py └── utils ├── __init__.py ├── amg.py ├── onnx.py ├── torch_nms.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pt 3 | *.pyc 4 | src/ 5 | outputs/ 6 | models/ 7 | models 8 | .DS_Store 9 | ia_config.ini 10 | .eslintrc 11 | .eslintrc.json 12 | pyproject.toml 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | -------------------------------------------------------------------------------- /README_DEV.md: -------------------------------------------------------------------------------- 1 | # Usage of Inpaint Anything Library 2 | 3 | ## Introduction 4 | 5 | The `inpalib` from the `inpaint-anything` package lets you segment images and create masks using sketches from other applications. 6 | 7 | ## Code Breakdown 8 | 9 | ### Imports and Module Initialization 10 | 11 | ```python 12 | import importlib 13 | 14 | import numpy as np 15 | from PIL import Image, ImageDraw 16 | 17 | inpalib = importlib.import_module("inpaint-anything.inpalib") 18 | ``` 19 | 20 | ### Fetch Model IDs 21 | 22 | ```python 23 | available_sam_ids = inpalib.get_available_sam_ids() 24 | 25 | use_sam_id = "sam_hq_vit_l.pth" 26 | # assert use_sam_id in available_sam_ids, f"Invalid SAM ID: {use_sam_id}" 27 | ``` 28 | 29 | Note: Only the models downloaded via the Inpaint Anything are available. 30 | 31 | ### Generate Segments Image 32 | 33 | ```python 34 | input_image = np.array(Image.open("/path/to/image.png")) 35 | 36 | sam_masks = inpalib.generate_sam_masks(input_image, use_sam_id, anime_style_chk=False) 37 | sam_masks = inpalib.sort_masks_by_area(sam_masks) 38 | 39 | seg_color_image = inpalib.create_seg_color_image(input_image, sam_masks) 40 | 41 | Image.fromarray(seg_color_image).save("/path/to/seg_color_image.png") 42 | ``` 43 | 44 | drawing drawing 45 | 46 | ### Create Mask from Sketch 47 | 48 | ```python 49 | sketch_image = Image.fromarray(np.zeros_like(input_image)) 50 | 51 | draw = ImageDraw.Draw(sketch_image) 52 | draw.point((input_image.shape[1] // 2, input_image.shape[0] // 2), fill=(255, 255, 255)) 53 | 54 | mask_image = inpalib.create_mask_image(np.array(sketch_image), sam_masks, ignore_black_chk=True) 55 | 56 | Image.fromarray(mask_image).save("/path/to/mask_image.png") 57 | ``` 58 | 59 | drawing 60 | 61 | Note: Ensure you adjust the file paths before executing the code. 62 | -------------------------------------------------------------------------------- /fast_sam/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_sam_wrapper import FastSAM 2 | from .fast_sam_wrapper import FastSamAutomaticMaskGenerator 3 | 4 | fast_sam_model_registry = { 5 | "FastSAM-x": FastSAM, 6 | "FastSAM-s": FastSAM, 7 | } 8 | 9 | __all__ = ["FastSAM", "FastSamAutomaticMaskGenerator", "fast_sam_model_registry"] 10 | -------------------------------------------------------------------------------- /fast_sam/fast_sam_wrapper.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from typing import Any, Dict, List 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import ultralytics 9 | 10 | if hasattr(ultralytics, "FastSAM"): 11 | from ultralytics import FastSAM as YOLO 12 | else: 13 | from ultralytics import YOLO 14 | 15 | 16 | class FastSAM: 17 | def __init__( 18 | self, 19 | checkpoint: str, 20 | ) -> None: 21 | self.model_path = checkpoint 22 | self.model = YOLO(self.model_path) 23 | 24 | if not hasattr(torch.nn.Upsample, "recompute_scale_factor"): 25 | torch.nn.Upsample.recompute_scale_factor = None 26 | 27 | def to(self, device) -> None: 28 | self.model.to(device) 29 | 30 | @property 31 | def device(self) -> Any: 32 | return self.model.device 33 | 34 | def __call__(self, source=None, stream=False, **kwargs) -> Any: 35 | return self.model(source=source, stream=stream, **kwargs) 36 | 37 | 38 | class FastSamAutomaticMaskGenerator: 39 | def __init__( 40 | self, 41 | model: FastSAM, 42 | points_per_batch: int = None, 43 | pred_iou_thresh: float = None, 44 | stability_score_thresh: float = None, 45 | ) -> None: 46 | self.model = model 47 | self.points_per_batch = points_per_batch 48 | self.pred_iou_thresh = pred_iou_thresh 49 | self.stability_score_thresh = stability_score_thresh 50 | self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15 51 | 52 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 53 | height, width = image.shape[:2] 54 | new_height = math.ceil(height / 32) * 32 55 | new_width = math.ceil(width / 32) * 32 56 | resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC) 57 | 58 | backup_nn_dict = {} 59 | for key, _ in torch.nn.__dict__.copy().items(): 60 | if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key: 61 | backup_nn_dict[key] = torch.nn.__dict__.pop(key) 62 | 63 | results = self.model( 64 | source=resize_image, 65 | stream=False, 66 | imgsz=max(new_height, new_width), 67 | device=self.model.device, 68 | retina_masks=True, 69 | iou=0.7, 70 | conf=self.conf, 71 | max_det=256) 72 | 73 | for key, value in backup_nn_dict.items(): 74 | setattr(torch.nn, key, value) 75 | # assert backup_nn_dict[key] == torch.nn.__dict__[key] 76 | 77 | annotations = results[0].masks.data 78 | 79 | if isinstance(annotations[0], torch.Tensor): 80 | annotations = np.array(annotations.cpu()) 81 | 82 | annotations_list = [] 83 | for mask in annotations: 84 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) 85 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8)) 86 | mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA) 87 | 88 | annotations_list.append(dict(segmentation=mask.astype(bool))) 89 | 90 | return annotations_list 91 | -------------------------------------------------------------------------------- /ia_check_versions.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from importlib.metadata import version 3 | from importlib.util import find_spec 4 | 5 | import torch 6 | from packaging.version import parse 7 | 8 | 9 | def get_module_version(module_name): 10 | try: 11 | module_version = version(module_name) 12 | except Exception: 13 | module_version = None 14 | return module_version 15 | 16 | 17 | def compare_version(version1, version2): 18 | if not isinstance(version1, str) or not isinstance(version2, str): 19 | return None 20 | 21 | if parse(version1) > parse(version2): 22 | return 1 23 | elif parse(version1) < parse(version2): 24 | return -1 25 | else: 26 | return 0 27 | 28 | 29 | def compare_module_version(module_name, version_string): 30 | module_version = get_module_version(module_name) 31 | 32 | result = compare_version(module_version, version_string) 33 | return result if result is not None else -2 34 | 35 | 36 | class IACheckVersions: 37 | @cached_property 38 | def diffusers_enable_cpu_offload(self): 39 | if (find_spec("diffusers") is not None and compare_module_version("diffusers", "0.15.0") >= 0 and 40 | find_spec("accelerate") is not None and compare_module_version("accelerate", "0.17.0") >= 0 and 41 | torch.cuda.is_available()): 42 | return True 43 | else: 44 | return False 45 | 46 | @cached_property 47 | def torch_mps_is_available(self): 48 | if compare_module_version("torch", "2.0.1") < 0: 49 | if not getattr(torch, "has_mps", False): 50 | return False 51 | try: 52 | torch.zeros(1).to(torch.device("mps")) 53 | return True 54 | except Exception: 55 | return False 56 | else: 57 | return torch.backends.mps.is_available() and torch.backends.mps.is_built() 58 | 59 | @cached_property 60 | def torch_on_amd_rocm(self): 61 | if find_spec("torch") is not None and "rocm" in version("torch"): 62 | return True 63 | else: 64 | return False 65 | 66 | @cached_property 67 | def gradio_version_is_old(self): 68 | if find_spec("gradio") is not None and compare_module_version("gradio", "3.34.0") <= 0: 69 | return True 70 | else: 71 | return False 72 | 73 | 74 | ia_check_versions = IACheckVersions() 75 | -------------------------------------------------------------------------------- /ia_config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | # import json 3 | import os 4 | from types import SimpleNamespace 5 | 6 | from ia_ui_items import get_inp_model_ids, get_sam_model_ids 7 | 8 | 9 | class IAConfig: 10 | SECTIONS = SimpleNamespace( 11 | DEFAULT=configparser.DEFAULTSECT, 12 | USER="USER", 13 | ) 14 | 15 | KEYS = SimpleNamespace( 16 | SAM_MODEL_ID="sam_model_id", 17 | INP_MODEL_ID="inp_model_id", 18 | ) 19 | 20 | PATHS = SimpleNamespace( 21 | INI=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ia_config.ini"), 22 | ) 23 | 24 | global_args = {} 25 | 26 | def __init__(self): 27 | self.ids_dict = {} 28 | self.ids_dict[IAConfig.KEYS.SAM_MODEL_ID] = { 29 | "list": get_sam_model_ids(), 30 | "index": 1, 31 | } 32 | self.ids_dict[IAConfig.KEYS.INP_MODEL_ID] = { 33 | "list": get_inp_model_ids(), 34 | "index": 0, 35 | } 36 | 37 | 38 | ia_config = IAConfig() 39 | 40 | 41 | def setup_ia_config_ini(): 42 | ia_config_ini = configparser.ConfigParser(defaults={}) 43 | if os.path.isfile(IAConfig.PATHS.INI): 44 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") 45 | 46 | changed = False 47 | for key, ids_info in ia_config.ids_dict.items(): 48 | if not ia_config_ini.has_option(IAConfig.SECTIONS.DEFAULT, key): 49 | if len(ids_info["list"]) > ids_info["index"]: 50 | ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]] 51 | changed = True 52 | else: 53 | if len(ids_info["list"]) > ids_info["index"] and ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] != ids_info["list"][ids_info["index"]]: 54 | ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]] 55 | changed = True 56 | 57 | if changed: 58 | with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f: 59 | ia_config_ini.write(f) 60 | 61 | 62 | def get_ia_config(key, section=IAConfig.SECTIONS.DEFAULT): 63 | setup_ia_config_ini() 64 | 65 | ia_config_ini = configparser.ConfigParser(defaults={}) 66 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") 67 | 68 | if ia_config_ini.has_option(section, key): 69 | return ia_config_ini[section][key] 70 | 71 | section = IAConfig.SECTIONS.DEFAULT 72 | if ia_config_ini.has_option(section, key): 73 | return ia_config_ini[section][key] 74 | 75 | return None 76 | 77 | 78 | def get_ia_config_index(key, section=IAConfig.SECTIONS.DEFAULT): 79 | value = get_ia_config(key, section) 80 | 81 | ids_dict = ia_config.ids_dict 82 | if value is None: 83 | if key in ids_dict.keys(): 84 | ids_info = ids_dict[key] 85 | return ids_info["index"] 86 | else: 87 | return 0 88 | else: 89 | if key in ids_dict.keys(): 90 | ids_info = ids_dict[key] 91 | return ids_info["list"].index(value) if value in ids_info["list"] else ids_info["index"] 92 | else: 93 | return 0 94 | 95 | 96 | def set_ia_config(key, value, section=IAConfig.SECTIONS.DEFAULT): 97 | setup_ia_config_ini() 98 | 99 | ia_config_ini = configparser.ConfigParser(defaults={}) 100 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8") 101 | 102 | if ia_config_ini.has_option(section, key) and ia_config_ini[section][key] == value: 103 | return 104 | 105 | if section != IAConfig.SECTIONS.DEFAULT and not ia_config_ini.has_section(section): 106 | ia_config_ini[section] = {} 107 | 108 | try: 109 | ia_config_ini[section][key] = value 110 | except Exception: 111 | ia_config_ini[section] = {} 112 | ia_config_ini[section][key] = value 113 | 114 | with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f: 115 | ia_config_ini.write(f) 116 | -------------------------------------------------------------------------------- /ia_devices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TorchDevices: 5 | def __init__(self): 6 | self.cpu = torch.device("cpu") 7 | self.device = torch.device("cuda") if torch.cuda.is_available() else self.cpu 8 | 9 | 10 | devices = TorchDevices() 11 | -------------------------------------------------------------------------------- /ia_file_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from huggingface_hub import snapshot_download 4 | from ia_logging import ia_logging 5 | 6 | 7 | class IAFileManager: 8 | DOWNLOAD_COMPLETE = "Download complete" 9 | 10 | def __init__(self) -> None: 11 | self._ia_outputs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 12 | "outputs", 13 | datetime.now().strftime("%Y-%m-%d")) 14 | 15 | self._ia_models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") 16 | 17 | @property 18 | def outputs_dir(self) -> str: 19 | """Get inpaint-anything outputs directory. 20 | 21 | Returns: 22 | str: inpaint-anything outputs directory 23 | """ 24 | if not os.path.isdir(self._ia_outputs_dir): 25 | os.makedirs(self._ia_outputs_dir, exist_ok=True) 26 | return self._ia_outputs_dir 27 | 28 | @property 29 | def models_dir(self) -> str: 30 | """Get inpaint-anything models directory. 31 | 32 | Returns: 33 | str: inpaint-anything models directory 34 | """ 35 | if not os.path.isdir(self._ia_models_dir): 36 | os.makedirs(self._ia_models_dir, exist_ok=True) 37 | return self._ia_models_dir 38 | 39 | @property 40 | def savename_prefix(self) -> str: 41 | """Get inpaint-anything savename prefix. 42 | 43 | Returns: 44 | str: inpaint-anything savename prefix 45 | """ 46 | return datetime.now().strftime("%Y%m%d-%H%M%S") 47 | 48 | 49 | ia_file_manager = IAFileManager() 50 | 51 | 52 | def download_model_from_hf(hf_model_id, local_files_only=False): 53 | """Download model from HuggingFace Hub. 54 | 55 | Args: 56 | sam_model_id (str): HuggingFace model id 57 | local_files_only (bool, optional): If True, use only local files. Defaults to False. 58 | 59 | Returns: 60 | str: download status 61 | """ 62 | if not local_files_only: 63 | ia_logging.info(f"Downloading {hf_model_id}") 64 | try: 65 | snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only) 66 | except FileNotFoundError: 67 | return f"{hf_model_id} not found, please download" 68 | except Exception as e: 69 | return str(e) 70 | 71 | return IAFileManager.DOWNLOAD_COMPLETE 72 | -------------------------------------------------------------------------------- /ia_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers") 5 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="huggingface_hub") 6 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="timm") 7 | 8 | ia_logging = logging.getLogger("Inpaint Anything") 9 | ia_logging.setLevel(logging.INFO) 10 | ia_logging.propagate = False 11 | 12 | ia_logging_sh = logging.StreamHandler() 13 | ia_logging_sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) 14 | ia_logging_sh.setLevel(logging.INFO) 15 | ia_logging.addHandler(ia_logging_sh) 16 | -------------------------------------------------------------------------------- /ia_threading.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import inspect 3 | import threading 4 | from functools import wraps 5 | 6 | import torch 7 | 8 | from ia_check_versions import ia_check_versions 9 | 10 | model_access_sem = threading.Semaphore(1) 11 | 12 | 13 | def torch_gc(): 14 | if torch.cuda.is_available(): 15 | torch.cuda.empty_cache() 16 | torch.cuda.ipc_collect() 17 | if ia_check_versions.torch_mps_is_available: 18 | if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 19 | torch.mps.empty_cache() 20 | 21 | 22 | def clear_cache(): 23 | gc.collect() 24 | torch_gc() 25 | 26 | 27 | def post_clear_cache(sem): 28 | with sem: 29 | gc.collect() 30 | torch_gc() 31 | 32 | 33 | def async_post_clear_cache(): 34 | thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,)) 35 | thread.start() 36 | 37 | 38 | def clear_cache_decorator(func): 39 | @wraps(func) 40 | def yield_wrapper(*args, **kwargs): 41 | clear_cache() 42 | yield from func(*args, **kwargs) 43 | clear_cache() 44 | 45 | @wraps(func) 46 | def wrapper(*args, **kwargs): 47 | clear_cache() 48 | res = func(*args, **kwargs) 49 | clear_cache() 50 | return res 51 | 52 | if inspect.isgeneratorfunction(func): 53 | return yield_wrapper 54 | else: 55 | return wrapper 56 | -------------------------------------------------------------------------------- /ia_ui_gradio.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gradio as gr 4 | 5 | GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse 6 | 7 | 8 | def webpath(fn): 9 | web_path = os.path.realpath(fn) 10 | 11 | return f'file={web_path}?{os.path.getmtime(fn)}' 12 | 13 | 14 | def javascript_html(): 15 | script_path = os.path.join(os.path.dirname(__file__), "javascript", "inpaint-anything.js") 16 | head = f'\n' 17 | 18 | return head 19 | 20 | 21 | def reload_javascript(): 22 | js = javascript_html() 23 | 24 | def template_response(*args, **kwargs): 25 | res = GradioTemplateResponseOriginal(*args, **kwargs) 26 | res.body = res.body.replace(b'', f'{js}'.encode("utf8")) 27 | res.init_headers() 28 | return res 29 | 30 | gr.routes.templates.TemplateResponse = template_response 31 | -------------------------------------------------------------------------------- /ia_ui_items.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import scan_cache_dir 2 | 3 | 4 | def get_sampler_names(): 5 | """Get sampler name list. 6 | 7 | Returns: 8 | list: sampler name list 9 | """ 10 | sampler_names = [ 11 | "DDIM", 12 | "Euler", 13 | "Euler a", 14 | "DPM2 Karras", 15 | "DPM2 a Karras", 16 | ] 17 | return sampler_names 18 | 19 | 20 | def get_sam_model_ids(): 21 | """Get SAM model ids list. 22 | 23 | Returns: 24 | list: SAM model ids list 25 | """ 26 | sam_model_ids = [ 27 | "sam2_hiera_large.pt", 28 | "sam2_hiera_base_plus.pt", 29 | "sam2_hiera_small.pt", 30 | "sam2_hiera_tiny.pt", 31 | "sam_vit_h_4b8939.pth", 32 | "sam_vit_l_0b3195.pth", 33 | "sam_vit_b_01ec64.pth", 34 | "sam_hq_vit_h.pth", 35 | "sam_hq_vit_l.pth", 36 | "sam_hq_vit_b.pth", 37 | "FastSAM-x.pt", 38 | "FastSAM-s.pt", 39 | "mobile_sam.pt", 40 | ] 41 | return sam_model_ids 42 | 43 | 44 | inp_list_from_cache = None 45 | 46 | 47 | def get_inp_model_ids(): 48 | """Get inpainting model ids list. 49 | 50 | Returns: 51 | list: model ids list 52 | """ 53 | global inp_list_from_cache 54 | model_ids = [ 55 | "stabilityai/stable-diffusion-2-inpainting", 56 | "Uminosachi/dreamshaper_8Inpainting", 57 | "Uminosachi/deliberate_v3-inpainting", 58 | "Uminosachi/realisticVisionV51_v51VAE-inpainting", 59 | "Uminosachi/revAnimated_v121Inp-inpainting", 60 | "runwayml/stable-diffusion-inpainting", 61 | ] 62 | if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list): 63 | model_ids.extend(inp_list_from_cache) 64 | return model_ids 65 | try: 66 | hf_cache_info = scan_cache_dir() 67 | inpaint_repos = [] 68 | for repo in hf_cache_info.repos: 69 | if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids: 70 | inpaint_repos.append(repo.repo_id) 71 | inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1]) 72 | model_ids.extend(inp_list_from_cache) 73 | return model_ids 74 | except Exception: 75 | return model_ids 76 | 77 | 78 | def get_cleaner_model_ids(): 79 | """Get cleaner model ids list. 80 | 81 | Returns: 82 | list: model ids list 83 | """ 84 | model_ids = [ 85 | "lama", 86 | "ldm", 87 | "zits", 88 | "mat", 89 | "fcf", 90 | "manga", 91 | ] 92 | return model_ids 93 | 94 | 95 | def get_padding_mode_names(): 96 | """Get padding mode name list. 97 | 98 | Returns: 99 | list: padding mode name list 100 | """ 101 | padding_mode_names = [ 102 | "constant", 103 | "edge", 104 | "reflect", 105 | "mean", 106 | "median", 107 | "maximum", 108 | "minimum", 109 | ] 110 | return padding_mode_names 111 | -------------------------------------------------------------------------------- /images/inpaint_anything_explanation_image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/inpaint_anything_explanation_image_1.png -------------------------------------------------------------------------------- /images/inpaint_anything_ui_image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/inpaint_anything_ui_image_1.png -------------------------------------------------------------------------------- /images/sample_input_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_input_image.png -------------------------------------------------------------------------------- /images/sample_mask_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_mask_image.png -------------------------------------------------------------------------------- /images/sample_seg_color_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_seg_color_image.png -------------------------------------------------------------------------------- /inpalib/__init__.py: -------------------------------------------------------------------------------- 1 | from .masklib import create_mask_image, invert_mask 2 | from .samlib import (create_seg_color_image, generate_sam_masks, get_all_sam_ids, 3 | get_available_sam_ids, get_seg_colormap, insert_mask_to_sam_masks, 4 | sam_file_exists, sam_file_path, sort_masks_by_area) 5 | 6 | __all__ = [ 7 | "create_mask_image", 8 | "invert_mask", 9 | "create_seg_color_image", 10 | "generate_sam_masks", 11 | "get_all_sam_ids", 12 | "get_available_sam_ids", 13 | "get_seg_colormap", 14 | "insert_mask_to_sam_masks", 15 | "sam_file_exists", 16 | "sam_file_path", 17 | "sort_masks_by_area", 18 | ] 19 | -------------------------------------------------------------------------------- /inpalib/masklib.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def invert_mask(mask: np.ndarray) -> np.ndarray: 8 | """Invert mask. 9 | 10 | Args: 11 | mask (np.ndarray): mask 12 | 13 | Returns: 14 | np.ndarray: inverted mask 15 | """ 16 | if mask is None or not isinstance(mask, np.ndarray): 17 | raise ValueError("Invalid mask") 18 | 19 | # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255 20 | return np.invert(mask.astype(np.uint8)) 21 | 22 | 23 | def check_inputs_create_mask_image( 24 | mask: Union[np.ndarray, Image.Image], 25 | sam_masks: List[Dict[str, Any]], 26 | ignore_black_chk: bool = True, 27 | ) -> None: 28 | """Check create mask image inputs. 29 | 30 | Args: 31 | mask (Union[np.ndarray, Image.Image]): mask 32 | sam_masks (List[Dict[str, Any]]): SAM masks 33 | ignore_black_chk (bool): ignore black check 34 | 35 | Returns: 36 | None 37 | """ 38 | if mask is None or not isinstance(mask, (np.ndarray, Image.Image)): 39 | raise ValueError("Invalid mask") 40 | 41 | if sam_masks is None or not isinstance(sam_masks, list): 42 | raise ValueError("Invalid SAM masks") 43 | 44 | if ignore_black_chk is None or not isinstance(ignore_black_chk, bool): 45 | raise ValueError("Invalid ignore black check") 46 | 47 | 48 | def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray: 49 | """Convert mask. 50 | 51 | Args: 52 | mask (Union[np.ndarray, Image.Image]): mask 53 | 54 | Returns: 55 | np.ndarray: converted mask 56 | """ 57 | if isinstance(mask, Image.Image): 58 | mask = np.array(mask) 59 | 60 | if mask.ndim == 2: 61 | mask = mask[:, :, np.newaxis] 62 | 63 | if mask.shape[2] != 1: 64 | mask = mask[:, :, 0:1] 65 | 66 | return mask 67 | 68 | 69 | def create_mask_image( 70 | mask: Union[np.ndarray, Image.Image], 71 | sam_masks: List[Dict[str, Any]], 72 | ignore_black_chk: bool = True, 73 | ) -> np.ndarray: 74 | """Create mask image. 75 | 76 | Args: 77 | mask (Union[np.ndarray, Image.Image]): mask 78 | sam_masks (List[Dict[str, Any]]): SAM masks 79 | ignore_black_chk (bool): ignore black check 80 | 81 | Returns: 82 | np.ndarray: mask image 83 | """ 84 | check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk) 85 | mask = convert_mask(mask) 86 | 87 | canvas_image = np.zeros(mask.shape, dtype=np.uint8) 88 | mask_region = np.zeros(mask.shape, dtype=np.uint8) 89 | for seg_dict in sam_masks: 90 | seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) 91 | canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) 92 | if (seg_mask * canvas_mask * mask).astype(bool).any(): 93 | mask_region = mask_region + (seg_mask * canvas_mask) 94 | seg_color = seg_mask * canvas_mask 95 | canvas_image = canvas_image + seg_color 96 | 97 | if not ignore_black_chk: 98 | canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) 99 | if (canvas_mask * mask).astype(bool).any(): 100 | mask_region = mask_region + (canvas_mask) 101 | 102 | mask_region = np.tile(mask_region * 255, (1, 1, 3)) 103 | 104 | seg_image = mask_region.astype(np.uint8) 105 | 106 | return seg_image 107 | -------------------------------------------------------------------------------- /lama_cleaner/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 4 | 5 | import warnings # noqa: E402 6 | 7 | warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") 8 | warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner") 9 | 10 | from lama_cleaner.parse_args import parse_args # noqa: E402 11 | 12 | 13 | def entry_point(): 14 | args = parse_args() 15 | # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers 16 | # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 17 | from lama_cleaner.server import main 18 | 19 | main(args) 20 | -------------------------------------------------------------------------------- /lama_cleaner/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import nvidia_smi 9 | import psutil 10 | import torch 11 | 12 | from lama_cleaner.model_manager import ModelManager 13 | from lama_cleaner.schema import Config, HDStrategy, SDSampler 14 | 15 | try: 16 | torch._C._jit_override_can_fuse_on_cpu(False) 17 | torch._C._jit_override_can_fuse_on_gpu(False) 18 | torch._C._jit_set_texpr_fuser_enabled(False) 19 | torch._C._jit_set_nvfuser_enabled(False) 20 | except: 21 | pass 22 | 23 | NUM_THREADS = str(4) 24 | 25 | os.environ["OMP_NUM_THREADS"] = NUM_THREADS 26 | os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS 27 | os.environ["MKL_NUM_THREADS"] = NUM_THREADS 28 | os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS 29 | os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS 30 | if os.environ.get("CACHE_DIR"): 31 | os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] 32 | 33 | 34 | def run_model(model, size): 35 | # RGB 36 | image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8) 37 | mask = np.random.randint(0, 255, size).astype(np.uint8) 38 | 39 | config = Config( 40 | ldm_steps=2, 41 | hd_strategy=HDStrategy.ORIGINAL, 42 | hd_strategy_crop_margin=128, 43 | hd_strategy_crop_trigger_size=128, 44 | hd_strategy_resize_limit=128, 45 | prompt="a fox is sitting on a bench", 46 | sd_steps=5, 47 | sd_sampler=SDSampler.ddim 48 | ) 49 | model(image, mask, config) 50 | 51 | 52 | def benchmark(model, times: int, empty_cache: bool): 53 | sizes = [(512, 512)] 54 | 55 | nvidia_smi.nvmlInit() 56 | device_id = 0 57 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id) 58 | 59 | def format(metrics): 60 | return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}" 61 | 62 | process = psutil.Process(os.getpid()) 63 | # 每个 size 给出显存和内存占用的指标 64 | for size in sizes: 65 | torch.cuda.empty_cache() 66 | time_metrics = [] 67 | cpu_metrics = [] 68 | memory_metrics = [] 69 | gpu_memory_metrics = [] 70 | for _ in range(times): 71 | start = time.time() 72 | run_model(model, size) 73 | torch.cuda.synchronize() 74 | 75 | # cpu_metrics.append(process.cpu_percent()) 76 | time_metrics.append((time.time() - start) * 1000) 77 | memory_metrics.append(process.memory_info().rss / 1024 / 1024) 78 | gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024) 79 | 80 | print(f"size: {size}".center(80, "-")) 81 | # print(f"cpu: {format(cpu_metrics)}") 82 | print(f"latency: {format(time_metrics)}ms") 83 | print(f"memory: {format(memory_metrics)} MB") 84 | print(f"gpu memory: {format(gpu_memory_metrics)} MB") 85 | 86 | nvidia_smi.nvmlShutdown() 87 | 88 | 89 | def get_args_parser(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("--name") 92 | parser.add_argument("--device", default="cuda", type=str) 93 | parser.add_argument("--times", default=10, type=int) 94 | parser.add_argument("--empty-cache", action="store_true") 95 | return parser.parse_args() 96 | 97 | 98 | if __name__ == "__main__": 99 | args = get_args_parser() 100 | device = torch.device(args.device) 101 | model = ModelManager( 102 | name=args.name, 103 | device=device, 104 | sd_run_local=True, 105 | disable_nsfw=True, 106 | sd_cpu_textencoder=True, 107 | hf_access_token="123" 108 | ) 109 | benchmark(model, args.times, args.empty_cache) 110 | -------------------------------------------------------------------------------- /lama_cleaner/const.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from enum import Enum 4 | from pydantic import BaseModel 5 | 6 | 7 | MPS_SUPPORT_MODELS = [ 8 | "instruct_pix2pix", 9 | "sd1.5", 10 | "anything4", 11 | "realisticVision1.4", 12 | "sd2", 13 | "paint_by_example", 14 | "controlnet", 15 | ] 16 | 17 | DEFAULT_MODEL = "lama" 18 | AVAILABLE_MODELS = [ 19 | "lama", 20 | "ldm", 21 | "zits", 22 | "mat", 23 | "fcf", 24 | "sd1.5", 25 | "anything4", 26 | "realisticVision1.4", 27 | "cv2", 28 | "manga", 29 | "sd2", 30 | "paint_by_example", 31 | "instruct_pix2pix", 32 | ] 33 | SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] 34 | 35 | AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] 36 | DEFAULT_DEVICE = "cuda" 37 | 38 | NO_HALF_HELP = """ 39 | Using full precision model. 40 | If your generate result is always black or green, use this argument. (sd/paint_by_exmaple) 41 | """ 42 | 43 | CPU_OFFLOAD_HELP = """ 44 | Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example) 45 | """ 46 | 47 | DISABLE_NSFW_HELP = """ 48 | Disable NSFW checker. (sd/paint_by_example) 49 | """ 50 | 51 | SD_CPU_TEXTENCODER_HELP = """ 52 | Run Stable Diffusion text encoder model on CPU to save GPU memory. 53 | """ 54 | 55 | SD_CONTROLNET_HELP = """ 56 | Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. 57 | """ 58 | DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny" 59 | SD_CONTROLNET_CHOICES = [ 60 | "control_v11p_sd15_canny", 61 | "control_v11p_sd15_openpose", 62 | "control_v11p_sd15_inpaint", 63 | "control_v11f1p_sd15_depth" 64 | ] 65 | 66 | SD_LOCAL_MODEL_HELP = """ 67 | Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. 68 | """ 69 | 70 | LOCAL_FILES_ONLY_HELP = """ 71 | Use local files only, not connect to Hugging Face server. (sd/paint_by_example) 72 | """ 73 | 74 | ENABLE_XFORMERS_HELP = """ 75 | Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example) 76 | """ 77 | 78 | DEFAULT_MODEL_DIR = os.getenv( 79 | "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") 80 | ) 81 | MODEL_DIR_HELP = """ 82 | Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache 83 | """ 84 | 85 | OUTPUT_DIR_HELP = """ 86 | Result images will be saved to output directory automatically without confirmation. 87 | """ 88 | 89 | INPUT_HELP = """ 90 | If input is image, it will be loaded by default. 91 | If input is directory, you can browse and select image in file manager. 92 | """ 93 | 94 | GUI_HELP = """ 95 | Launch Lama Cleaner as desktop app 96 | """ 97 | 98 | NO_GUI_AUTO_CLOSE_HELP = """ 99 | Prevent backend auto close after the GUI window closed. 100 | """ 101 | 102 | QUALITY_HELP = """ 103 | Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size. 104 | """ 105 | 106 | 107 | class RealESRGANModelName(str, Enum): 108 | realesr_general_x4v3 = "realesr-general-x4v3" 109 | RealESRGAN_x4plus = "RealESRGAN_x4plus" 110 | RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" 111 | 112 | 113 | RealESRGANModelNameList = [e.value for e in RealESRGANModelName] 114 | 115 | INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." 116 | INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." 117 | AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h"] 118 | AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"] 119 | REMOVE_BG_HELP = "Enable remove background. Always run on CPU" 120 | ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU" 121 | REALESRGAN_HELP = "Enable realesrgan super resolution" 122 | REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] 123 | GFPGAN_HELP = ( 124 | "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan" 125 | ) 126 | GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] 127 | RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan" 128 | RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"] 129 | GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image" 130 | 131 | 132 | class Config(BaseModel): 133 | host: str = "127.0.0.1" 134 | port: int = 8080 135 | model: str = DEFAULT_MODEL 136 | sd_local_model_path: str = None 137 | sd_controlnet: bool = False 138 | sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD 139 | device: str = DEFAULT_DEVICE 140 | gui: bool = False 141 | no_gui_auto_close: bool = False 142 | no_half: bool = False 143 | cpu_offload: bool = False 144 | disable_nsfw: bool = False 145 | sd_cpu_textencoder: bool = False 146 | enable_xformers: bool = False 147 | local_files_only: bool = False 148 | model_dir: str = DEFAULT_MODEL_DIR 149 | input: str = None 150 | output_dir: str = None 151 | # plugins 152 | enable_interactive_seg: bool = False 153 | interactive_seg_model: str = "vit_l" 154 | interactive_seg_device: str = "cpu" 155 | enable_remove_bg: bool = False 156 | enable_anime_seg: bool = False 157 | enable_realesrgan: bool = False 158 | realesrgan_device: str = "cpu" 159 | realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value 160 | realesrgan_no_half: bool = False 161 | enable_gfpgan: bool = False 162 | gfpgan_device: str = "cpu" 163 | enable_restoreformer: bool = False 164 | restoreformer_device: str = "cpu" 165 | enable_gif: bool = False 166 | 167 | 168 | def load_config(installer_config: str): 169 | if os.path.exists(installer_config): 170 | with open(installer_config, "r", encoding="utf-8") as f: 171 | return Config(**json.load(f)) 172 | else: 173 | return Config() 174 | -------------------------------------------------------------------------------- /lama_cleaner/file_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_manager import FileManager 2 | -------------------------------------------------------------------------------- /lama_cleaner/file_manager/storage_backends.py: -------------------------------------------------------------------------------- 1 | # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py 2 | import errno 3 | import os 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class BaseStorageBackend(ABC): 8 | def __init__(self, app=None): 9 | self.app = app 10 | 11 | @abstractmethod 12 | def read(self, filepath, mode="rb", **kwargs): 13 | raise NotImplementedError 14 | 15 | @abstractmethod 16 | def exists(self, filepath): 17 | raise NotImplementedError 18 | 19 | @abstractmethod 20 | def save(self, filepath, data): 21 | raise NotImplementedError 22 | 23 | 24 | class FilesystemStorageBackend(BaseStorageBackend): 25 | def read(self, filepath, mode="rb", **kwargs): 26 | with open(filepath, mode) as f: # pylint: disable=unspecified-encoding 27 | return f.read() 28 | 29 | def exists(self, filepath): 30 | return os.path.exists(filepath) 31 | 32 | def save(self, filepath, data): 33 | directory = os.path.dirname(filepath) 34 | 35 | if not os.path.exists(directory): 36 | try: 37 | os.makedirs(directory) 38 | except OSError as e: 39 | if e.errno != errno.EEXIST: 40 | raise 41 | 42 | if not os.path.isdir(directory): 43 | raise IOError("{} is not a directory".format(directory)) 44 | 45 | with open(filepath, "wb") as f: 46 | f.write(data) 47 | -------------------------------------------------------------------------------- /lama_cleaner/file_manager/utils.py: -------------------------------------------------------------------------------- 1 | # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py 2 | import importlib 3 | import os 4 | from pathlib import Path 5 | 6 | from typing import Union 7 | 8 | 9 | def generate_filename(original_filename, *options): 10 | name, ext = os.path.splitext(original_filename) 11 | for v in options: 12 | if v: 13 | name += "_%s" % v 14 | name += ext 15 | 16 | return name 17 | 18 | 19 | def parse_size(size): 20 | if isinstance(size, int): 21 | # If the size parameter is a single number, assume square aspect. 22 | return [size, size] 23 | 24 | if isinstance(size, (tuple, list)): 25 | if len(size) == 1: 26 | # If single value tuple/list is provided, exand it to two elements 27 | return size + type(size)(size) 28 | return size 29 | 30 | try: 31 | thumbnail_size = [int(x) for x in size.lower().split("x", 1)] 32 | except ValueError: 33 | raise ValueError( # pylint: disable=raise-missing-from 34 | "Bad thumbnail size format. Valid format is INTxINT." 35 | ) 36 | 37 | if len(thumbnail_size) == 1: 38 | # If the size parameter only contains a single integer, assume square aspect. 39 | thumbnail_size.append(thumbnail_size[0]) 40 | 41 | return thumbnail_size 42 | 43 | 44 | def aspect_to_string(size): 45 | if isinstance(size, str): 46 | return size 47 | 48 | return "x".join(map(str, size)) 49 | 50 | 51 | IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'} 52 | 53 | 54 | def glob_img(p: Union[Path, str], recursive: bool = False): 55 | p = Path(p) 56 | if p.is_file() and p.suffix in IMG_SUFFIX: 57 | yield p 58 | else: 59 | if recursive: 60 | files = Path(p).glob("**/*.*") 61 | else: 62 | files = Path(p).glob("*.*") 63 | 64 | for it in files: 65 | if it.suffix not in IMG_SUFFIX: 66 | continue 67 | yield it 68 | -------------------------------------------------------------------------------- /lama_cleaner/installer.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | 5 | def install(package): 6 | subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 7 | 8 | 9 | def install_plugins_package(): 10 | install("rembg") 11 | install("realesrgan") 12 | install("gfpgan") 13 | -------------------------------------------------------------------------------- /lama_cleaner/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/lama_cleaner/model/__init__.py -------------------------------------------------------------------------------- /lama_cleaner/model/instruct_pix2pix.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import cv2 3 | import torch 4 | from loguru import logger 5 | 6 | from lama_cleaner.model.base import DiffusionInpaintModel 7 | from lama_cleaner.model.utils import set_seed 8 | from lama_cleaner.schema import Config 9 | 10 | 11 | class InstructPix2Pix(DiffusionInpaintModel): 12 | name = "instruct_pix2pix" 13 | pad_mod = 8 14 | min_size = 512 15 | 16 | def init_model(self, device: torch.device, **kwargs): 17 | from diffusers import StableDiffusionInstructPix2PixPipeline 18 | fp16 = not kwargs.get('no_half', False) 19 | 20 | model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} 21 | if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): 22 | logger.info("Disable Stable Diffusion Model NSFW checker") 23 | model_kwargs.update(dict( 24 | safety_checker=None, 25 | feature_extractor=None, 26 | requires_safety_checker=False 27 | )) 28 | 29 | use_gpu = device == torch.device('cuda') and torch.cuda.is_available() 30 | torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 31 | self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( 32 | "timbrooks/instruct-pix2pix", 33 | revision="fp16" if use_gpu and fp16 else "main", 34 | torch_dtype=torch_dtype, 35 | **model_kwargs 36 | ) 37 | 38 | self.model.enable_attention_slicing() 39 | if kwargs.get('enable_xformers', False): 40 | self.model.enable_xformers_memory_efficient_attention() 41 | 42 | if kwargs.get('cpu_offload', False) and use_gpu: 43 | logger.info("Enable sequential cpu offload") 44 | self.model.enable_sequential_cpu_offload(gpu_id=0) 45 | else: 46 | self.model = self.model.to(device) 47 | 48 | def forward(self, image, mask, config: Config): 49 | """Input image and output image have same size 50 | image: [H, W, C] RGB 51 | mask: [H, W, 1] 255 means area to repaint 52 | return: BGR IMAGE 53 | edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0] 54 | """ 55 | output = self.model( 56 | image=PIL.Image.fromarray(image), 57 | prompt=config.prompt, 58 | negative_prompt=config.negative_prompt, 59 | num_inference_steps=config.p2p_steps, 60 | image_guidance_scale=config.p2p_image_guidance_scale, 61 | guidance_scale=config.p2p_guidance_scale, 62 | output_type="np.array", 63 | generator=torch.manual_seed(config.sd_seed) 64 | ).images[0] 65 | 66 | output = (output * 255).round().astype("uint8") 67 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 68 | return output 69 | 70 | # 71 | # def forward_post_process(self, result, image, mask, config): 72 | # if config.sd_match_histograms: 73 | # result = self._match_histograms(result, image[:, :, ::-1], mask) 74 | # 75 | # if config.sd_mask_blur != 0: 76 | # k = 2 * config.sd_mask_blur + 1 77 | # mask = cv2.GaussianBlur(mask, (k, k), 0) 78 | # return result, image, mask 79 | 80 | @staticmethod 81 | def is_downloaded() -> bool: 82 | # model will be downloaded when app start, and can't switch in frontend settings 83 | return True 84 | -------------------------------------------------------------------------------- /lama_cleaner/model/lama.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from lama_cleaner.helper import ( 8 | norm_img, 9 | get_cache_path_by_url, 10 | load_jit_model, 11 | ) 12 | from lama_cleaner.model.base import InpaintModel 13 | from lama_cleaner.schema import Config 14 | 15 | LAMA_MODEL_URL = os.environ.get( 16 | "LAMA_MODEL_URL", 17 | "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", 18 | ) 19 | LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500") 20 | 21 | 22 | class LaMa(InpaintModel): 23 | name = "lama" 24 | pad_mod = 8 25 | 26 | def init_model(self, device, **kwargs): 27 | self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval() 28 | 29 | @staticmethod 30 | def is_downloaded() -> bool: 31 | return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) 32 | 33 | def forward(self, image, mask, config: Config): 34 | """Input image and output image have same size 35 | image: [H, W, C] RGB 36 | mask: [H, W] 37 | return: BGR IMAGE 38 | """ 39 | image = norm_img(image) 40 | mask = norm_img(mask) 41 | 42 | mask = (mask > 0) * 1 43 | image = torch.from_numpy(image).unsqueeze(0).to(self.device) 44 | mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) 45 | 46 | inpainted_image = self.model(image, mask) 47 | 48 | cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() 49 | cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") 50 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) 51 | return cur_res 52 | -------------------------------------------------------------------------------- /lama_cleaner/model/manga.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import time 8 | from loguru import logger 9 | 10 | from lama_cleaner.helper import get_cache_path_by_url, load_jit_model 11 | from lama_cleaner.model.base import InpaintModel 12 | from lama_cleaner.schema import Config 13 | 14 | 15 | MANGA_INPAINTOR_MODEL_URL = os.environ.get( 16 | "MANGA_INPAINTOR_MODEL_URL", 17 | "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit", 18 | ) 19 | MANGA_INPAINTOR_MODEL_MD5 = os.environ.get( 20 | "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c" 21 | ) 22 | 23 | MANGA_LINE_MODEL_URL = os.environ.get( 24 | "MANGA_LINE_MODEL_URL", 25 | "https://github.com/Sanster/models/releases/download/manga/erika.jit", 26 | ) 27 | MANGA_LINE_MODEL_MD5 = os.environ.get( 28 | "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644" 29 | ) 30 | 31 | 32 | class Manga(InpaintModel): 33 | name = "manga" 34 | pad_mod = 16 35 | 36 | def init_model(self, device, **kwargs): 37 | self.inpaintor_model = load_jit_model( 38 | MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5 39 | ) 40 | self.line_model = load_jit_model( 41 | MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5 42 | ) 43 | self.seed = 42 44 | 45 | @staticmethod 46 | def is_downloaded() -> bool: 47 | model_paths = [ 48 | get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL), 49 | get_cache_path_by_url(MANGA_LINE_MODEL_URL), 50 | ] 51 | return all([os.path.exists(it) for it in model_paths]) 52 | 53 | def forward(self, image, mask, config: Config): 54 | """ 55 | image: [H, W, C] RGB 56 | mask: [H, W, 1] 57 | return: BGR IMAGE 58 | """ 59 | seed = self.seed 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | 65 | gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 66 | gray_img = torch.from_numpy( 67 | gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32) 68 | ).to(self.device) 69 | start = time.time() 70 | lines = self.line_model(gray_img) 71 | torch.cuda.empty_cache() 72 | lines = torch.clamp(lines, 0, 255) 73 | logger.info(f"erika_model time: {time.time() - start}") 74 | 75 | mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device) 76 | mask = mask.permute(0, 3, 1, 2) 77 | mask = torch.where(mask > 0.5, 1.0, 0.0) 78 | noise = torch.randn_like(mask) 79 | ones = torch.ones_like(mask) 80 | 81 | gray_img = gray_img / 255 * 2 - 1.0 82 | lines = lines / 255 * 2 - 1.0 83 | 84 | start = time.time() 85 | inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones) 86 | logger.info(f"image_inpaintor_model time: {time.time() - start}") 87 | 88 | cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() 89 | cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8) 90 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR) 91 | return cur_res 92 | -------------------------------------------------------------------------------- /lama_cleaner/model/opencv2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from lama_cleaner.model.base import InpaintModel 3 | from lama_cleaner.schema import Config 4 | 5 | flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA} 6 | 7 | 8 | class OpenCV2(InpaintModel): 9 | name = "cv2" 10 | pad_mod = 1 11 | 12 | @staticmethod 13 | def is_downloaded() -> bool: 14 | return True 15 | 16 | def forward(self, image, mask, config: Config): 17 | """Input image and output image have same size 18 | image: [H, W, C] RGB 19 | mask: [H, W, 1] 20 | return: BGR IMAGE 21 | """ 22 | cur_res = cv2.inpaint( 23 | image[:, :, ::-1], 24 | mask, 25 | inpaintRadius=config.cv2_radius, 26 | flags=flag_map[config.cv2_flag], 27 | ) 28 | return cur_res 29 | -------------------------------------------------------------------------------- /lama_cleaner/model/paint_by_example.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import PIL.Image 3 | import cv2 4 | import torch 5 | from diffusers import DiffusionPipeline 6 | from loguru import logger 7 | 8 | from lama_cleaner.model.base import DiffusionInpaintModel 9 | from lama_cleaner.model.utils import set_seed 10 | from lama_cleaner.schema import Config 11 | 12 | 13 | class PaintByExample(DiffusionInpaintModel): 14 | name = "paint_by_example" 15 | pad_mod = 8 16 | min_size = 512 17 | 18 | def init_model(self, device: torch.device, **kwargs): 19 | fp16 = not kwargs.get('no_half', False) 20 | use_gpu = device == torch.device('cuda') and torch.cuda.is_available() 21 | torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 22 | model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} 23 | 24 | if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): 25 | logger.info("Disable Paint By Example Model NSFW checker") 26 | model_kwargs.update(dict( 27 | safety_checker=None, 28 | requires_safety_checker=False 29 | )) 30 | 31 | self.model = DiffusionPipeline.from_pretrained( 32 | "Fantasy-Studio/Paint-by-Example", 33 | torch_dtype=torch_dtype, 34 | **model_kwargs 35 | ) 36 | 37 | self.model.enable_attention_slicing() 38 | if kwargs.get('enable_xformers', False): 39 | self.model.enable_xformers_memory_efficient_attention() 40 | 41 | # TODO: gpu_id 42 | if kwargs.get('cpu_offload', False) and use_gpu: 43 | self.model.image_encoder = self.model.image_encoder.to(device) 44 | self.model.enable_sequential_cpu_offload(gpu_id=0) 45 | else: 46 | self.model = self.model.to(device) 47 | 48 | def forward(self, image, mask, config: Config): 49 | """Input image and output image have same size 50 | image: [H, W, C] RGB 51 | mask: [H, W, 1] 255 means area to repaint 52 | return: BGR IMAGE 53 | """ 54 | output = self.model( 55 | image=PIL.Image.fromarray(image), 56 | mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), 57 | example_image=config.paint_by_example_example_image, 58 | num_inference_steps=config.paint_by_example_steps, 59 | output_type='np.array', 60 | generator=torch.manual_seed(config.paint_by_example_seed) 61 | ).images[0] 62 | 63 | output = (output * 255).round().astype("uint8") 64 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 65 | return output 66 | 67 | def forward_post_process(self, result, image, mask, config): 68 | if config.paint_by_example_match_histograms: 69 | result = self._match_histograms(result, image[:, :, ::-1], mask) 70 | 71 | if config.paint_by_example_mask_blur != 0: 72 | k = 2 * config.paint_by_example_mask_blur + 1 73 | mask = cv2.GaussianBlur(mask, (k, k), 0) 74 | return result, image, mask 75 | 76 | @staticmethod 77 | def is_downloaded() -> bool: 78 | # model will be downloaded when app start, and can't switch in frontend settings 79 | return True 80 | -------------------------------------------------------------------------------- /lama_cleaner/model/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_stable_diffusion_controlnet_inpaint import ( 2 | StableDiffusionControlNetInpaintPipeline, 3 | ) 4 | -------------------------------------------------------------------------------- /lama_cleaner/model_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | 4 | from loguru import logger 5 | 6 | from lama_cleaner.const import SD15_MODELS 7 | from lama_cleaner.helper import switch_mps_device 8 | from lama_cleaner.model.controlnet import ControlNet 9 | from lama_cleaner.model.fcf import FcF 10 | from lama_cleaner.model.lama import LaMa 11 | from lama_cleaner.model.ldm import LDM 12 | from lama_cleaner.model.manga import Manga 13 | from lama_cleaner.model.mat import MAT 14 | from lama_cleaner.model.paint_by_example import PaintByExample 15 | from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix 16 | from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 17 | from lama_cleaner.model.utils import torch_gc 18 | from lama_cleaner.model.zits import ZITS 19 | from lama_cleaner.model.opencv2 import OpenCV2 20 | from lama_cleaner.schema import Config 21 | 22 | models = { 23 | "lama": LaMa, 24 | "ldm": LDM, 25 | "zits": ZITS, 26 | "mat": MAT, 27 | "fcf": FcF, 28 | SD15.name: SD15, 29 | Anything4.name: Anything4, 30 | RealisticVision14.name: RealisticVision14, 31 | "cv2": OpenCV2, 32 | "manga": Manga, 33 | "sd2": SD2, 34 | "paint_by_example": PaintByExample, 35 | "instruct_pix2pix": InstructPix2Pix, 36 | } 37 | 38 | 39 | class ModelManager: 40 | def __init__(self, name: str, device: torch.device, **kwargs): 41 | self.name = name 42 | self.device = device 43 | self.kwargs = kwargs 44 | self.model = self.init_model(name, device, **kwargs) 45 | 46 | def init_model(self, name: str, device, **kwargs): 47 | if name in SD15_MODELS and kwargs.get("sd_controlnet", False): 48 | return ControlNet(device, **{**kwargs, "name": name}) 49 | 50 | if name in models: 51 | model = models[name](device, **kwargs) 52 | else: 53 | raise NotImplementedError(f"Not supported model: {name}") 54 | return model 55 | 56 | def is_downloaded(self, name: str) -> bool: 57 | if name in models: 58 | return models[name].is_downloaded() 59 | else: 60 | raise NotImplementedError(f"Not supported model: {name}") 61 | 62 | def __call__(self, image, mask, config: Config): 63 | self.switch_controlnet_method(control_method=config.controlnet_method) 64 | return self.model(image, mask, config) 65 | 66 | def switch(self, new_name: str, **kwargs): 67 | if new_name == self.name: 68 | return 69 | try: 70 | if torch.cuda.memory_allocated() > 0: 71 | # Clear current loaded model from memory 72 | torch.cuda.empty_cache() 73 | del self.model 74 | gc.collect() 75 | 76 | self.model = self.init_model( 77 | new_name, switch_mps_device(new_name, self.device), **self.kwargs 78 | ) 79 | self.name = new_name 80 | except NotImplementedError as e: 81 | raise e 82 | 83 | def switch_controlnet_method(self, control_method: str): 84 | if not self.kwargs.get("sd_controlnet"): 85 | return 86 | if self.kwargs["sd_controlnet_method"] == control_method: 87 | return 88 | if not hasattr(self.model, "is_local_sd_model"): 89 | return 90 | 91 | if self.model.is_local_sd_model: 92 | # is_native_control_inpaint 表示加载了普通 SD 模型 93 | if ( 94 | self.model.is_native_control_inpaint 95 | and control_method != "control_v11p_sd15_inpaint" 96 | ): 97 | raise RuntimeError( 98 | f"--sd-local-model-path load a normal SD model, " 99 | f"to use {control_method} you should load an inpainting SD model" 100 | ) 101 | elif ( 102 | not self.model.is_native_control_inpaint 103 | and control_method == "control_v11p_sd15_inpaint" 104 | ): 105 | raise RuntimeError( 106 | f"--sd-local-model-path load an inpainting SD model, " 107 | f"to use {control_method} you should load a norml SD model" 108 | ) 109 | 110 | del self.model 111 | torch_gc() 112 | 113 | old_method = self.kwargs["sd_controlnet_method"] 114 | self.kwargs["sd_controlnet_method"] = control_method 115 | self.model = self.init_model( 116 | self.name, switch_mps_device(self.name, self.device), **self.kwargs 117 | ) 118 | logger.info(f"Switch ControlNet method from {old_method} to {control_method}") 119 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_seg import InteractiveSeg 2 | from .remove_bg import RemoveBG 3 | from .realesrgan import RealESRGANUpscaler 4 | from .gfpgan_plugin import GFPGANPlugin 5 | from .restoreformer import RestoreFormerPlugin 6 | from .gif import MakeGIF 7 | from .anime_seg import AnimeSeg 8 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/base_plugin.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | 4 | class BasePlugin: 5 | def __init__(self): 6 | err_msg = self.check_dep() 7 | if err_msg: 8 | logger.error(err_msg) 9 | exit(-1) 10 | 11 | def __call__(self, rgb_np_img, files, form): 12 | ... 13 | 14 | def check_dep(self): 15 | ... 16 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/gfpgan_plugin.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from loguru import logger 3 | 4 | from lama_cleaner.helper import download_model 5 | from lama_cleaner.plugins.base_plugin import BasePlugin 6 | 7 | 8 | class GFPGANPlugin(BasePlugin): 9 | name = "GFPGAN" 10 | 11 | def __init__(self, device, upscaler=None): 12 | super().__init__() 13 | from .gfpganer import MyGFPGANer 14 | 15 | url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 16 | model_md5 = "94d735072630ab734561130a47bc44f8" 17 | model_path = download_model(url, model_md5) 18 | logger.info(f"GFPGAN model path: {model_path}") 19 | 20 | import facexlib 21 | 22 | if hasattr(facexlib.detection.retinaface, "device"): 23 | facexlib.detection.retinaface.device = device 24 | 25 | # Use GFPGAN for face enhancement 26 | self.face_enhancer = MyGFPGANer( 27 | model_path=model_path, 28 | upscale=1, 29 | arch="clean", 30 | channel_multiplier=2, 31 | device=device, 32 | bg_upsampler=upscaler.model if upscaler is not None else None, 33 | ) 34 | self.face_enhancer.face_helper.face_det.mean_tensor.to(device) 35 | self.face_enhancer.face_helper.face_det = ( 36 | self.face_enhancer.face_helper.face_det.to(device) 37 | ) 38 | 39 | def __call__(self, rgb_np_img, files, form): 40 | weight = 0.5 41 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) 42 | logger.info(f"GFPGAN input shape: {bgr_np_img.shape}") 43 | _, _, bgr_output = self.face_enhancer.enhance( 44 | bgr_np_img, 45 | has_aligned=False, 46 | only_center_face=False, 47 | paste_back=True, 48 | weight=weight, 49 | ) 50 | logger.info(f"GFPGAN output shape: {bgr_output.shape}") 51 | 52 | # try: 53 | # if scale != 2: 54 | # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 55 | # h, w = img.shape[0:2] 56 | # output = cv2.resize( 57 | # output, 58 | # (int(w * scale / 2), int(h * scale / 2)), 59 | # interpolation=interpolation, 60 | # ) 61 | # except Exception as error: 62 | # print("wrong scale input.", error) 63 | return bgr_output 64 | 65 | def check_dep(self): 66 | try: 67 | import gfpgan 68 | except ImportError: 69 | return ( 70 | "gfpgan is not installed, please install it first. pip install gfpgan" 71 | ) 72 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/gfpganer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 5 | from gfpgan import GFPGANv1Clean, GFPGANer 6 | from torch.hub import get_dir 7 | 8 | 9 | class MyGFPGANer(GFPGANer): 10 | """Helper for restoration with GFPGAN. 11 | 12 | It will detect and crop faces, and then resize the faces to 512x512. 13 | GFPGAN is used to restored the resized faces. 14 | The background is upsampled with the bg_upsampler. 15 | Finally, the faces will be pasted back to the upsample background image. 16 | 17 | Args: 18 | model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). 19 | upscale (float): The upscale of the final output. Default: 2. 20 | arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. 21 | channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. 22 | bg_upsampler (nn.Module): The upsampler for the background. Default: None. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_path, 28 | upscale=2, 29 | arch="clean", 30 | channel_multiplier=2, 31 | bg_upsampler=None, 32 | device=None, 33 | ): 34 | self.upscale = upscale 35 | self.bg_upsampler = bg_upsampler 36 | 37 | # initialize model 38 | self.device = ( 39 | torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | if device is None 41 | else device 42 | ) 43 | # initialize the GFP-GAN 44 | if arch == "clean": 45 | self.gfpgan = GFPGANv1Clean( 46 | out_size=512, 47 | num_style_feat=512, 48 | channel_multiplier=channel_multiplier, 49 | decoder_load_path=None, 50 | fix_decoder=False, 51 | num_mlp=8, 52 | input_is_latent=True, 53 | different_w=True, 54 | narrow=1, 55 | sft_half=True, 56 | ) 57 | elif arch == "RestoreFormer": 58 | from gfpgan.archs.restoreformer_arch import RestoreFormer 59 | 60 | self.gfpgan = RestoreFormer() 61 | 62 | hub_dir = get_dir() 63 | model_dir = os.path.join(hub_dir, "checkpoints") 64 | 65 | # initialize face helper 66 | self.face_helper = FaceRestoreHelper( 67 | upscale, 68 | face_size=512, 69 | crop_ratio=(1, 1), 70 | det_model="retinaface_resnet50", 71 | save_ext="png", 72 | use_parse=True, 73 | device=self.device, 74 | model_rootpath=model_dir, 75 | ) 76 | 77 | loadnet = torch.load(model_path) 78 | if "params_ema" in loadnet: 79 | keyname = "params_ema" 80 | else: 81 | keyname = "params" 82 | self.gfpgan.load_state_dict(loadnet[keyname], strict=True) 83 | self.gfpgan.eval() 84 | self.gfpgan = self.gfpgan.to(self.device) 85 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/gif.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | 4 | from PIL import Image, ImageDraw 5 | 6 | from lama_cleaner.helper import load_img 7 | from lama_cleaner.plugins.base_plugin import BasePlugin 8 | 9 | 10 | def keep_ratio_resize(img, size, resample=Image.BILINEAR): 11 | if img.width > img.height: 12 | w = size 13 | h = int(img.height * size / img.width) 14 | else: 15 | h = size 16 | w = int(img.width * size / img.height) 17 | return img.resize((w, h), resample) 18 | 19 | 20 | def cubic_bezier(p1, p2, duration: int, frames: int): 21 | """ 22 | 23 | Args: 24 | p1: 25 | p2: 26 | duration: Total duration of the curve 27 | frames: 28 | 29 | Returns: 30 | 31 | """ 32 | x0, y0 = (0, 0) 33 | x1, y1 = p1 34 | x2, y2 = p2 35 | x3, y3 = (1, 1) 36 | 37 | def cal_y(t): 38 | return ( 39 | math.pow(1 - t, 3) * y0 40 | + 3 * math.pow(1 - t, 2) * t * y1 41 | + 3 * (1 - t) * math.pow(t, 2) * y2 42 | + math.pow(t, 3) * y3 43 | ) 44 | 45 | def cal_x(t): 46 | return ( 47 | math.pow(1 - t, 3) * x0 48 | + 3 * math.pow(1 - t, 2) * t * x1 49 | + 3 * (1 - t) * math.pow(t, 2) * x2 50 | + math.pow(t, 3) * x3 51 | ) 52 | 53 | res = [] 54 | for t in range(0, 1 * frames, duration): 55 | t = t / frames 56 | res.append((cal_x(t), cal_y(t))) 57 | 58 | res.append((1, 0)) 59 | return res 60 | 61 | 62 | def make_compare_gif( 63 | clean_img: Image.Image, 64 | src_img: Image.Image, 65 | max_side_length: int = 600, 66 | splitter_width: int = 5, 67 | splitter_color=(255, 203, 0, int(255 * 0.73)), 68 | ): 69 | if clean_img.size != src_img.size: 70 | clean_img = clean_img.resize(src_img.size, Image.BILINEAR) 71 | 72 | duration_per_frame = 20 73 | num_frames = 50 74 | # erase-in-out 75 | cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames) 76 | cubic_bezier_points.reverse() 77 | 78 | max_side_length = min(max_side_length, max(clean_img.size)) 79 | 80 | src_img = keep_ratio_resize(src_img, max_side_length) 81 | clean_img = keep_ratio_resize(clean_img, max_side_length) 82 | width, height = src_img.size 83 | 84 | # Generate images to make Gif from right to left 85 | images = [] 86 | 87 | for i in range(num_frames): 88 | new_frame = Image.new("RGB", (width, height)) 89 | new_frame.paste(clean_img, (0, 0)) 90 | 91 | left = int(cubic_bezier_points[i][0] * width) 92 | cropped_src_img = src_img.crop((left, 0, width, height)) 93 | new_frame.paste(cropped_src_img, (left, 0, width, height)) 94 | if i != num_frames - 1: 95 | # draw a yellow splitter on the edge of the cropped image 96 | draw = ImageDraw.Draw(new_frame) 97 | draw.line( 98 | [(left, 0), (left, height)], width=splitter_width, fill=splitter_color 99 | ) 100 | images.append(new_frame) 101 | 102 | for i in range(30): 103 | images.append(src_img) 104 | 105 | cubic_bezier_points.reverse() 106 | # Generate images to make Gif from left to right 107 | for i in range(num_frames): 108 | new_frame = Image.new("RGB", (width, height)) 109 | new_frame.paste(src_img, (0, 0)) 110 | 111 | right = int(cubic_bezier_points[i][0] * width) 112 | cropped_src_img = clean_img.crop((0, 0, right, height)) 113 | new_frame.paste(cropped_src_img, (0, 0, right, height)) 114 | if i != num_frames - 1: 115 | # draw a yellow splitter on the edge of the cropped image 116 | draw = ImageDraw.Draw(new_frame) 117 | draw.line( 118 | [(right, 0), (right, height)], width=splitter_width, fill=splitter_color 119 | ) 120 | images.append(new_frame) 121 | 122 | for _ in range(30): 123 | images.append(clean_img) 124 | 125 | img_byte_arr = io.BytesIO() 126 | clean_img.save( 127 | img_byte_arr, 128 | format="GIF", 129 | save_all=True, 130 | include_color_table=True, 131 | append_images=images, 132 | optimize=False, 133 | duration=duration_per_frame, 134 | loop=0, 135 | ) 136 | return img_byte_arr.getvalue() 137 | 138 | 139 | class MakeGIF(BasePlugin): 140 | name = "MakeGIF" 141 | 142 | def __call__(self, rgb_np_img, files, form): 143 | origin_image = rgb_np_img 144 | clean_image_bytes = files["clean_img"].read() 145 | clean_image, _ = load_img(clean_image_bytes) 146 | gif_bytes = make_compare_gif( 147 | Image.fromarray(origin_image), Image.fromarray(clean_image) 148 | ) 149 | return gif_bytes 150 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/interactive_seg.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import cv2 4 | import numpy as np 5 | from loguru import logger 6 | 7 | from lama_cleaner.helper import download_model 8 | from lama_cleaner.plugins.base_plugin import BasePlugin 9 | from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry 10 | 11 | # 从小到大 12 | SEGMENT_ANYTHING_MODELS = { 13 | "vit_b": { 14 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 15 | "md5": "01ec64d29a2fca3f0661936605ae66f8", 16 | }, 17 | "vit_l": { 18 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 19 | "md5": "0b3195507c641ddb6910d2bb5adee89c", 20 | }, 21 | "vit_h": { 22 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 23 | "md5": "4b8939a88964f0f4ff5f5b2642c598a6", 24 | }, 25 | } 26 | 27 | 28 | class InteractiveSeg(BasePlugin): 29 | name = "InteractiveSeg" 30 | 31 | def __init__(self, model_name, device): 32 | super().__init__() 33 | model_path = download_model( 34 | SEGMENT_ANYTHING_MODELS[model_name]["url"], 35 | SEGMENT_ANYTHING_MODELS[model_name]["md5"], 36 | ) 37 | logger.info(f"SegmentAnything model path: {model_path}") 38 | self.predictor = SamPredictor( 39 | sam_model_registry[model_name](checkpoint=model_path).to(device) 40 | ) 41 | self.prev_img_md5 = None 42 | 43 | def __call__(self, rgb_np_img, files, form): 44 | clicks = json.loads(form["clicks"]) 45 | return self.forward(rgb_np_img, clicks, form["img_md5"]) 46 | 47 | def forward(self, rgb_np_img, clicks, img_md5): 48 | input_point = [] 49 | input_label = [] 50 | for click in clicks: 51 | x = click[0] 52 | y = click[1] 53 | input_point.append([x, y]) 54 | input_label.append(click[2]) 55 | 56 | if img_md5 and img_md5 != self.prev_img_md5: 57 | self.prev_img_md5 = img_md5 58 | self.predictor.set_image(rgb_np_img) 59 | 60 | masks, scores, _ = self.predictor.predict( 61 | point_coords=np.array(input_point), 62 | point_labels=np.array(input_label), 63 | multimask_output=False, 64 | ) 65 | mask = masks[0].astype(np.uint8) * 255 66 | # TODO: how to set kernel size? 67 | kernel_size = 9 68 | mask = cv2.dilate( 69 | mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1 70 | ) 71 | # fronted brush color "ffcc00bb" 72 | res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) 73 | res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)] 74 | res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) 75 | return res_mask 76 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/realesrgan.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import cv2 4 | from loguru import logger 5 | 6 | from lama_cleaner.const import RealESRGANModelName 7 | from lama_cleaner.helper import download_model 8 | from lama_cleaner.plugins.base_plugin import BasePlugin 9 | 10 | 11 | class RealESRGANUpscaler(BasePlugin): 12 | name = "RealESRGAN" 13 | 14 | def __init__(self, name, device, no_half=False): 15 | super().__init__() 16 | from basicsr.archs.rrdbnet_arch import RRDBNet 17 | from realesrgan import RealESRGANer 18 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 19 | 20 | REAL_ESRGAN_MODELS = { 21 | RealESRGANModelName.realesr_general_x4v3: { 22 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 23 | "scale": 4, 24 | "model": lambda: SRVGGNetCompact( 25 | num_in_ch=3, 26 | num_out_ch=3, 27 | num_feat=64, 28 | num_conv=32, 29 | upscale=4, 30 | act_type="prelu", 31 | ), 32 | "model_md5": "91a7644643c884ee00737db24e478156", 33 | }, 34 | RealESRGANModelName.RealESRGAN_x4plus: { 35 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 36 | "scale": 4, 37 | "model": lambda: RRDBNet( 38 | num_in_ch=3, 39 | num_out_ch=3, 40 | num_feat=64, 41 | num_block=23, 42 | num_grow_ch=32, 43 | scale=4, 44 | ), 45 | "model_md5": "99ec365d4afad750833258a1a24f44ca", 46 | }, 47 | RealESRGANModelName.RealESRGAN_x4plus_anime_6B: { 48 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 49 | "scale": 4, 50 | "model": lambda: RRDBNet( 51 | num_in_ch=3, 52 | num_out_ch=3, 53 | num_feat=64, 54 | num_block=6, 55 | num_grow_ch=32, 56 | scale=4, 57 | ), 58 | "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", 59 | }, 60 | } 61 | if name not in REAL_ESRGAN_MODELS: 62 | raise ValueError(f"Unknown RealESRGAN model name: {name}") 63 | model_info = REAL_ESRGAN_MODELS[name] 64 | 65 | model_path = download_model(model_info["url"], model_info["model_md5"]) 66 | logger.info(f"RealESRGAN model path: {model_path}") 67 | 68 | self.model = RealESRGANer( 69 | scale=model_info["scale"], 70 | model_path=model_path, 71 | model=model_info["model"](), 72 | half=True if "cuda" in str(device) and not no_half else False, 73 | tile=512, 74 | tile_pad=10, 75 | pre_pad=10, 76 | device=device, 77 | ) 78 | 79 | def __call__(self, rgb_np_img, files, form): 80 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) 81 | scale = float(form["upscale"]) 82 | logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}") 83 | result = self.forward(bgr_np_img, scale) 84 | logger.info(f"RealESRGAN output shape: {result.shape}") 85 | return result 86 | 87 | def forward(self, bgr_np_img, scale: float): 88 | # 输出是 BGR 89 | upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] 90 | return upsampled 91 | 92 | def check_dep(self): 93 | try: 94 | import realesrgan 95 | except ImportError: 96 | return "RealESRGAN is not installed, please install it first. pip install realesrgan" 97 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/remove_bg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.hub import get_dir 5 | 6 | from lama_cleaner.plugins.base_plugin import BasePlugin 7 | 8 | 9 | class RemoveBG(BasePlugin): 10 | name = "RemoveBG" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | from rembg import new_session 15 | 16 | hub_dir = get_dir() 17 | model_dir = os.path.join(hub_dir, "checkpoints") 18 | os.environ["U2NET_HOME"] = model_dir 19 | 20 | self.session = new_session(model_name="u2net") 21 | 22 | def __call__(self, rgb_np_img, files, form): 23 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) 24 | return self.forward(bgr_np_img) 25 | 26 | def forward(self, bgr_np_img) -> np.ndarray: 27 | from rembg import remove 28 | 29 | # return BGRA image 30 | output = remove(bgr_np_img, session=self.session) 31 | return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) 32 | 33 | def check_dep(self): 34 | try: 35 | import rembg 36 | except ImportError: 37 | return ( 38 | "RemoveBG is not installed, please install it first. pip install rembg" 39 | ) 40 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/restoreformer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from loguru import logger 3 | 4 | from lama_cleaner.helper import download_model 5 | from lama_cleaner.plugins.base_plugin import BasePlugin 6 | 7 | 8 | class RestoreFormerPlugin(BasePlugin): 9 | name = "RestoreFormer" 10 | 11 | def __init__(self, device, upscaler=None): 12 | super().__init__() 13 | from .gfpganer import MyGFPGANer 14 | 15 | url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" 16 | model_md5 = "eaeeff6c4a1caa1673977cb374e6f699" 17 | model_path = download_model(url, model_md5) 18 | logger.info(f"RestoreFormer model path: {model_path}") 19 | 20 | import facexlib 21 | 22 | if hasattr(facexlib.detection.retinaface, "device"): 23 | facexlib.detection.retinaface.device = device 24 | 25 | self.face_enhancer = MyGFPGANer( 26 | model_path=model_path, 27 | upscale=1, 28 | arch="RestoreFormer", 29 | channel_multiplier=2, 30 | device=device, 31 | bg_upsampler=upscaler.model if upscaler is not None else None, 32 | ) 33 | 34 | def __call__(self, rgb_np_img, files, form): 35 | weight = 0.5 36 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) 37 | logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}") 38 | _, _, bgr_output = self.face_enhancer.enhance( 39 | bgr_np_img, 40 | has_aligned=False, 41 | only_center_face=False, 42 | paste_back=True, 43 | weight=weight, 44 | ) 45 | logger.info(f"RestoreFormer output shape: {bgr_output.shape}") 46 | return bgr_output 47 | 48 | def check_dep(self): 49 | try: 50 | import gfpgan 51 | except ImportError: 52 | return ( 53 | "gfpgan is not installed, please install it first. pip install gfpgan" 54 | ) 55 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /lama_cleaner/plugins/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape( 31 | image.shape[0], image.shape[1], self.target_length 32 | ) 33 | return np.array(resize(to_pil_image(image), target_size)) 34 | 35 | def apply_coords( 36 | self, coords: np.ndarray, original_size: Tuple[int, ...] 37 | ) -> np.ndarray: 38 | """ 39 | Expects a numpy array of length 2 in the final dimension. Requires the 40 | original image size in (H, W) format. 41 | """ 42 | old_h, old_w = original_size 43 | new_h, new_w = self.get_preprocess_shape( 44 | original_size[0], original_size[1], self.target_length 45 | ) 46 | coords = deepcopy(coords).astype(float) 47 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 48 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 49 | return coords 50 | 51 | def apply_boxes( 52 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 53 | ) -> np.ndarray: 54 | """ 55 | Expects a numpy array shape Bx4. Requires the original image size 56 | in (H, W) format. 57 | """ 58 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 59 | return boxes.reshape(-1, 4) 60 | 61 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 62 | """ 63 | Expects batched images with shape BxCxHxW and float format. This 64 | transformation may not exactly match apply_image. apply_image is 65 | the transformation expected by the model. 66 | """ 67 | # Expects an image in BCHW format. May not exactly match apply_image. 68 | target_size = self.get_preprocess_shape( 69 | image.shape[0], image.shape[1], self.target_length 70 | ) 71 | return F.interpolate( 72 | image, target_size, mode="bilinear", align_corners=False, antialias=True 73 | ) 74 | 75 | def apply_coords_torch( 76 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 77 | ) -> torch.Tensor: 78 | """ 79 | Expects a torch tensor with length 2 in the last dimension. Requires the 80 | original image size in (H, W) format. 81 | """ 82 | old_h, old_w = original_size 83 | new_h, new_w = self.get_preprocess_shape( 84 | original_size[0], original_size[1], self.target_length 85 | ) 86 | coords = deepcopy(coords).to(torch.float) 87 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 88 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 89 | return coords 90 | 91 | def apply_boxes_torch( 92 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 93 | ) -> torch.Tensor: 94 | """ 95 | Expects a torch tensor with shape Bx4. Requires the original image 96 | size in (H, W) format. 97 | """ 98 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 99 | return boxes.reshape(-1, 4) 100 | 101 | @staticmethod 102 | def get_preprocess_shape( 103 | oldh: int, oldw: int, long_side_length: int 104 | ) -> Tuple[int, int]: 105 | """ 106 | Compute the output size given input size and target long side length. 107 | """ 108 | scale = long_side_length * 1.0 / max(oldh, oldw) 109 | newh, neww = oldh * scale, oldw * scale 110 | neww = int(neww + 0.5) 111 | newh = int(newh + 0.5) 112 | return (newh, neww) 113 | -------------------------------------------------------------------------------- /lama_cleaner/runtime.py: -------------------------------------------------------------------------------- 1 | # https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py 2 | import platform 3 | import sys 4 | import packaging.version 5 | from rich import print 6 | from typing import Dict, Any 7 | 8 | _PY_VERSION: str = sys.version.split()[0].rstrip("+") 9 | 10 | if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"): 11 | import importlib_metadata # type: ignore 12 | else: 13 | import importlib.metadata as importlib_metadata # type: ignore 14 | 15 | _package_versions = {} 16 | 17 | _CANDIDATES = [ 18 | "torch", 19 | "torchvision", 20 | "Pillow", 21 | "diffusers", 22 | "transformers", 23 | "opencv-python", 24 | "xformers", 25 | "accelerate", 26 | "lama-cleaner", 27 | "rembg", 28 | "realesrgan", 29 | "gfpgan", 30 | ] 31 | # Check once at runtime 32 | for name in _CANDIDATES: 33 | _package_versions[name] = "N/A" 34 | try: 35 | _package_versions[name] = importlib_metadata.version(name) 36 | except importlib_metadata.PackageNotFoundError: 37 | pass 38 | 39 | 40 | def dump_environment_info() -> Dict[str, str]: 41 | """Dump information about the machine to help debugging issues. """ 42 | 43 | # Generic machine info 44 | info: Dict[str, Any] = { 45 | "Platform": platform.platform(), 46 | "Python version": platform.python_version(), 47 | } 48 | info.update(_package_versions) 49 | print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") 50 | return info 51 | -------------------------------------------------------------------------------- /lama_cleaner/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from enum import Enum 3 | 4 | from PIL.Image import Image 5 | from pydantic import BaseModel 6 | 7 | 8 | class HDStrategy(str, Enum): 9 | # Use original image size 10 | ORIGINAL = "Original" 11 | # Resize the longer side of the image to a specific size(hd_strategy_resize_limit), 12 | # then do inpainting on the resized image. Finally, resize the inpainting result to the original size. 13 | # The area outside the mask will not lose quality. 14 | RESIZE = "Resize" 15 | # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting 16 | CROP = "Crop" 17 | 18 | 19 | class LDMSampler(str, Enum): 20 | ddim = "ddim" 21 | plms = "plms" 22 | 23 | 24 | class SDSampler(str, Enum): 25 | ddim = "ddim" 26 | pndm = "pndm" 27 | k_lms = "k_lms" 28 | k_euler = "k_euler" 29 | k_euler_a = "k_euler_a" 30 | dpm_plus_plus = "dpm++" 31 | uni_pc = "uni_pc" 32 | 33 | 34 | class Config(BaseModel): 35 | class Config: 36 | arbitrary_types_allowed = True 37 | 38 | # Configs for ldm model 39 | ldm_steps: int 40 | ldm_sampler: str = LDMSampler.plms 41 | 42 | # Configs for zits model 43 | zits_wireframe: bool = True 44 | 45 | # Configs for High Resolution Strategy(different way to preprocess image) 46 | hd_strategy: str # See HDStrategy Enum 47 | hd_strategy_crop_margin: int 48 | # If the longer side of the image is larger than this value, use crop strategy 49 | hd_strategy_crop_trigger_size: int 50 | hd_strategy_resize_limit: int 51 | 52 | # Configs for Stable Diffusion 1.5 53 | prompt: str = "" 54 | negative_prompt: str = "" 55 | # Crop image to this size before doing sd inpainting 56 | # The value is always on the original image scale 57 | use_croper: bool = False 58 | croper_x: int = None 59 | croper_y: int = None 60 | croper_height: int = None 61 | croper_width: int = None 62 | 63 | # Resize the image before doing sd inpainting, the area outside the mask will not lose quality. 64 | # Used by sd models and paint_by_example model 65 | sd_scale: float = 1.0 66 | # Blur the edge of mask area. The higher the number the smoother blend with the original image 67 | sd_mask_blur: int = 0 68 | # Ignore this value, it's useless for inpainting 69 | sd_strength: float = 0.75 70 | # The number of denoising steps. More denoising steps usually lead to a 71 | # higher quality image at the expense of slower inference. 72 | sd_steps: int = 50 73 | # Higher guidance scale encourages to generate images that are closely linked 74 | # to the text prompt, usually at the expense of lower image quality. 75 | sd_guidance_scale: float = 7.5 76 | sd_sampler: str = SDSampler.uni_pc 77 | # -1 mean random seed 78 | sd_seed: int = 42 79 | sd_match_histograms: bool = False 80 | 81 | # Configs for opencv inpainting 82 | # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 83 | cv2_flag: str = "INPAINT_NS" 84 | cv2_radius: int = 4 85 | 86 | # Paint by Example 87 | paint_by_example_steps: int = 50 88 | paint_by_example_guidance_scale: float = 7.5 89 | paint_by_example_mask_blur: int = 0 90 | paint_by_example_seed: int = 42 91 | paint_by_example_match_histograms: bool = False 92 | paint_by_example_example_image: Optional[Image] = None 93 | 94 | # InstructPix2Pix 95 | p2p_steps: int = 50 96 | p2p_image_guidance_scale: float = 1.5 97 | p2p_guidance_scale: float = 7.5 98 | 99 | # ControlNet 100 | controlnet_conditioning_scale: float = 0.4 101 | controlnet_method: str = "control_v11p_sd15_canny" 102 | -------------------------------------------------------------------------------- /lama_cleaner/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/lama_cleaner/tests/__init__.py -------------------------------------------------------------------------------- /lama_cleaner/tests/test_instruct_pix2pix.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from lama_cleaner.model_manager import ModelManager 7 | from lama_cleaner.tests.test_model import get_config, assert_equal 8 | from lama_cleaner.schema import HDStrategy 9 | 10 | current_dir = Path(__file__).parent.absolute().resolve() 11 | save_dir = current_dir / 'result' 12 | save_dir.mkdir(exist_ok=True, parents=True) 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | 16 | @pytest.mark.parametrize("disable_nsfw", [True, False]) 17 | @pytest.mark.parametrize("cpu_offload", [False, True]) 18 | def test_instruct_pix2pix(disable_nsfw, cpu_offload): 19 | sd_steps = 50 if device == 'cuda' else 1 20 | model = ModelManager(name="instruct_pix2pix", 21 | device=torch.device(device), 22 | hf_access_token="", 23 | sd_run_local=False, 24 | disable_nsfw=disable_nsfw, 25 | sd_cpu_textencoder=False, 26 | cpu_offload=cpu_offload) 27 | cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1) 28 | 29 | name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}" 30 | 31 | assert_equal( 32 | model, 33 | cfg, 34 | f"instruct_pix2pix_{name}.png", 35 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 36 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 37 | fx=1.3 38 | ) 39 | 40 | 41 | @pytest.mark.parametrize("disable_nsfw", [False]) 42 | @pytest.mark.parametrize("cpu_offload", [False]) 43 | def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload): 44 | sd_steps = 50 if device == 'cuda' else 1 45 | model = ModelManager(name="instruct_pix2pix", 46 | device=torch.device(device), 47 | hf_access_token="", 48 | sd_run_local=False, 49 | disable_nsfw=disable_nsfw, 50 | sd_cpu_textencoder=False, 51 | cpu_offload=cpu_offload) 52 | cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps) 53 | 54 | name = f"snow" 55 | 56 | assert_equal( 57 | model, 58 | cfg, 59 | f"instruct_pix2pix_{name}.png", 60 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 61 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 62 | ) 63 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_interactive_seg.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from lama_cleaner.plugins import InteractiveSeg, Click 7 | 8 | current_dir = Path(__file__).parent.absolute().resolve() 9 | save_dir = current_dir / "result" 10 | save_dir.mkdir(exist_ok=True, parents=True) 11 | img_p = current_dir / "overture-creations-5sI6fQgYIuo.png" 12 | 13 | 14 | def test_interactive_seg(): 15 | interactive_seg_model = InteractiveSeg() 16 | img = cv2.imread(str(img_p)) 17 | pred = interactive_seg_model.forward( 18 | img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)] 19 | ) 20 | cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred) 21 | 22 | 23 | def test_interactive_seg_with_negative_click(): 24 | interactive_seg_model = InteractiveSeg() 25 | img = cv2.imread(str(img_p)) 26 | pred = interactive_seg_model.forward( 27 | img, 28 | clicks=[ 29 | Click(coords=(256, 256), indx=0, is_positive=True), 30 | Click(coords=(384, 256), indx=1, is_positive=False), 31 | ], 32 | ) 33 | cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred) 34 | 35 | 36 | def test_interactive_seg_with_prev_mask(): 37 | interactive_seg_model = InteractiveSeg() 38 | img = cv2.imread(str(img_p)) 39 | mask = np.zeros_like(img)[:, :, 0] 40 | pred = interactive_seg_model.forward( 41 | img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask 42 | ) 43 | cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred) 44 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_load_img.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from lama_cleaner.helper import load_img 4 | 5 | current_dir = Path(__file__).parent.absolute().resolve() 6 | png_img_p = current_dir / "image.png" 7 | jpg_img_p = current_dir / "bunny.jpeg" 8 | 9 | 10 | def test_load_png_image(): 11 | with open(png_img_p, "rb") as f: 12 | np_img, alpha_channel = load_img(f.read()) 13 | assert np_img.shape == (256, 256, 3) 14 | assert alpha_channel.shape == (256, 256) 15 | 16 | 17 | def test_load_jpg_image(): 18 | with open(jpg_img_p, "rb") as f: 19 | np_img, alpha_channel = load_img(f.read()) 20 | assert np_img.shape == (394, 448, 3) 21 | assert alpha_channel is None 22 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_model_md5.py: -------------------------------------------------------------------------------- 1 | def test_load_model(): 2 | from lama_cleaner.plugins import InteractiveSeg 3 | from lama_cleaner.model_manager import ModelManager 4 | 5 | interactive_seg_model = InteractiveSeg('vit_l', 'cpu') 6 | 7 | models = [ 8 | "lama", 9 | "ldm", 10 | "zits", 11 | "mat", 12 | "fcf", 13 | "manga", 14 | ] 15 | for m in models: 16 | ModelManager( 17 | name=m, 18 | device="cpu", 19 | no_half=False, 20 | hf_access_token="", 21 | disable_nsfw=False, 22 | sd_cpu_textencoder=True, 23 | sd_run_local=True, 24 | local_files_only=True, 25 | cpu_offload=True, 26 | enable_xformers=False, 27 | ) 28 | 29 | 30 | # def create_empty_file(tmp_dir, name): 31 | # tmp_model_dir = os.path.join(tmp_dir, "torch", "hub", "checkpoints") 32 | # Path(tmp_model_dir).mkdir(exist_ok=True, parents=True) 33 | # path = os.path.join(tmp_model_dir, name) 34 | # with open(path, "w") as f: 35 | # f.write("1") 36 | # 37 | # 38 | # def test_load_model_error(): 39 | # MODELS = [ 40 | # ("big-lama.pt", "e3aa4aaa15225a33ec84f9f4bc47e500"), 41 | # ("cond_stage_model_encode.pt", "23239fc9081956a3e70de56472b3f296"), 42 | # ("cond_stage_model_decode.pt", "fe419cd15a750d37a4733589d0d3585c"), 43 | # ("diffusion.pt", "b0afda12bf790c03aba2a7431f11d22d"), 44 | # ] 45 | # with tempfile.TemporaryDirectory() as tmp_dir: 46 | # os.environ["XDG_CACHE_HOME"] = tmp_dir 47 | # for name, md5 in MODELS: 48 | # create_empty_file(tmp_dir, name) 49 | # test_load_model() 50 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_paint_by_example.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import pytest 5 | import torch 6 | from PIL import Image 7 | 8 | from lama_cleaner.model_manager import ModelManager 9 | from lama_cleaner.schema import HDStrategy 10 | from lama_cleaner.tests.test_model import get_config, get_data 11 | 12 | current_dir = Path(__file__).parent.absolute().resolve() 13 | save_dir = current_dir / 'result' 14 | save_dir.mkdir(exist_ok=True, parents=True) 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | device = torch.device(device) 17 | 18 | 19 | def assert_equal( 20 | model, config, gt_name, 21 | fx: float = 1, fy: float = 1, 22 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 23 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 24 | example_p=current_dir / "bunny.jpeg", 25 | ): 26 | img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) 27 | 28 | example_image = cv2.imread(str(example_p)) 29 | example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB) 30 | example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) 31 | 32 | print(f"Input image shape: {img.shape}, example_image: {example_image.shape}") 33 | config.paint_by_example_example_image = Image.fromarray(example_image) 34 | res = model(img, mask, config) 35 | cv2.imwrite(str(save_dir / gt_name), res) 36 | 37 | 38 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) 39 | def test_paint_by_example(strategy): 40 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) 41 | cfg = get_config(strategy, paint_by_example_steps=30) 42 | assert_equal( 43 | model, 44 | cfg, 45 | f"paint_by_example_{strategy.capitalize()}.png", 46 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 47 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 48 | fy=0.9, 49 | fx=1.3, 50 | ) 51 | 52 | 53 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) 54 | def test_paint_by_example_disable_nsfw(strategy): 55 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False) 56 | cfg = get_config(strategy, paint_by_example_steps=30) 57 | assert_equal( 58 | model, 59 | cfg, 60 | f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png", 61 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 62 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 63 | ) 64 | 65 | 66 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) 67 | def test_paint_by_example_sd_scale(strategy): 68 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) 69 | cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) 70 | assert_equal( 71 | model, 72 | cfg, 73 | f"paint_by_example_{strategy.capitalize()}_sdscale.png", 74 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 75 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 76 | fy=0.9, 77 | fx=1.3 78 | ) 79 | 80 | 81 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) 82 | def test_paint_by_example_cpu_offload(strategy): 83 | model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False) 84 | cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) 85 | assert_equal( 86 | model, 87 | cfg, 88 | f"paint_by_example_{strategy.capitalize()}_cpu_offload.png", 89 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 90 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 91 | ) 92 | 93 | 94 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) 95 | def test_paint_by_example_cpu_offload_cpu_device(strategy): 96 | model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True) 97 | cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85) 98 | assert_equal( 99 | model, 100 | cfg, 101 | f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png", 102 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", 103 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", 104 | fy=0.9, 105 | fx=1.3 106 | ) 107 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_plugins.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import time 4 | 5 | from lama_cleaner.plugins.anime_seg import AnimeSeg 6 | 7 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 8 | from pathlib import Path 9 | 10 | import cv2 11 | import pytest 12 | import torch.cuda 13 | 14 | from lama_cleaner.plugins import ( 15 | RemoveBG, 16 | RealESRGANUpscaler, 17 | GFPGANPlugin, 18 | RestoreFormerPlugin, 19 | InteractiveSeg, 20 | ) 21 | 22 | current_dir = Path(__file__).parent.absolute().resolve() 23 | save_dir = current_dir / "result" 24 | save_dir.mkdir(exist_ok=True, parents=True) 25 | img_p = current_dir / "bunny.jpeg" 26 | img_bytes = open(img_p, "rb").read() 27 | bgr_img = cv2.imread(str(img_p)) 28 | rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) 29 | 30 | 31 | def _save(img, name): 32 | cv2.imwrite(str(save_dir / name), img) 33 | 34 | 35 | def test_remove_bg(): 36 | model = RemoveBG() 37 | res = model.forward(bgr_img) 38 | res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA) 39 | _save(res, "test_remove_bg.png") 40 | 41 | 42 | def test_anime_seg(): 43 | model = AnimeSeg() 44 | img = cv2.imread(str(current_dir / "anime_test.png")) 45 | res = model.forward(img) 46 | assert len(res.shape) == 3 47 | assert res.shape[-1] == 4 48 | _save(res, "test_anime_seg.png") 49 | 50 | 51 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) 52 | def test_upscale(device): 53 | if device == "cuda" and not torch.cuda.is_available(): 54 | return 55 | if device == "mps" and not torch.backends.mps.is_available(): 56 | return 57 | 58 | model = RealESRGANUpscaler("realesr-general-x4v3", device) 59 | res = model.forward(bgr_img, 2) 60 | _save(res, f"test_upscale_x2_{device}.png") 61 | 62 | res = model.forward(bgr_img, 4) 63 | _save(res, f"test_upscale_x4_{device}.png") 64 | 65 | 66 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) 67 | def test_gfpgan(device): 68 | if device == "cuda" and not torch.cuda.is_available(): 69 | return 70 | if device == "mps" and not torch.backends.mps.is_available(): 71 | return 72 | model = GFPGANPlugin(device) 73 | res = model(rgb_img, None, None) 74 | _save(res, f"test_gfpgan_{device}.png") 75 | 76 | 77 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) 78 | def test_restoreformer(device): 79 | if device == "cuda" and not torch.cuda.is_available(): 80 | return 81 | if device == "mps" and not torch.backends.mps.is_available(): 82 | return 83 | model = RestoreFormerPlugin(device) 84 | res = model(rgb_img, None, None) 85 | _save(res, f"test_restoreformer_{device}.png") 86 | 87 | 88 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) 89 | def test_segment_anything(device): 90 | if device == "cuda" and not torch.cuda.is_available(): 91 | return 92 | if device == "mps" and not torch.backends.mps.is_available(): 93 | return 94 | img_md5 = hashlib.md5(img_bytes).hexdigest() 95 | model = InteractiveSeg("vit_l", device) 96 | new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) 97 | 98 | save_name = f"test_segment_anything_{device}.png" 99 | _save(new_mask, save_name) 100 | 101 | start = time.time() 102 | model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5) 103 | print(f"Time for {save_name}: {time.time() - start:.2f}s") 104 | -------------------------------------------------------------------------------- /lama_cleaner/tests/test_save_exif.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | 4 | from PIL import Image 5 | 6 | from lama_cleaner.helper import pil_to_bytes, load_img 7 | 8 | current_dir = Path(__file__).parent.absolute().resolve() 9 | 10 | 11 | def print_exif(exif): 12 | for k, v in exif.items(): 13 | print(f"{k}: {v}") 14 | 15 | 16 | def run_test(img_p: Path): 17 | print(img_p) 18 | ext = img_p.suffix.strip(".") 19 | img_bytes = img_p.read_bytes() 20 | np_img, _, exif_infos = load_img(img_bytes, False, True) 21 | print(exif_infos) 22 | print("Original exif_infos") 23 | print_exif(exif_infos["exif"]) 24 | 25 | pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos={}) 26 | 27 | pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos=exif_infos) 28 | res_img = Image.open(io.BytesIO(pil_bytes)) 29 | print(f"Result img info: {res_img.info}") 30 | res_exif = res_img.getexif() 31 | print_exif(res_exif) 32 | assert res_exif == exif_infos["exif"] 33 | assert exif_infos["parameters"] == res_img.info.get("parameters") 34 | 35 | 36 | def test_png(): 37 | run_test(current_dir / "image.png") 38 | run_test(current_dir / "pnginfo_test.png") 39 | 40 | 41 | def test_jpeg(): 42 | jpg_img_p = current_dir / "bunny.jpeg" 43 | run_test(jpg_img_p) 44 | -------------------------------------------------------------------------------- /mobile_sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", category=UserWarning, module="mobile_sam") 10 | 11 | from .automatic_mask_generator import SamAutomaticMaskGenerator # noqa: E402 12 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, # noqa: E402 13 | build_sam_vit_t, sam_model_registry) 14 | from .predictor import SamPredictor # noqa: E402 15 | 16 | __all__ = [ 17 | "build_sam", 18 | "build_sam_vit_h", 19 | "build_sam_vit_l", 20 | "build_sam_vit_b", 21 | "build_sam_vit_t", 22 | "sam_model_registry", 23 | "SamPredictor", 24 | "SamAutomaticMaskGenerator", 25 | ] 26 | -------------------------------------------------------------------------------- /mobile_sam/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | def build_sam_vit_t(checkpoint=None): 48 | prompt_embed_dim = 256 49 | image_size = 1024 50 | vit_patch_size = 16 51 | image_embedding_size = image_size // vit_patch_size 52 | mobile_sam = Sam( 53 | image_encoder=TinyViT( 54 | img_size=1024, in_chans=3, num_classes=1000, 55 | embed_dims=[64, 128, 160, 320], 56 | depths=[2, 2, 6, 2], 57 | num_heads=[2, 4, 5, 10], 58 | window_sizes=[7, 7, 14, 7], 59 | mlp_ratio=4., 60 | drop_rate=0., 61 | drop_path_rate=0.0, 62 | use_checkpoint=False, 63 | mbconv_expand_ratio=4.0, 64 | local_conv_size=3, 65 | layer_lr_decay=0.8 66 | ), 67 | prompt_encoder=PromptEncoder( 68 | embed_dim=prompt_embed_dim, 69 | image_embedding_size=(image_embedding_size, image_embedding_size), 70 | input_image_size=(image_size, image_size), 71 | mask_in_chans=16, 72 | ), 73 | mask_decoder=MaskDecoder( 74 | num_multimask_outputs=3, 75 | transformer=TwoWayTransformer( 76 | depth=2, 77 | embedding_dim=prompt_embed_dim, 78 | mlp_dim=2048, 79 | num_heads=8, 80 | ), 81 | transformer_dim=prompt_embed_dim, 82 | iou_head_depth=3, 83 | iou_head_hidden_dim=256, 84 | ), 85 | pixel_mean=[123.675, 116.28, 103.53], 86 | pixel_std=[58.395, 57.12, 57.375], 87 | ) 88 | 89 | mobile_sam.eval() 90 | if checkpoint is not None: 91 | with open(checkpoint, "rb") as f: 92 | state_dict = torch.load(f) 93 | mobile_sam.load_state_dict(state_dict) 94 | return mobile_sam 95 | 96 | 97 | sam_model_registry = { 98 | "default": build_sam_vit_h, 99 | "vit_h": build_sam_vit_h, 100 | "vit_l": build_sam_vit_l, 101 | "vit_b": build_sam_vit_b, 102 | "vit_t": build_sam_vit_t, 103 | } 104 | 105 | 106 | def _build_sam( 107 | encoder_embed_dim, 108 | encoder_depth, 109 | encoder_num_heads, 110 | encoder_global_attn_indexes, 111 | checkpoint=None, 112 | ): 113 | prompt_embed_dim = 256 114 | image_size = 1024 115 | vit_patch_size = 16 116 | image_embedding_size = image_size // vit_patch_size 117 | sam = Sam( 118 | image_encoder=ImageEncoderViT( 119 | depth=encoder_depth, 120 | embed_dim=encoder_embed_dim, 121 | img_size=image_size, 122 | mlp_ratio=4, 123 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 124 | num_heads=encoder_num_heads, 125 | patch_size=vit_patch_size, 126 | qkv_bias=True, 127 | use_rel_pos=True, 128 | global_attn_indexes=encoder_global_attn_indexes, 129 | window_size=14, 130 | out_chans=prompt_embed_dim, 131 | ), 132 | prompt_encoder=PromptEncoder( 133 | embed_dim=prompt_embed_dim, 134 | image_embedding_size=(image_embedding_size, image_embedding_size), 135 | input_image_size=(image_size, image_size), 136 | mask_in_chans=16, 137 | ), 138 | mask_decoder=MaskDecoder( 139 | num_multimask_outputs=3, 140 | transformer=TwoWayTransformer( 141 | depth=2, 142 | embedding_dim=prompt_embed_dim, 143 | mlp_dim=2048, 144 | num_heads=8, 145 | ), 146 | transformer_dim=prompt_embed_dim, 147 | iou_head_depth=3, 148 | iou_head_hidden_dim=256, 149 | ), 150 | pixel_mean=[123.675, 116.28, 103.53], 151 | pixel_std=[58.395, 57.12, 57.375], 152 | ) 153 | sam.eval() 154 | if checkpoint is not None: 155 | with open(checkpoint, "rb") as f: 156 | state_dict = torch.load(f) 157 | sam.load_state_dict(state_dict) 158 | return sam 159 | -------------------------------------------------------------------------------- /mobile_sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT 13 | 14 | __all__ = [ 15 | "Sam", 16 | "ImageEncoderViT", 17 | "MaskDecoder", 18 | "PromptEncoder", 19 | "TwoWayTransformer", 20 | "TinyViT", 21 | ] 22 | -------------------------------------------------------------------------------- /mobile_sam/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /mobile_sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mobile_sam/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /mobile_sam/utils/torch_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_iou 3 | 4 | 5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: 6 | order = torch.argsort(-scores) 7 | keep = [] 8 | 9 | while order.numel() > 0: 10 | i = order[0] 11 | keep.append(i.item()) 12 | 13 | if order.numel() == 1: 14 | break 15 | 16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] 17 | mask = ious <= iou_threshold 18 | order = order[1:][mask] 19 | 20 | return torch.tensor(keep, device=bboxes.device) 21 | -------------------------------------------------------------------------------- /mobile_sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | torch==2.3.1 3 | torchvision 4 | accelerate 5 | diffusers<0.32.0 6 | gradio<4.0.0 7 | huggingface-hub 8 | numpy<2.0.0 9 | opencv-python 10 | pillow 11 | segment-anything 12 | transformers<5.0.0 13 | xformers==0.0.27 14 | # lama-cleaner 15 | ultralytics 16 | tqdm 17 | packaging 18 | loguru 19 | rich 20 | pydantic 21 | timm 22 | onnxruntime 23 | hydra-core 24 | iopath 25 | -------------------------------------------------------------------------------- /requirements_cu118.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | torch==2.3.1 3 | torchvision 4 | accelerate 5 | diffusers<0.32.0 6 | gradio<4.0.0 7 | huggingface-hub 8 | numpy<2.0.0 9 | opencv-python 10 | pillow 11 | segment-anything 12 | transformers<5.0.0 13 | xformers==0.0.27 14 | # lama-cleaner 15 | ultralytics 16 | tqdm 17 | packaging 18 | loguru 19 | rich 20 | pydantic 21 | timm 22 | onnxruntime 23 | hydra-core 24 | iopath 25 | -------------------------------------------------------------------------------- /requirements_mac.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.2 2 | torchvision 3 | accelerate 4 | diffusers<0.32.0 5 | gradio<4.0.0 6 | huggingface-hub 7 | numpy<2.0.0 8 | opencv-python 9 | pillow 10 | segment-anything 11 | transformers<5.0.0 12 | # xformers 13 | # lama-cleaner 14 | ultralytics 15 | tqdm 16 | packaging 17 | loguru 18 | rich 19 | pydantic 20 | timm 21 | onnxruntime 22 | hydra-core 23 | iopath 24 | -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | 10 | from hydra import initialize_config_dir, initialize_config_module # noqa: F401 11 | 12 | warnings.filterwarnings("ignore", category=UserWarning, module="sam2") 13 | 14 | inpa_basedir = os.path.abspath(os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))) 15 | configs_path = os.path.join(inpa_basedir, "sam2_configs") 16 | 17 | try: 18 | initialize_config_dir(configs_path, version_base="1.2") 19 | except TypeError: 20 | initialize_config_dir(configs_path) 21 | # initialize_config_module("sam2_configs", version_base="1.2") 22 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def build_sam2( 16 | config_file, 17 | ckpt_path=None, 18 | device="cuda", 19 | mode="eval", 20 | hydra_overrides_extra=[], 21 | apply_postprocessing=True, 22 | ): 23 | 24 | if apply_postprocessing: 25 | hydra_overrides_extra = hydra_overrides_extra.copy() 26 | hydra_overrides_extra += [ 27 | # dynamically fall back to multi-mask if the single mask is not stable 28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 30 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 31 | ] 32 | # Read config and init model 33 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 34 | OmegaConf.resolve(cfg) 35 | model = instantiate(cfg.model, _recursive_=True) 36 | _load_checkpoint(model, ckpt_path) 37 | model = model.to(device) 38 | if mode == "eval": 39 | model.eval() 40 | return model 41 | 42 | 43 | def build_sam2_video_predictor( 44 | config_file, 45 | ckpt_path=None, 46 | device="cuda", 47 | mode="eval", 48 | hydra_overrides_extra=[], 49 | apply_postprocessing=True, 50 | ): 51 | hydra_overrides = [ 52 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 53 | ] 54 | if apply_postprocessing: 55 | hydra_overrides_extra = hydra_overrides_extra.copy() 56 | hydra_overrides_extra += [ 57 | # dynamically fall back to multi-mask if the single mask is not stable 58 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 61 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 62 | "++model.binarize_mask_from_pts_for_mem_enc=true", 63 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 64 | "++model.fill_hole_area=8", 65 | ] 66 | hydra_overrides.extend(hydra_overrides_extra) 67 | 68 | # Read config and init model 69 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 70 | OmegaConf.resolve(cfg) 71 | model = instantiate(cfg.model, _recursive_=True) 72 | _load_checkpoint(model, ckpt_path) 73 | model = model.to(device) 74 | if mode == "eval": 75 | model.eval() 76 | return model 77 | 78 | 79 | def _load_checkpoint(model, ckpt_path): 80 | if ckpt_path is not None: 81 | sd = torch.load(ckpt_path, map_location="cpu")["model"] 82 | missing_keys, unexpected_keys = model.load_state_dict(sd) 83 | if missing_keys: 84 | logging.error(missing_keys) 85 | raise RuntimeError() 86 | if unexpected_keys: 87 | logging.error(unexpected_keys) 88 | raise RuntimeError() 89 | logging.info("Loaded checkpoint sucessfully") 90 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 16 | """ 17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 18 | that are temporally closest to the current frame at `frame_idx`. Here, we take 19 | - a) the closest conditioning frame before `frame_idx` (if any); 20 | - b) the closest conditioning frame after `frame_idx` (if any); 21 | - c) any other temporally closest conditioning frames until reaching a total 22 | of `max_cond_frame_num` conditioning frames. 23 | 24 | Outputs: 25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 27 | """ 28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 29 | selected_outputs = cond_frame_outputs 30 | unselected_outputs = {} 31 | else: 32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 33 | selected_outputs = {} 34 | 35 | # the closest conditioning frame before `frame_idx` (if any) 36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) 37 | if idx_before is not None: 38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before] 39 | 40 | # the closest conditioning frame after `frame_idx` (if any) 41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) 42 | if idx_after is not None: 43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after] 44 | 45 | # add other temporally closest conditioning frames until reaching a total 46 | # of `max_cond_frame_num` conditioning frames. 47 | num_remain = max_cond_frame_num - len(selected_outputs) 48 | inds_remain = sorted( 49 | (t for t in cond_frame_outputs if t not in selected_outputs), 50 | key=lambda x: abs(x - frame_idx), 51 | )[:num_remain] 52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 53 | unselected_outputs = { 54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 55 | } 56 | 57 | return selected_outputs, unselected_outputs 58 | 59 | 60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000): 61 | """ 62 | Get 1D sine positional embedding as in the original Transformer paper. 63 | """ 64 | pe_dim = dim // 2 65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) 66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) 67 | 68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t 69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) 70 | return pos_embed 71 | 72 | 73 | def get_activation_fn(activation): 74 | """Return an activation function given a string""" 75 | if activation == "relu": 76 | return F.relu 77 | if activation == "gelu": 78 | return F.gelu 79 | if activation == "glu": 80 | return F.glu 81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 82 | 83 | 84 | def get_clones(module, N): 85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 86 | 87 | 88 | class DropPath(nn.Module): 89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py 90 | def __init__(self, drop_prob=0.0, scale_by_keep=True): 91 | super(DropPath, self).__init__() 92 | self.drop_prob = drop_prob 93 | self.scale_by_keep = scale_by_keep 94 | 95 | def forward(self, x): 96 | if self.drop_prob == 0.0 or not self.training: 97 | return x 98 | keep_prob = 1 - self.drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and self.scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | # Lightly adapted from 107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 108 | class MLP(nn.Module): 109 | def __init__( 110 | self, 111 | input_dim: int, 112 | hidden_dim: int, 113 | output_dim: int, 114 | num_layers: int, 115 | activation: nn.Module = nn.ReLU, 116 | sigmoid_output: bool = False, 117 | ) -> None: 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList( 122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 123 | ) 124 | self.sigmoid_output = sigmoid_output 125 | self.act = activation() 126 | 127 | def forward(self, x): 128 | for i, layer in enumerate(self.layers): 129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 130 | if self.sigmoid_output: 131 | x = F.sigmoid(x) 132 | return x 133 | 134 | 135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 137 | class LayerNorm2d(nn.Module): 138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 139 | super().__init__() 140 | self.weight = nn.Parameter(torch.ones(num_channels)) 141 | self.bias = nn.Parameter(torch.zeros(num_channels)) 142 | self.eps = eps 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | return x 150 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/torch_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_iou 3 | 4 | 5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: 6 | order = torch.argsort(-scores) 7 | keep = [] 8 | 9 | while order.numel() > 0: 10 | i = order[0] 11 | keep.append(i.item()) 12 | 13 | if order.numel() == 1: 14 | break 15 | 16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] 17 | mask = ious <= iou_threshold 18 | order = order[1:][mask] 19 | 20 | return torch.tensor(keep, device=bboxes.device) 21 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Normalize, Resize, ToTensor 11 | 12 | 13 | class SAM2Transforms(nn.Module): 14 | def __init__( 15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 16 | ): 17 | """ 18 | Transforms for SAM2. 19 | """ 20 | super().__init__() 21 | self.resolution = resolution 22 | self.mask_threshold = mask_threshold 23 | self.max_hole_area = max_hole_area 24 | self.max_sprinkle_area = max_sprinkle_area 25 | self.mean = [0.485, 0.456, 0.406] 26 | self.std = [0.229, 0.224, 0.225] 27 | self.to_tensor = ToTensor() 28 | self.transforms = nn.Sequential( 29 | Resize((self.resolution, self.resolution), antialias=True), 30 | Normalize(self.mean, self.std), 31 | ) 32 | 33 | def __call__(self, x): 34 | x = self.to_tensor(x) 35 | return self.transforms(x) 36 | 37 | def forward_batch(self, img_list): 38 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 39 | img_batch = torch.stack(img_batch, dim=0) 40 | return img_batch 41 | 42 | def transform_coords( 43 | self, coords: torch.Tensor, normalize=False, orig_hw=None 44 | ) -> torch.Tensor: 45 | """ 46 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 47 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 48 | 49 | Returns 50 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 51 | """ 52 | if normalize: 53 | assert orig_hw is not None 54 | h, w = orig_hw 55 | coords = coords.clone() 56 | coords[..., 0] = coords[..., 0] / w 57 | coords[..., 1] = coords[..., 1] / h 58 | 59 | coords = coords * self.resolution # unnormalize coords 60 | return coords 61 | 62 | def transform_boxes( 63 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 64 | ) -> torch.Tensor: 65 | """ 66 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 67 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 68 | """ 69 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 70 | return boxes 71 | 72 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 73 | """ 74 | Perform PostProcessing on output masks. 75 | """ 76 | from sam2.utils.misc import get_connected_components 77 | 78 | masks = masks.float() 79 | if self.max_hole_area > 0: 80 | # Holes are those connected components in background with area <= self.fill_hole_area 81 | # (background regions are those with mask scores <= self.mask_threshold) 82 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 83 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold) 84 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 85 | is_hole = is_hole.reshape_as(masks) 86 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 87 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 88 | 89 | if self.max_sprinkle_area > 0: 90 | labels, areas = get_connected_components(mask_flat > self.mask_threshold) 91 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 92 | is_hole = is_hole.reshape_as(masks) 93 | # We fill holes with negative mask score (-10.0) to change them to background. 94 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 95 | 96 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 97 | return masks 98 | -------------------------------------------------------------------------------- /sam2_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /segment_anything_fb/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | 17 | __all__ = [ 18 | "build_sam", 19 | "build_sam_vit_h", 20 | "build_sam_vit_l", 21 | "build_sam_vit_b", 22 | "sam_model_registry", 23 | "SamPredictor", 24 | "SamAutomaticMaskGenerator", 25 | ] 26 | -------------------------------------------------------------------------------- /segment_anything_fb/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything_fb/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | 13 | __all__ = [ 14 | "Sam", 15 | "ImageEncoderViT", 16 | "MaskDecoder", 17 | "PromptEncoder", 18 | "TwoWayTransformer", 19 | ] 20 | -------------------------------------------------------------------------------- /segment_anything_fb/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything_fb/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything_fb/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything_fb/utils/torch_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_iou 3 | 4 | 5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: 6 | order = torch.argsort(-scores) 7 | keep = [] 8 | 9 | while order.numel() > 0: 10 | i = order[0] 11 | keep.append(i.item()) 12 | 13 | if order.numel() == 1: 14 | break 15 | 16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] 17 | mask = ious <= iou_threshold 18 | order = order[1:][mask] 19 | 20 | return torch.tensor(keep, device=bboxes.device) 21 | -------------------------------------------------------------------------------- /segment_anything_fb/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /segment_anything_hq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .build_sam_baseline import sam_model_registry_baseline 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | 18 | __all__ = [ 19 | "build_sam", 20 | "build_sam_vit_h", 21 | "build_sam_vit_l", 22 | "build_sam_vit_b", 23 | "sam_model_registry", 24 | "sam_model_registry_baseline", 25 | "SamPredictor", 26 | "SamAutomaticMaskGenerator", 27 | ] 28 | -------------------------------------------------------------------------------- /segment_anything_hq/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer 12 | import platform 13 | 14 | 15 | def build_sam_vit_h(checkpoint=None): 16 | return _build_sam( 17 | encoder_embed_dim=1280, 18 | encoder_depth=32, 19 | encoder_num_heads=16, 20 | encoder_global_attn_indexes=[7, 15, 23, 31], 21 | checkpoint=checkpoint, 22 | ) 23 | 24 | 25 | build_sam = build_sam_vit_h 26 | 27 | 28 | def build_sam_vit_l(checkpoint=None): 29 | return _build_sam( 30 | encoder_embed_dim=1024, 31 | encoder_depth=24, 32 | encoder_num_heads=16, 33 | encoder_global_attn_indexes=[5, 11, 17, 23], 34 | checkpoint=checkpoint, 35 | ) 36 | 37 | 38 | def build_sam_vit_b(checkpoint=None): 39 | return _build_sam( 40 | encoder_embed_dim=768, 41 | encoder_depth=12, 42 | encoder_num_heads=12, 43 | encoder_global_attn_indexes=[2, 5, 8, 11], 44 | checkpoint=checkpoint, 45 | ) 46 | 47 | 48 | sam_model_registry = { 49 | "default": build_sam_vit_h, 50 | "vit_h": build_sam_vit_h, 51 | "vit_l": build_sam_vit_l, 52 | "vit_b": build_sam_vit_b, 53 | } 54 | 55 | 56 | def _build_sam( 57 | encoder_embed_dim, 58 | encoder_depth, 59 | encoder_num_heads, 60 | encoder_global_attn_indexes, 61 | checkpoint=None, 62 | ): 63 | prompt_embed_dim = 256 64 | image_size = 1024 65 | vit_patch_size = 16 66 | image_embedding_size = image_size // vit_patch_size 67 | sam = Sam( 68 | image_encoder=ImageEncoderViT( 69 | depth=encoder_depth, 70 | embed_dim=encoder_embed_dim, 71 | img_size=image_size, 72 | mlp_ratio=4, 73 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 74 | num_heads=encoder_num_heads, 75 | patch_size=vit_patch_size, 76 | qkv_bias=True, 77 | use_rel_pos=True, 78 | global_attn_indexes=encoder_global_attn_indexes, 79 | window_size=14, 80 | out_chans=prompt_embed_dim, 81 | ), 82 | prompt_encoder=PromptEncoder( 83 | embed_dim=prompt_embed_dim, 84 | image_embedding_size=(image_embedding_size, image_embedding_size), 85 | input_image_size=(image_size, image_size), 86 | mask_in_chans=16, 87 | ), 88 | mask_decoder=MaskDecoderHQ( 89 | num_multimask_outputs=3, 90 | transformer=TwoWayTransformer( 91 | depth=2, 92 | embedding_dim=prompt_embed_dim, 93 | mlp_dim=2048, 94 | num_heads=8, 95 | ), 96 | transformer_dim=prompt_embed_dim, 97 | iou_head_depth=3, 98 | iou_head_hidden_dim=256, 99 | vit_dim=encoder_embed_dim, 100 | ), 101 | pixel_mean=[123.675, 116.28, 103.53], 102 | pixel_std=[58.395, 57.12, 57.375], 103 | ) 104 | sam.eval() 105 | if checkpoint is not None: 106 | with open(checkpoint, "rb") as f: 107 | if platform.system() == "Darwin": 108 | if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 109 | state_dict = torch.load(f, map_location=torch.device("mps")) 110 | else: 111 | state_dict = torch.load(f, map_location=torch.device("cpu")) 112 | else: 113 | if torch.cuda.is_available(): 114 | state_dict = torch.load(f) 115 | else: 116 | state_dict = torch.load(f, map_location=torch.device("cpu")) 117 | # info = sam.load_state_dict(state_dict, strict=False) 118 | # print(info) 119 | sam.load_state_dict(state_dict, strict=False) 120 | for n, p in sam.named_parameters(): 121 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: 122 | p.requires_grad = False 123 | 124 | return sam 125 | -------------------------------------------------------------------------------- /segment_anything_hq/build_sam_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry_baseline = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything_hq/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder_hq import MaskDecoderHQ 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | 14 | __all__ = [ 15 | "Sam", 16 | "ImageEncoderViT", 17 | "MaskDecoderHQ", 18 | "MaskDecoder", 19 | "PromptEncoder", 20 | "TwoWayTransformer", 21 | ] 22 | -------------------------------------------------------------------------------- /segment_anything_hq/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything_hq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything_hq/utils/torch_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_iou 3 | 4 | 5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: 6 | order = torch.argsort(-scores) 7 | keep = [] 8 | 9 | while order.numel() > 0: 10 | i = order[0] 11 | keep.append(i.item()) 12 | 13 | if order.numel() == 1: 14 | break 15 | 16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] 17 | mask = ious <= iou_threshold 18 | order = order[1:][mask] 19 | 20 | return torch.tensor(keep, device=bboxes.device) 21 | -------------------------------------------------------------------------------- /segment_anything_hq/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | --------------------------------------------------------------------------------