├── .gitignore
├── LICENSE
├── README.md
├── README_DEV.md
├── fast_sam
├── __init__.py
└── fast_sam_wrapper.py
├── ia_check_versions.py
├── ia_config.py
├── ia_devices.py
├── ia_file_manager.py
├── ia_get_dataset_colormap.py
├── ia_logging.py
├── ia_sam_manager.py
├── ia_threading.py
├── ia_ui_gradio.py
├── ia_ui_items.py
├── iasam_app.py
├── images
├── inpaint_anything_explanation_image_1.png
├── inpaint_anything_ui_image_1.png
├── sample_input_image.png
├── sample_mask_image.png
└── sample_seg_color_image.png
├── inpalib
├── __init__.py
├── masklib.py
└── samlib.py
├── javascript
└── inpaint-anything.js
├── lama_cleaner
├── __init__.py
├── benchmark.py
├── const.py
├── file_manager
│ ├── __init__.py
│ ├── file_manager.py
│ ├── storage_backends.py
│ └── utils.py
├── helper.py
├── installer.py
├── model
│ ├── __init__.py
│ ├── base.py
│ ├── controlnet.py
│ ├── ddim_sampler.py
│ ├── fcf.py
│ ├── instruct_pix2pix.py
│ ├── lama.py
│ ├── ldm.py
│ ├── manga.py
│ ├── mat.py
│ ├── opencv2.py
│ ├── paint_by_example.py
│ ├── pipeline
│ │ ├── __init__.py
│ │ └── pipeline_stable_diffusion_controlnet_inpaint.py
│ ├── plms_sampler.py
│ ├── sd.py
│ ├── utils.py
│ └── zits.py
├── model_manager.py
├── parse_args.py
├── plugins
│ ├── __init__.py
│ ├── anime_seg.py
│ ├── base_plugin.py
│ ├── gfpgan_plugin.py
│ ├── gfpganer.py
│ ├── gif.py
│ ├── interactive_seg.py
│ ├── realesrgan.py
│ ├── remove_bg.py
│ ├── restoreformer.py
│ └── segment_anything
│ │ ├── __init__.py
│ │ ├── build_sam.py
│ │ ├── modeling
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── image_encoder.py
│ │ ├── mask_decoder.py
│ │ ├── prompt_encoder.py
│ │ ├── sam.py
│ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils
│ │ ├── __init__.py
│ │ └── transforms.py
├── runtime.py
├── schema.py
├── server.py
├── tests
│ ├── __init__.py
│ ├── test_controlnet.py
│ ├── test_instruct_pix2pix.py
│ ├── test_interactive_seg.py
│ ├── test_load_img.py
│ ├── test_model.py
│ ├── test_model_md5.py
│ ├── test_paint_by_example.py
│ ├── test_plugins.py
│ ├── test_save_exif.py
│ └── test_sd_model.py
└── web_config.py
├── mobile_sam
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ ├── tiny_vit_sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ ├── torch_nms.py
│ └── transforms.py
├── requirements.txt
├── requirements_cu118.txt
├── requirements_mac.txt
├── sam2
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── csrc
│ └── connected_components.cu
├── modeling
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── hieradet.py
│ │ ├── image_encoder.py
│ │ └── utils.py
│ ├── memory_attention.py
│ ├── memory_encoder.py
│ ├── position_encoding.py
│ ├── sam
│ │ ├── __init__.py
│ │ ├── mask_decoder.py
│ │ ├── prompt_encoder.py
│ │ └── transformer.py
│ ├── sam2_base.py
│ └── sam2_utils.py
├── sam2_image_predictor.py
├── sam2_video_predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── misc.py
│ ├── torch_nms.py
│ └── transforms.py
├── sam2_configs
├── __init__.py
├── sam2_hiera_b+.yaml
├── sam2_hiera_l.yaml
├── sam2_hiera_s.yaml
└── sam2_hiera_t.yaml
├── segment_anything_fb
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ ├── torch_nms.py
│ └── transforms.py
└── segment_anything_hq
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── build_sam_baseline.py
├── modeling
├── __init__.py
├── common.py
├── image_encoder.py
├── mask_decoder.py
├── mask_decoder_hq.py
├── prompt_encoder.py
├── sam.py
└── transformer.py
├── predictor.py
└── utils
├── __init__.py
├── amg.py
├── onnx.py
├── torch_nms.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.pt
3 | *.pyc
4 | src/
5 | outputs/
6 | models/
7 | models
8 | .DS_Store
9 | ia_config.ini
10 | .eslintrc
11 | .eslintrc.json
12 | pyproject.toml
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | pip-wheel-metadata/
37 | share/python-wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .nox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *.cover
63 | *.py,cover
64 | .hypothesis/
65 | .pytest_cache/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
--------------------------------------------------------------------------------
/README_DEV.md:
--------------------------------------------------------------------------------
1 | # Usage of Inpaint Anything Library
2 |
3 | ## Introduction
4 |
5 | The `inpalib` from the `inpaint-anything` package lets you segment images and create masks using sketches from other applications.
6 |
7 | ## Code Breakdown
8 |
9 | ### Imports and Module Initialization
10 |
11 | ```python
12 | import importlib
13 |
14 | import numpy as np
15 | from PIL import Image, ImageDraw
16 |
17 | inpalib = importlib.import_module("inpaint-anything.inpalib")
18 | ```
19 |
20 | ### Fetch Model IDs
21 |
22 | ```python
23 | available_sam_ids = inpalib.get_available_sam_ids()
24 |
25 | use_sam_id = "sam_hq_vit_l.pth"
26 | # assert use_sam_id in available_sam_ids, f"Invalid SAM ID: {use_sam_id}"
27 | ```
28 |
29 | Note: Only the models downloaded via the Inpaint Anything are available.
30 |
31 | ### Generate Segments Image
32 |
33 | ```python
34 | input_image = np.array(Image.open("/path/to/image.png"))
35 |
36 | sam_masks = inpalib.generate_sam_masks(input_image, use_sam_id, anime_style_chk=False)
37 | sam_masks = inpalib.sort_masks_by_area(sam_masks)
38 |
39 | seg_color_image = inpalib.create_seg_color_image(input_image, sam_masks)
40 |
41 | Image.fromarray(seg_color_image).save("/path/to/seg_color_image.png")
42 | ```
43 |
44 |
45 |
46 | ### Create Mask from Sketch
47 |
48 | ```python
49 | sketch_image = Image.fromarray(np.zeros_like(input_image))
50 |
51 | draw = ImageDraw.Draw(sketch_image)
52 | draw.point((input_image.shape[1] // 2, input_image.shape[0] // 2), fill=(255, 255, 255))
53 |
54 | mask_image = inpalib.create_mask_image(np.array(sketch_image), sam_masks, ignore_black_chk=True)
55 |
56 | Image.fromarray(mask_image).save("/path/to/mask_image.png")
57 | ```
58 |
59 |
60 |
61 | Note: Ensure you adjust the file paths before executing the code.
62 |
--------------------------------------------------------------------------------
/fast_sam/__init__.py:
--------------------------------------------------------------------------------
1 | from .fast_sam_wrapper import FastSAM
2 | from .fast_sam_wrapper import FastSamAutomaticMaskGenerator
3 |
4 | fast_sam_model_registry = {
5 | "FastSAM-x": FastSAM,
6 | "FastSAM-s": FastSAM,
7 | }
8 |
9 | __all__ = ["FastSAM", "FastSamAutomaticMaskGenerator", "fast_sam_model_registry"]
10 |
--------------------------------------------------------------------------------
/fast_sam/fast_sam_wrapper.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import math
3 | from typing import Any, Dict, List
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import ultralytics
9 |
10 | if hasattr(ultralytics, "FastSAM"):
11 | from ultralytics import FastSAM as YOLO
12 | else:
13 | from ultralytics import YOLO
14 |
15 |
16 | class FastSAM:
17 | def __init__(
18 | self,
19 | checkpoint: str,
20 | ) -> None:
21 | self.model_path = checkpoint
22 | self.model = YOLO(self.model_path)
23 |
24 | if not hasattr(torch.nn.Upsample, "recompute_scale_factor"):
25 | torch.nn.Upsample.recompute_scale_factor = None
26 |
27 | def to(self, device) -> None:
28 | self.model.to(device)
29 |
30 | @property
31 | def device(self) -> Any:
32 | return self.model.device
33 |
34 | def __call__(self, source=None, stream=False, **kwargs) -> Any:
35 | return self.model(source=source, stream=stream, **kwargs)
36 |
37 |
38 | class FastSamAutomaticMaskGenerator:
39 | def __init__(
40 | self,
41 | model: FastSAM,
42 | points_per_batch: int = None,
43 | pred_iou_thresh: float = None,
44 | stability_score_thresh: float = None,
45 | ) -> None:
46 | self.model = model
47 | self.points_per_batch = points_per_batch
48 | self.pred_iou_thresh = pred_iou_thresh
49 | self.stability_score_thresh = stability_score_thresh
50 | self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15
51 |
52 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
53 | height, width = image.shape[:2]
54 | new_height = math.ceil(height / 32) * 32
55 | new_width = math.ceil(width / 32) * 32
56 | resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
57 |
58 | backup_nn_dict = {}
59 | for key, _ in torch.nn.__dict__.copy().items():
60 | if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key:
61 | backup_nn_dict[key] = torch.nn.__dict__.pop(key)
62 |
63 | results = self.model(
64 | source=resize_image,
65 | stream=False,
66 | imgsz=max(new_height, new_width),
67 | device=self.model.device,
68 | retina_masks=True,
69 | iou=0.7,
70 | conf=self.conf,
71 | max_det=256)
72 |
73 | for key, value in backup_nn_dict.items():
74 | setattr(torch.nn, key, value)
75 | # assert backup_nn_dict[key] == torch.nn.__dict__[key]
76 |
77 | annotations = results[0].masks.data
78 |
79 | if isinstance(annotations[0], torch.Tensor):
80 | annotations = np.array(annotations.cpu())
81 |
82 | annotations_list = []
83 | for mask in annotations:
84 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
85 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8))
86 | mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA)
87 |
88 | annotations_list.append(dict(segmentation=mask.astype(bool)))
89 |
90 | return annotations_list
91 |
--------------------------------------------------------------------------------
/ia_check_versions.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from importlib.metadata import version
3 | from importlib.util import find_spec
4 |
5 | import torch
6 | from packaging.version import parse
7 |
8 |
9 | def get_module_version(module_name):
10 | try:
11 | module_version = version(module_name)
12 | except Exception:
13 | module_version = None
14 | return module_version
15 |
16 |
17 | def compare_version(version1, version2):
18 | if not isinstance(version1, str) or not isinstance(version2, str):
19 | return None
20 |
21 | if parse(version1) > parse(version2):
22 | return 1
23 | elif parse(version1) < parse(version2):
24 | return -1
25 | else:
26 | return 0
27 |
28 |
29 | def compare_module_version(module_name, version_string):
30 | module_version = get_module_version(module_name)
31 |
32 | result = compare_version(module_version, version_string)
33 | return result if result is not None else -2
34 |
35 |
36 | class IACheckVersions:
37 | @cached_property
38 | def diffusers_enable_cpu_offload(self):
39 | if (find_spec("diffusers") is not None and compare_module_version("diffusers", "0.15.0") >= 0 and
40 | find_spec("accelerate") is not None and compare_module_version("accelerate", "0.17.0") >= 0 and
41 | torch.cuda.is_available()):
42 | return True
43 | else:
44 | return False
45 |
46 | @cached_property
47 | def torch_mps_is_available(self):
48 | if compare_module_version("torch", "2.0.1") < 0:
49 | if not getattr(torch, "has_mps", False):
50 | return False
51 | try:
52 | torch.zeros(1).to(torch.device("mps"))
53 | return True
54 | except Exception:
55 | return False
56 | else:
57 | return torch.backends.mps.is_available() and torch.backends.mps.is_built()
58 |
59 | @cached_property
60 | def torch_on_amd_rocm(self):
61 | if find_spec("torch") is not None and "rocm" in version("torch"):
62 | return True
63 | else:
64 | return False
65 |
66 | @cached_property
67 | def gradio_version_is_old(self):
68 | if find_spec("gradio") is not None and compare_module_version("gradio", "3.34.0") <= 0:
69 | return True
70 | else:
71 | return False
72 |
73 |
74 | ia_check_versions = IACheckVersions()
75 |
--------------------------------------------------------------------------------
/ia_config.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | # import json
3 | import os
4 | from types import SimpleNamespace
5 |
6 | from ia_ui_items import get_inp_model_ids, get_sam_model_ids
7 |
8 |
9 | class IAConfig:
10 | SECTIONS = SimpleNamespace(
11 | DEFAULT=configparser.DEFAULTSECT,
12 | USER="USER",
13 | )
14 |
15 | KEYS = SimpleNamespace(
16 | SAM_MODEL_ID="sam_model_id",
17 | INP_MODEL_ID="inp_model_id",
18 | )
19 |
20 | PATHS = SimpleNamespace(
21 | INI=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ia_config.ini"),
22 | )
23 |
24 | global_args = {}
25 |
26 | def __init__(self):
27 | self.ids_dict = {}
28 | self.ids_dict[IAConfig.KEYS.SAM_MODEL_ID] = {
29 | "list": get_sam_model_ids(),
30 | "index": 1,
31 | }
32 | self.ids_dict[IAConfig.KEYS.INP_MODEL_ID] = {
33 | "list": get_inp_model_ids(),
34 | "index": 0,
35 | }
36 |
37 |
38 | ia_config = IAConfig()
39 |
40 |
41 | def setup_ia_config_ini():
42 | ia_config_ini = configparser.ConfigParser(defaults={})
43 | if os.path.isfile(IAConfig.PATHS.INI):
44 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
45 |
46 | changed = False
47 | for key, ids_info in ia_config.ids_dict.items():
48 | if not ia_config_ini.has_option(IAConfig.SECTIONS.DEFAULT, key):
49 | if len(ids_info["list"]) > ids_info["index"]:
50 | ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
51 | changed = True
52 | else:
53 | if len(ids_info["list"]) > ids_info["index"] and ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] != ids_info["list"][ids_info["index"]]:
54 | ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
55 | changed = True
56 |
57 | if changed:
58 | with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
59 | ia_config_ini.write(f)
60 |
61 |
62 | def get_ia_config(key, section=IAConfig.SECTIONS.DEFAULT):
63 | setup_ia_config_ini()
64 |
65 | ia_config_ini = configparser.ConfigParser(defaults={})
66 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
67 |
68 | if ia_config_ini.has_option(section, key):
69 | return ia_config_ini[section][key]
70 |
71 | section = IAConfig.SECTIONS.DEFAULT
72 | if ia_config_ini.has_option(section, key):
73 | return ia_config_ini[section][key]
74 |
75 | return None
76 |
77 |
78 | def get_ia_config_index(key, section=IAConfig.SECTIONS.DEFAULT):
79 | value = get_ia_config(key, section)
80 |
81 | ids_dict = ia_config.ids_dict
82 | if value is None:
83 | if key in ids_dict.keys():
84 | ids_info = ids_dict[key]
85 | return ids_info["index"]
86 | else:
87 | return 0
88 | else:
89 | if key in ids_dict.keys():
90 | ids_info = ids_dict[key]
91 | return ids_info["list"].index(value) if value in ids_info["list"] else ids_info["index"]
92 | else:
93 | return 0
94 |
95 |
96 | def set_ia_config(key, value, section=IAConfig.SECTIONS.DEFAULT):
97 | setup_ia_config_ini()
98 |
99 | ia_config_ini = configparser.ConfigParser(defaults={})
100 | ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
101 |
102 | if ia_config_ini.has_option(section, key) and ia_config_ini[section][key] == value:
103 | return
104 |
105 | if section != IAConfig.SECTIONS.DEFAULT and not ia_config_ini.has_section(section):
106 | ia_config_ini[section] = {}
107 |
108 | try:
109 | ia_config_ini[section][key] = value
110 | except Exception:
111 | ia_config_ini[section] = {}
112 | ia_config_ini[section][key] = value
113 |
114 | with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
115 | ia_config_ini.write(f)
116 |
--------------------------------------------------------------------------------
/ia_devices.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TorchDevices:
5 | def __init__(self):
6 | self.cpu = torch.device("cpu")
7 | self.device = torch.device("cuda") if torch.cuda.is_available() else self.cpu
8 |
9 |
10 | devices = TorchDevices()
11 |
--------------------------------------------------------------------------------
/ia_file_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from huggingface_hub import snapshot_download
4 | from ia_logging import ia_logging
5 |
6 |
7 | class IAFileManager:
8 | DOWNLOAD_COMPLETE = "Download complete"
9 |
10 | def __init__(self) -> None:
11 | self._ia_outputs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
12 | "outputs",
13 | datetime.now().strftime("%Y-%m-%d"))
14 |
15 | self._ia_models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
16 |
17 | @property
18 | def outputs_dir(self) -> str:
19 | """Get inpaint-anything outputs directory.
20 |
21 | Returns:
22 | str: inpaint-anything outputs directory
23 | """
24 | if not os.path.isdir(self._ia_outputs_dir):
25 | os.makedirs(self._ia_outputs_dir, exist_ok=True)
26 | return self._ia_outputs_dir
27 |
28 | @property
29 | def models_dir(self) -> str:
30 | """Get inpaint-anything models directory.
31 |
32 | Returns:
33 | str: inpaint-anything models directory
34 | """
35 | if not os.path.isdir(self._ia_models_dir):
36 | os.makedirs(self._ia_models_dir, exist_ok=True)
37 | return self._ia_models_dir
38 |
39 | @property
40 | def savename_prefix(self) -> str:
41 | """Get inpaint-anything savename prefix.
42 |
43 | Returns:
44 | str: inpaint-anything savename prefix
45 | """
46 | return datetime.now().strftime("%Y%m%d-%H%M%S")
47 |
48 |
49 | ia_file_manager = IAFileManager()
50 |
51 |
52 | def download_model_from_hf(hf_model_id, local_files_only=False):
53 | """Download model from HuggingFace Hub.
54 |
55 | Args:
56 | sam_model_id (str): HuggingFace model id
57 | local_files_only (bool, optional): If True, use only local files. Defaults to False.
58 |
59 | Returns:
60 | str: download status
61 | """
62 | if not local_files_only:
63 | ia_logging.info(f"Downloading {hf_model_id}")
64 | try:
65 | snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only)
66 | except FileNotFoundError:
67 | return f"{hf_model_id} not found, please download"
68 | except Exception as e:
69 | return str(e)
70 |
71 | return IAFileManager.DOWNLOAD_COMPLETE
72 |
--------------------------------------------------------------------------------
/ia_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 |
4 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
5 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="huggingface_hub")
6 | warnings.filterwarnings(action="ignore", category=FutureWarning, module="timm")
7 |
8 | ia_logging = logging.getLogger("Inpaint Anything")
9 | ia_logging.setLevel(logging.INFO)
10 | ia_logging.propagate = False
11 |
12 | ia_logging_sh = logging.StreamHandler()
13 | ia_logging_sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
14 | ia_logging_sh.setLevel(logging.INFO)
15 | ia_logging.addHandler(ia_logging_sh)
16 |
--------------------------------------------------------------------------------
/ia_threading.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import inspect
3 | import threading
4 | from functools import wraps
5 |
6 | import torch
7 |
8 | from ia_check_versions import ia_check_versions
9 |
10 | model_access_sem = threading.Semaphore(1)
11 |
12 |
13 | def torch_gc():
14 | if torch.cuda.is_available():
15 | torch.cuda.empty_cache()
16 | torch.cuda.ipc_collect()
17 | if ia_check_versions.torch_mps_is_available:
18 | if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
19 | torch.mps.empty_cache()
20 |
21 |
22 | def clear_cache():
23 | gc.collect()
24 | torch_gc()
25 |
26 |
27 | def post_clear_cache(sem):
28 | with sem:
29 | gc.collect()
30 | torch_gc()
31 |
32 |
33 | def async_post_clear_cache():
34 | thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,))
35 | thread.start()
36 |
37 |
38 | def clear_cache_decorator(func):
39 | @wraps(func)
40 | def yield_wrapper(*args, **kwargs):
41 | clear_cache()
42 | yield from func(*args, **kwargs)
43 | clear_cache()
44 |
45 | @wraps(func)
46 | def wrapper(*args, **kwargs):
47 | clear_cache()
48 | res = func(*args, **kwargs)
49 | clear_cache()
50 | return res
51 |
52 | if inspect.isgeneratorfunction(func):
53 | return yield_wrapper
54 | else:
55 | return wrapper
56 |
--------------------------------------------------------------------------------
/ia_ui_gradio.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import gradio as gr
4 |
5 | GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
6 |
7 |
8 | def webpath(fn):
9 | web_path = os.path.realpath(fn)
10 |
11 | return f'file={web_path}?{os.path.getmtime(fn)}'
12 |
13 |
14 | def javascript_html():
15 | script_path = os.path.join(os.path.dirname(__file__), "javascript", "inpaint-anything.js")
16 | head = f'\n'
17 |
18 | return head
19 |
20 |
21 | def reload_javascript():
22 | js = javascript_html()
23 |
24 | def template_response(*args, **kwargs):
25 | res = GradioTemplateResponseOriginal(*args, **kwargs)
26 | res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
27 | res.init_headers()
28 | return res
29 |
30 | gr.routes.templates.TemplateResponse = template_response
31 |
--------------------------------------------------------------------------------
/ia_ui_items.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import scan_cache_dir
2 |
3 |
4 | def get_sampler_names():
5 | """Get sampler name list.
6 |
7 | Returns:
8 | list: sampler name list
9 | """
10 | sampler_names = [
11 | "DDIM",
12 | "Euler",
13 | "Euler a",
14 | "DPM2 Karras",
15 | "DPM2 a Karras",
16 | ]
17 | return sampler_names
18 |
19 |
20 | def get_sam_model_ids():
21 | """Get SAM model ids list.
22 |
23 | Returns:
24 | list: SAM model ids list
25 | """
26 | sam_model_ids = [
27 | "sam2_hiera_large.pt",
28 | "sam2_hiera_base_plus.pt",
29 | "sam2_hiera_small.pt",
30 | "sam2_hiera_tiny.pt",
31 | "sam_vit_h_4b8939.pth",
32 | "sam_vit_l_0b3195.pth",
33 | "sam_vit_b_01ec64.pth",
34 | "sam_hq_vit_h.pth",
35 | "sam_hq_vit_l.pth",
36 | "sam_hq_vit_b.pth",
37 | "FastSAM-x.pt",
38 | "FastSAM-s.pt",
39 | "mobile_sam.pt",
40 | ]
41 | return sam_model_ids
42 |
43 |
44 | inp_list_from_cache = None
45 |
46 |
47 | def get_inp_model_ids():
48 | """Get inpainting model ids list.
49 |
50 | Returns:
51 | list: model ids list
52 | """
53 | global inp_list_from_cache
54 | model_ids = [
55 | "stabilityai/stable-diffusion-2-inpainting",
56 | "Uminosachi/dreamshaper_8Inpainting",
57 | "Uminosachi/deliberate_v3-inpainting",
58 | "Uminosachi/realisticVisionV51_v51VAE-inpainting",
59 | "Uminosachi/revAnimated_v121Inp-inpainting",
60 | "runwayml/stable-diffusion-inpainting",
61 | ]
62 | if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list):
63 | model_ids.extend(inp_list_from_cache)
64 | return model_ids
65 | try:
66 | hf_cache_info = scan_cache_dir()
67 | inpaint_repos = []
68 | for repo in hf_cache_info.repos:
69 | if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids:
70 | inpaint_repos.append(repo.repo_id)
71 | inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1])
72 | model_ids.extend(inp_list_from_cache)
73 | return model_ids
74 | except Exception:
75 | return model_ids
76 |
77 |
78 | def get_cleaner_model_ids():
79 | """Get cleaner model ids list.
80 |
81 | Returns:
82 | list: model ids list
83 | """
84 | model_ids = [
85 | "lama",
86 | "ldm",
87 | "zits",
88 | "mat",
89 | "fcf",
90 | "manga",
91 | ]
92 | return model_ids
93 |
94 |
95 | def get_padding_mode_names():
96 | """Get padding mode name list.
97 |
98 | Returns:
99 | list: padding mode name list
100 | """
101 | padding_mode_names = [
102 | "constant",
103 | "edge",
104 | "reflect",
105 | "mean",
106 | "median",
107 | "maximum",
108 | "minimum",
109 | ]
110 | return padding_mode_names
111 |
--------------------------------------------------------------------------------
/images/inpaint_anything_explanation_image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/inpaint_anything_explanation_image_1.png
--------------------------------------------------------------------------------
/images/inpaint_anything_ui_image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/inpaint_anything_ui_image_1.png
--------------------------------------------------------------------------------
/images/sample_input_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_input_image.png
--------------------------------------------------------------------------------
/images/sample_mask_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_mask_image.png
--------------------------------------------------------------------------------
/images/sample_seg_color_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/images/sample_seg_color_image.png
--------------------------------------------------------------------------------
/inpalib/__init__.py:
--------------------------------------------------------------------------------
1 | from .masklib import create_mask_image, invert_mask
2 | from .samlib import (create_seg_color_image, generate_sam_masks, get_all_sam_ids,
3 | get_available_sam_ids, get_seg_colormap, insert_mask_to_sam_masks,
4 | sam_file_exists, sam_file_path, sort_masks_by_area)
5 |
6 | __all__ = [
7 | "create_mask_image",
8 | "invert_mask",
9 | "create_seg_color_image",
10 | "generate_sam_masks",
11 | "get_all_sam_ids",
12 | "get_available_sam_ids",
13 | "get_seg_colormap",
14 | "insert_mask_to_sam_masks",
15 | "sam_file_exists",
16 | "sam_file_path",
17 | "sort_masks_by_area",
18 | ]
19 |
--------------------------------------------------------------------------------
/inpalib/masklib.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Union
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | def invert_mask(mask: np.ndarray) -> np.ndarray:
8 | """Invert mask.
9 |
10 | Args:
11 | mask (np.ndarray): mask
12 |
13 | Returns:
14 | np.ndarray: inverted mask
15 | """
16 | if mask is None or not isinstance(mask, np.ndarray):
17 | raise ValueError("Invalid mask")
18 |
19 | # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255
20 | return np.invert(mask.astype(np.uint8))
21 |
22 |
23 | def check_inputs_create_mask_image(
24 | mask: Union[np.ndarray, Image.Image],
25 | sam_masks: List[Dict[str, Any]],
26 | ignore_black_chk: bool = True,
27 | ) -> None:
28 | """Check create mask image inputs.
29 |
30 | Args:
31 | mask (Union[np.ndarray, Image.Image]): mask
32 | sam_masks (List[Dict[str, Any]]): SAM masks
33 | ignore_black_chk (bool): ignore black check
34 |
35 | Returns:
36 | None
37 | """
38 | if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
39 | raise ValueError("Invalid mask")
40 |
41 | if sam_masks is None or not isinstance(sam_masks, list):
42 | raise ValueError("Invalid SAM masks")
43 |
44 | if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
45 | raise ValueError("Invalid ignore black check")
46 |
47 |
48 | def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
49 | """Convert mask.
50 |
51 | Args:
52 | mask (Union[np.ndarray, Image.Image]): mask
53 |
54 | Returns:
55 | np.ndarray: converted mask
56 | """
57 | if isinstance(mask, Image.Image):
58 | mask = np.array(mask)
59 |
60 | if mask.ndim == 2:
61 | mask = mask[:, :, np.newaxis]
62 |
63 | if mask.shape[2] != 1:
64 | mask = mask[:, :, 0:1]
65 |
66 | return mask
67 |
68 |
69 | def create_mask_image(
70 | mask: Union[np.ndarray, Image.Image],
71 | sam_masks: List[Dict[str, Any]],
72 | ignore_black_chk: bool = True,
73 | ) -> np.ndarray:
74 | """Create mask image.
75 |
76 | Args:
77 | mask (Union[np.ndarray, Image.Image]): mask
78 | sam_masks (List[Dict[str, Any]]): SAM masks
79 | ignore_black_chk (bool): ignore black check
80 |
81 | Returns:
82 | np.ndarray: mask image
83 | """
84 | check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
85 | mask = convert_mask(mask)
86 |
87 | canvas_image = np.zeros(mask.shape, dtype=np.uint8)
88 | mask_region = np.zeros(mask.shape, dtype=np.uint8)
89 | for seg_dict in sam_masks:
90 | seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
91 | canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
92 | if (seg_mask * canvas_mask * mask).astype(bool).any():
93 | mask_region = mask_region + (seg_mask * canvas_mask)
94 | seg_color = seg_mask * canvas_mask
95 | canvas_image = canvas_image + seg_color
96 |
97 | if not ignore_black_chk:
98 | canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
99 | if (canvas_mask * mask).astype(bool).any():
100 | mask_region = mask_region + (canvas_mask)
101 |
102 | mask_region = np.tile(mask_region * 255, (1, 1, 3))
103 |
104 | seg_image = mask_region.astype(np.uint8)
105 |
106 | return seg_image
107 |
--------------------------------------------------------------------------------
/lama_cleaner/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4 |
5 | import warnings # noqa: E402
6 |
7 | warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
8 | warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner")
9 |
10 | from lama_cleaner.parse_args import parse_args # noqa: E402
11 |
12 |
13 | def entry_point():
14 | args = parse_args()
15 | # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
16 | # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
17 | from lama_cleaner.server import main
18 |
19 | main(args)
20 |
--------------------------------------------------------------------------------
/lama_cleaner/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import os
5 | import time
6 |
7 | import numpy as np
8 | import nvidia_smi
9 | import psutil
10 | import torch
11 |
12 | from lama_cleaner.model_manager import ModelManager
13 | from lama_cleaner.schema import Config, HDStrategy, SDSampler
14 |
15 | try:
16 | torch._C._jit_override_can_fuse_on_cpu(False)
17 | torch._C._jit_override_can_fuse_on_gpu(False)
18 | torch._C._jit_set_texpr_fuser_enabled(False)
19 | torch._C._jit_set_nvfuser_enabled(False)
20 | except:
21 | pass
22 |
23 | NUM_THREADS = str(4)
24 |
25 | os.environ["OMP_NUM_THREADS"] = NUM_THREADS
26 | os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
27 | os.environ["MKL_NUM_THREADS"] = NUM_THREADS
28 | os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
29 | os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
30 | if os.environ.get("CACHE_DIR"):
31 | os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
32 |
33 |
34 | def run_model(model, size):
35 | # RGB
36 | image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
37 | mask = np.random.randint(0, 255, size).astype(np.uint8)
38 |
39 | config = Config(
40 | ldm_steps=2,
41 | hd_strategy=HDStrategy.ORIGINAL,
42 | hd_strategy_crop_margin=128,
43 | hd_strategy_crop_trigger_size=128,
44 | hd_strategy_resize_limit=128,
45 | prompt="a fox is sitting on a bench",
46 | sd_steps=5,
47 | sd_sampler=SDSampler.ddim
48 | )
49 | model(image, mask, config)
50 |
51 |
52 | def benchmark(model, times: int, empty_cache: bool):
53 | sizes = [(512, 512)]
54 |
55 | nvidia_smi.nvmlInit()
56 | device_id = 0
57 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
58 |
59 | def format(metrics):
60 | return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
61 |
62 | process = psutil.Process(os.getpid())
63 | # 每个 size 给出显存和内存占用的指标
64 | for size in sizes:
65 | torch.cuda.empty_cache()
66 | time_metrics = []
67 | cpu_metrics = []
68 | memory_metrics = []
69 | gpu_memory_metrics = []
70 | for _ in range(times):
71 | start = time.time()
72 | run_model(model, size)
73 | torch.cuda.synchronize()
74 |
75 | # cpu_metrics.append(process.cpu_percent())
76 | time_metrics.append((time.time() - start) * 1000)
77 | memory_metrics.append(process.memory_info().rss / 1024 / 1024)
78 | gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
79 |
80 | print(f"size: {size}".center(80, "-"))
81 | # print(f"cpu: {format(cpu_metrics)}")
82 | print(f"latency: {format(time_metrics)}ms")
83 | print(f"memory: {format(memory_metrics)} MB")
84 | print(f"gpu memory: {format(gpu_memory_metrics)} MB")
85 |
86 | nvidia_smi.nvmlShutdown()
87 |
88 |
89 | def get_args_parser():
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument("--name")
92 | parser.add_argument("--device", default="cuda", type=str)
93 | parser.add_argument("--times", default=10, type=int)
94 | parser.add_argument("--empty-cache", action="store_true")
95 | return parser.parse_args()
96 |
97 |
98 | if __name__ == "__main__":
99 | args = get_args_parser()
100 | device = torch.device(args.device)
101 | model = ModelManager(
102 | name=args.name,
103 | device=device,
104 | sd_run_local=True,
105 | disable_nsfw=True,
106 | sd_cpu_textencoder=True,
107 | hf_access_token="123"
108 | )
109 | benchmark(model, args.times, args.empty_cache)
110 |
--------------------------------------------------------------------------------
/lama_cleaner/const.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from enum import Enum
4 | from pydantic import BaseModel
5 |
6 |
7 | MPS_SUPPORT_MODELS = [
8 | "instruct_pix2pix",
9 | "sd1.5",
10 | "anything4",
11 | "realisticVision1.4",
12 | "sd2",
13 | "paint_by_example",
14 | "controlnet",
15 | ]
16 |
17 | DEFAULT_MODEL = "lama"
18 | AVAILABLE_MODELS = [
19 | "lama",
20 | "ldm",
21 | "zits",
22 | "mat",
23 | "fcf",
24 | "sd1.5",
25 | "anything4",
26 | "realisticVision1.4",
27 | "cv2",
28 | "manga",
29 | "sd2",
30 | "paint_by_example",
31 | "instruct_pix2pix",
32 | ]
33 | SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
34 |
35 | AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
36 | DEFAULT_DEVICE = "cuda"
37 |
38 | NO_HALF_HELP = """
39 | Using full precision model.
40 | If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
41 | """
42 |
43 | CPU_OFFLOAD_HELP = """
44 | Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
45 | """
46 |
47 | DISABLE_NSFW_HELP = """
48 | Disable NSFW checker. (sd/paint_by_example)
49 | """
50 |
51 | SD_CPU_TEXTENCODER_HELP = """
52 | Run Stable Diffusion text encoder model on CPU to save GPU memory.
53 | """
54 |
55 | SD_CONTROLNET_HELP = """
56 | Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
57 | """
58 | DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
59 | SD_CONTROLNET_CHOICES = [
60 | "control_v11p_sd15_canny",
61 | "control_v11p_sd15_openpose",
62 | "control_v11p_sd15_inpaint",
63 | "control_v11f1p_sd15_depth"
64 | ]
65 |
66 | SD_LOCAL_MODEL_HELP = """
67 | Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
68 | """
69 |
70 | LOCAL_FILES_ONLY_HELP = """
71 | Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
72 | """
73 |
74 | ENABLE_XFORMERS_HELP = """
75 | Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
76 | """
77 |
78 | DEFAULT_MODEL_DIR = os.getenv(
79 | "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
80 | )
81 | MODEL_DIR_HELP = """
82 | Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
83 | """
84 |
85 | OUTPUT_DIR_HELP = """
86 | Result images will be saved to output directory automatically without confirmation.
87 | """
88 |
89 | INPUT_HELP = """
90 | If input is image, it will be loaded by default.
91 | If input is directory, you can browse and select image in file manager.
92 | """
93 |
94 | GUI_HELP = """
95 | Launch Lama Cleaner as desktop app
96 | """
97 |
98 | NO_GUI_AUTO_CLOSE_HELP = """
99 | Prevent backend auto close after the GUI window closed.
100 | """
101 |
102 | QUALITY_HELP = """
103 | Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
104 | """
105 |
106 |
107 | class RealESRGANModelName(str, Enum):
108 | realesr_general_x4v3 = "realesr-general-x4v3"
109 | RealESRGAN_x4plus = "RealESRGAN_x4plus"
110 | RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
111 |
112 |
113 | RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
114 |
115 | INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
116 | INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
117 | AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h"]
118 | AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
119 | REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
120 | ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
121 | REALESRGAN_HELP = "Enable realesrgan super resolution"
122 | REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
123 | GFPGAN_HELP = (
124 | "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
125 | )
126 | GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
127 | RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
128 | RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
129 | GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
130 |
131 |
132 | class Config(BaseModel):
133 | host: str = "127.0.0.1"
134 | port: int = 8080
135 | model: str = DEFAULT_MODEL
136 | sd_local_model_path: str = None
137 | sd_controlnet: bool = False
138 | sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
139 | device: str = DEFAULT_DEVICE
140 | gui: bool = False
141 | no_gui_auto_close: bool = False
142 | no_half: bool = False
143 | cpu_offload: bool = False
144 | disable_nsfw: bool = False
145 | sd_cpu_textencoder: bool = False
146 | enable_xformers: bool = False
147 | local_files_only: bool = False
148 | model_dir: str = DEFAULT_MODEL_DIR
149 | input: str = None
150 | output_dir: str = None
151 | # plugins
152 | enable_interactive_seg: bool = False
153 | interactive_seg_model: str = "vit_l"
154 | interactive_seg_device: str = "cpu"
155 | enable_remove_bg: bool = False
156 | enable_anime_seg: bool = False
157 | enable_realesrgan: bool = False
158 | realesrgan_device: str = "cpu"
159 | realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
160 | realesrgan_no_half: bool = False
161 | enable_gfpgan: bool = False
162 | gfpgan_device: str = "cpu"
163 | enable_restoreformer: bool = False
164 | restoreformer_device: str = "cpu"
165 | enable_gif: bool = False
166 |
167 |
168 | def load_config(installer_config: str):
169 | if os.path.exists(installer_config):
170 | with open(installer_config, "r", encoding="utf-8") as f:
171 | return Config(**json.load(f))
172 | else:
173 | return Config()
174 |
--------------------------------------------------------------------------------
/lama_cleaner/file_manager/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_manager import FileManager
2 |
--------------------------------------------------------------------------------
/lama_cleaner/file_manager/storage_backends.py:
--------------------------------------------------------------------------------
1 | # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
2 | import errno
3 | import os
4 | from abc import ABC, abstractmethod
5 |
6 |
7 | class BaseStorageBackend(ABC):
8 | def __init__(self, app=None):
9 | self.app = app
10 |
11 | @abstractmethod
12 | def read(self, filepath, mode="rb", **kwargs):
13 | raise NotImplementedError
14 |
15 | @abstractmethod
16 | def exists(self, filepath):
17 | raise NotImplementedError
18 |
19 | @abstractmethod
20 | def save(self, filepath, data):
21 | raise NotImplementedError
22 |
23 |
24 | class FilesystemStorageBackend(BaseStorageBackend):
25 | def read(self, filepath, mode="rb", **kwargs):
26 | with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
27 | return f.read()
28 |
29 | def exists(self, filepath):
30 | return os.path.exists(filepath)
31 |
32 | def save(self, filepath, data):
33 | directory = os.path.dirname(filepath)
34 |
35 | if not os.path.exists(directory):
36 | try:
37 | os.makedirs(directory)
38 | except OSError as e:
39 | if e.errno != errno.EEXIST:
40 | raise
41 |
42 | if not os.path.isdir(directory):
43 | raise IOError("{} is not a directory".format(directory))
44 |
45 | with open(filepath, "wb") as f:
46 | f.write(data)
47 |
--------------------------------------------------------------------------------
/lama_cleaner/file_manager/utils.py:
--------------------------------------------------------------------------------
1 | # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
2 | import importlib
3 | import os
4 | from pathlib import Path
5 |
6 | from typing import Union
7 |
8 |
9 | def generate_filename(original_filename, *options):
10 | name, ext = os.path.splitext(original_filename)
11 | for v in options:
12 | if v:
13 | name += "_%s" % v
14 | name += ext
15 |
16 | return name
17 |
18 |
19 | def parse_size(size):
20 | if isinstance(size, int):
21 | # If the size parameter is a single number, assume square aspect.
22 | return [size, size]
23 |
24 | if isinstance(size, (tuple, list)):
25 | if len(size) == 1:
26 | # If single value tuple/list is provided, exand it to two elements
27 | return size + type(size)(size)
28 | return size
29 |
30 | try:
31 | thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
32 | except ValueError:
33 | raise ValueError( # pylint: disable=raise-missing-from
34 | "Bad thumbnail size format. Valid format is INTxINT."
35 | )
36 |
37 | if len(thumbnail_size) == 1:
38 | # If the size parameter only contains a single integer, assume square aspect.
39 | thumbnail_size.append(thumbnail_size[0])
40 |
41 | return thumbnail_size
42 |
43 |
44 | def aspect_to_string(size):
45 | if isinstance(size, str):
46 | return size
47 |
48 | return "x".join(map(str, size))
49 |
50 |
51 | IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
52 |
53 |
54 | def glob_img(p: Union[Path, str], recursive: bool = False):
55 | p = Path(p)
56 | if p.is_file() and p.suffix in IMG_SUFFIX:
57 | yield p
58 | else:
59 | if recursive:
60 | files = Path(p).glob("**/*.*")
61 | else:
62 | files = Path(p).glob("*.*")
63 |
64 | for it in files:
65 | if it.suffix not in IMG_SUFFIX:
66 | continue
67 | yield it
68 |
--------------------------------------------------------------------------------
/lama_cleaner/installer.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 |
4 |
5 | def install(package):
6 | subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7 |
8 |
9 | def install_plugins_package():
10 | install("rembg")
11 | install("realesrgan")
12 | install("gfpgan")
13 |
--------------------------------------------------------------------------------
/lama_cleaner/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/lama_cleaner/model/__init__.py
--------------------------------------------------------------------------------
/lama_cleaner/model/instruct_pix2pix.py:
--------------------------------------------------------------------------------
1 | import PIL.Image
2 | import cv2
3 | import torch
4 | from loguru import logger
5 |
6 | from lama_cleaner.model.base import DiffusionInpaintModel
7 | from lama_cleaner.model.utils import set_seed
8 | from lama_cleaner.schema import Config
9 |
10 |
11 | class InstructPix2Pix(DiffusionInpaintModel):
12 | name = "instruct_pix2pix"
13 | pad_mod = 8
14 | min_size = 512
15 |
16 | def init_model(self, device: torch.device, **kwargs):
17 | from diffusers import StableDiffusionInstructPix2PixPipeline
18 | fp16 = not kwargs.get('no_half', False)
19 |
20 | model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
21 | if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
22 | logger.info("Disable Stable Diffusion Model NSFW checker")
23 | model_kwargs.update(dict(
24 | safety_checker=None,
25 | feature_extractor=None,
26 | requires_safety_checker=False
27 | ))
28 |
29 | use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
30 | torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
31 | self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
32 | "timbrooks/instruct-pix2pix",
33 | revision="fp16" if use_gpu and fp16 else "main",
34 | torch_dtype=torch_dtype,
35 | **model_kwargs
36 | )
37 |
38 | self.model.enable_attention_slicing()
39 | if kwargs.get('enable_xformers', False):
40 | self.model.enable_xformers_memory_efficient_attention()
41 |
42 | if kwargs.get('cpu_offload', False) and use_gpu:
43 | logger.info("Enable sequential cpu offload")
44 | self.model.enable_sequential_cpu_offload(gpu_id=0)
45 | else:
46 | self.model = self.model.to(device)
47 |
48 | def forward(self, image, mask, config: Config):
49 | """Input image and output image have same size
50 | image: [H, W, C] RGB
51 | mask: [H, W, 1] 255 means area to repaint
52 | return: BGR IMAGE
53 | edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
54 | """
55 | output = self.model(
56 | image=PIL.Image.fromarray(image),
57 | prompt=config.prompt,
58 | negative_prompt=config.negative_prompt,
59 | num_inference_steps=config.p2p_steps,
60 | image_guidance_scale=config.p2p_image_guidance_scale,
61 | guidance_scale=config.p2p_guidance_scale,
62 | output_type="np.array",
63 | generator=torch.manual_seed(config.sd_seed)
64 | ).images[0]
65 |
66 | output = (output * 255).round().astype("uint8")
67 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
68 | return output
69 |
70 | #
71 | # def forward_post_process(self, result, image, mask, config):
72 | # if config.sd_match_histograms:
73 | # result = self._match_histograms(result, image[:, :, ::-1], mask)
74 | #
75 | # if config.sd_mask_blur != 0:
76 | # k = 2 * config.sd_mask_blur + 1
77 | # mask = cv2.GaussianBlur(mask, (k, k), 0)
78 | # return result, image, mask
79 |
80 | @staticmethod
81 | def is_downloaded() -> bool:
82 | # model will be downloaded when app start, and can't switch in frontend settings
83 | return True
84 |
--------------------------------------------------------------------------------
/lama_cleaner/model/lama.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from lama_cleaner.helper import (
8 | norm_img,
9 | get_cache_path_by_url,
10 | load_jit_model,
11 | )
12 | from lama_cleaner.model.base import InpaintModel
13 | from lama_cleaner.schema import Config
14 |
15 | LAMA_MODEL_URL = os.environ.get(
16 | "LAMA_MODEL_URL",
17 | "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
18 | )
19 | LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
20 |
21 |
22 | class LaMa(InpaintModel):
23 | name = "lama"
24 | pad_mod = 8
25 |
26 | def init_model(self, device, **kwargs):
27 | self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
28 |
29 | @staticmethod
30 | def is_downloaded() -> bool:
31 | return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
32 |
33 | def forward(self, image, mask, config: Config):
34 | """Input image and output image have same size
35 | image: [H, W, C] RGB
36 | mask: [H, W]
37 | return: BGR IMAGE
38 | """
39 | image = norm_img(image)
40 | mask = norm_img(mask)
41 |
42 | mask = (mask > 0) * 1
43 | image = torch.from_numpy(image).unsqueeze(0).to(self.device)
44 | mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
45 |
46 | inpainted_image = self.model(image, mask)
47 |
48 | cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
49 | cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
50 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
51 | return cur_res
52 |
--------------------------------------------------------------------------------
/lama_cleaner/model/manga.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import time
8 | from loguru import logger
9 |
10 | from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
11 | from lama_cleaner.model.base import InpaintModel
12 | from lama_cleaner.schema import Config
13 |
14 |
15 | MANGA_INPAINTOR_MODEL_URL = os.environ.get(
16 | "MANGA_INPAINTOR_MODEL_URL",
17 | "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
18 | )
19 | MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
20 | "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
21 | )
22 |
23 | MANGA_LINE_MODEL_URL = os.environ.get(
24 | "MANGA_LINE_MODEL_URL",
25 | "https://github.com/Sanster/models/releases/download/manga/erika.jit",
26 | )
27 | MANGA_LINE_MODEL_MD5 = os.environ.get(
28 | "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
29 | )
30 |
31 |
32 | class Manga(InpaintModel):
33 | name = "manga"
34 | pad_mod = 16
35 |
36 | def init_model(self, device, **kwargs):
37 | self.inpaintor_model = load_jit_model(
38 | MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
39 | )
40 | self.line_model = load_jit_model(
41 | MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
42 | )
43 | self.seed = 42
44 |
45 | @staticmethod
46 | def is_downloaded() -> bool:
47 | model_paths = [
48 | get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
49 | get_cache_path_by_url(MANGA_LINE_MODEL_URL),
50 | ]
51 | return all([os.path.exists(it) for it in model_paths])
52 |
53 | def forward(self, image, mask, config: Config):
54 | """
55 | image: [H, W, C] RGB
56 | mask: [H, W, 1]
57 | return: BGR IMAGE
58 | """
59 | seed = self.seed
60 | random.seed(seed)
61 | np.random.seed(seed)
62 | torch.manual_seed(seed)
63 | torch.cuda.manual_seed_all(seed)
64 |
65 | gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
66 | gray_img = torch.from_numpy(
67 | gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
68 | ).to(self.device)
69 | start = time.time()
70 | lines = self.line_model(gray_img)
71 | torch.cuda.empty_cache()
72 | lines = torch.clamp(lines, 0, 255)
73 | logger.info(f"erika_model time: {time.time() - start}")
74 |
75 | mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
76 | mask = mask.permute(0, 3, 1, 2)
77 | mask = torch.where(mask > 0.5, 1.0, 0.0)
78 | noise = torch.randn_like(mask)
79 | ones = torch.ones_like(mask)
80 |
81 | gray_img = gray_img / 255 * 2 - 1.0
82 | lines = lines / 255 * 2 - 1.0
83 |
84 | start = time.time()
85 | inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
86 | logger.info(f"image_inpaintor_model time: {time.time() - start}")
87 |
88 | cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
89 | cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
90 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
91 | return cur_res
92 |
--------------------------------------------------------------------------------
/lama_cleaner/model/opencv2.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from lama_cleaner.model.base import InpaintModel
3 | from lama_cleaner.schema import Config
4 |
5 | flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
6 |
7 |
8 | class OpenCV2(InpaintModel):
9 | name = "cv2"
10 | pad_mod = 1
11 |
12 | @staticmethod
13 | def is_downloaded() -> bool:
14 | return True
15 |
16 | def forward(self, image, mask, config: Config):
17 | """Input image and output image have same size
18 | image: [H, W, C] RGB
19 | mask: [H, W, 1]
20 | return: BGR IMAGE
21 | """
22 | cur_res = cv2.inpaint(
23 | image[:, :, ::-1],
24 | mask,
25 | inpaintRadius=config.cv2_radius,
26 | flags=flag_map[config.cv2_flag],
27 | )
28 | return cur_res
29 |
--------------------------------------------------------------------------------
/lama_cleaner/model/paint_by_example.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import PIL.Image
3 | import cv2
4 | import torch
5 | from diffusers import DiffusionPipeline
6 | from loguru import logger
7 |
8 | from lama_cleaner.model.base import DiffusionInpaintModel
9 | from lama_cleaner.model.utils import set_seed
10 | from lama_cleaner.schema import Config
11 |
12 |
13 | class PaintByExample(DiffusionInpaintModel):
14 | name = "paint_by_example"
15 | pad_mod = 8
16 | min_size = 512
17 |
18 | def init_model(self, device: torch.device, **kwargs):
19 | fp16 = not kwargs.get('no_half', False)
20 | use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
21 | torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
22 | model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
23 |
24 | if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
25 | logger.info("Disable Paint By Example Model NSFW checker")
26 | model_kwargs.update(dict(
27 | safety_checker=None,
28 | requires_safety_checker=False
29 | ))
30 |
31 | self.model = DiffusionPipeline.from_pretrained(
32 | "Fantasy-Studio/Paint-by-Example",
33 | torch_dtype=torch_dtype,
34 | **model_kwargs
35 | )
36 |
37 | self.model.enable_attention_slicing()
38 | if kwargs.get('enable_xformers', False):
39 | self.model.enable_xformers_memory_efficient_attention()
40 |
41 | # TODO: gpu_id
42 | if kwargs.get('cpu_offload', False) and use_gpu:
43 | self.model.image_encoder = self.model.image_encoder.to(device)
44 | self.model.enable_sequential_cpu_offload(gpu_id=0)
45 | else:
46 | self.model = self.model.to(device)
47 |
48 | def forward(self, image, mask, config: Config):
49 | """Input image and output image have same size
50 | image: [H, W, C] RGB
51 | mask: [H, W, 1] 255 means area to repaint
52 | return: BGR IMAGE
53 | """
54 | output = self.model(
55 | image=PIL.Image.fromarray(image),
56 | mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
57 | example_image=config.paint_by_example_example_image,
58 | num_inference_steps=config.paint_by_example_steps,
59 | output_type='np.array',
60 | generator=torch.manual_seed(config.paint_by_example_seed)
61 | ).images[0]
62 |
63 | output = (output * 255).round().astype("uint8")
64 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
65 | return output
66 |
67 | def forward_post_process(self, result, image, mask, config):
68 | if config.paint_by_example_match_histograms:
69 | result = self._match_histograms(result, image[:, :, ::-1], mask)
70 |
71 | if config.paint_by_example_mask_blur != 0:
72 | k = 2 * config.paint_by_example_mask_blur + 1
73 | mask = cv2.GaussianBlur(mask, (k, k), 0)
74 | return result, image, mask
75 |
76 | @staticmethod
77 | def is_downloaded() -> bool:
78 | # model will be downloaded when app start, and can't switch in frontend settings
79 | return True
80 |
--------------------------------------------------------------------------------
/lama_cleaner/model/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_stable_diffusion_controlnet_inpaint import (
2 | StableDiffusionControlNetInpaintPipeline,
3 | )
4 |
--------------------------------------------------------------------------------
/lama_cleaner/model_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import gc
3 |
4 | from loguru import logger
5 |
6 | from lama_cleaner.const import SD15_MODELS
7 | from lama_cleaner.helper import switch_mps_device
8 | from lama_cleaner.model.controlnet import ControlNet
9 | from lama_cleaner.model.fcf import FcF
10 | from lama_cleaner.model.lama import LaMa
11 | from lama_cleaner.model.ldm import LDM
12 | from lama_cleaner.model.manga import Manga
13 | from lama_cleaner.model.mat import MAT
14 | from lama_cleaner.model.paint_by_example import PaintByExample
15 | from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
16 | from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
17 | from lama_cleaner.model.utils import torch_gc
18 | from lama_cleaner.model.zits import ZITS
19 | from lama_cleaner.model.opencv2 import OpenCV2
20 | from lama_cleaner.schema import Config
21 |
22 | models = {
23 | "lama": LaMa,
24 | "ldm": LDM,
25 | "zits": ZITS,
26 | "mat": MAT,
27 | "fcf": FcF,
28 | SD15.name: SD15,
29 | Anything4.name: Anything4,
30 | RealisticVision14.name: RealisticVision14,
31 | "cv2": OpenCV2,
32 | "manga": Manga,
33 | "sd2": SD2,
34 | "paint_by_example": PaintByExample,
35 | "instruct_pix2pix": InstructPix2Pix,
36 | }
37 |
38 |
39 | class ModelManager:
40 | def __init__(self, name: str, device: torch.device, **kwargs):
41 | self.name = name
42 | self.device = device
43 | self.kwargs = kwargs
44 | self.model = self.init_model(name, device, **kwargs)
45 |
46 | def init_model(self, name: str, device, **kwargs):
47 | if name in SD15_MODELS and kwargs.get("sd_controlnet", False):
48 | return ControlNet(device, **{**kwargs, "name": name})
49 |
50 | if name in models:
51 | model = models[name](device, **kwargs)
52 | else:
53 | raise NotImplementedError(f"Not supported model: {name}")
54 | return model
55 |
56 | def is_downloaded(self, name: str) -> bool:
57 | if name in models:
58 | return models[name].is_downloaded()
59 | else:
60 | raise NotImplementedError(f"Not supported model: {name}")
61 |
62 | def __call__(self, image, mask, config: Config):
63 | self.switch_controlnet_method(control_method=config.controlnet_method)
64 | return self.model(image, mask, config)
65 |
66 | def switch(self, new_name: str, **kwargs):
67 | if new_name == self.name:
68 | return
69 | try:
70 | if torch.cuda.memory_allocated() > 0:
71 | # Clear current loaded model from memory
72 | torch.cuda.empty_cache()
73 | del self.model
74 | gc.collect()
75 |
76 | self.model = self.init_model(
77 | new_name, switch_mps_device(new_name, self.device), **self.kwargs
78 | )
79 | self.name = new_name
80 | except NotImplementedError as e:
81 | raise e
82 |
83 | def switch_controlnet_method(self, control_method: str):
84 | if not self.kwargs.get("sd_controlnet"):
85 | return
86 | if self.kwargs["sd_controlnet_method"] == control_method:
87 | return
88 | if not hasattr(self.model, "is_local_sd_model"):
89 | return
90 |
91 | if self.model.is_local_sd_model:
92 | # is_native_control_inpaint 表示加载了普通 SD 模型
93 | if (
94 | self.model.is_native_control_inpaint
95 | and control_method != "control_v11p_sd15_inpaint"
96 | ):
97 | raise RuntimeError(
98 | f"--sd-local-model-path load a normal SD model, "
99 | f"to use {control_method} you should load an inpainting SD model"
100 | )
101 | elif (
102 | not self.model.is_native_control_inpaint
103 | and control_method == "control_v11p_sd15_inpaint"
104 | ):
105 | raise RuntimeError(
106 | f"--sd-local-model-path load an inpainting SD model, "
107 | f"to use {control_method} you should load a norml SD model"
108 | )
109 |
110 | del self.model
111 | torch_gc()
112 |
113 | old_method = self.kwargs["sd_controlnet_method"]
114 | self.kwargs["sd_controlnet_method"] = control_method
115 | self.model = self.init_model(
116 | self.name, switch_mps_device(self.name, self.device), **self.kwargs
117 | )
118 | logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
119 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/__init__.py:
--------------------------------------------------------------------------------
1 | from .interactive_seg import InteractiveSeg
2 | from .remove_bg import RemoveBG
3 | from .realesrgan import RealESRGANUpscaler
4 | from .gfpgan_plugin import GFPGANPlugin
5 | from .restoreformer import RestoreFormerPlugin
6 | from .gif import MakeGIF
7 | from .anime_seg import AnimeSeg
8 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/base_plugin.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 |
3 |
4 | class BasePlugin:
5 | def __init__(self):
6 | err_msg = self.check_dep()
7 | if err_msg:
8 | logger.error(err_msg)
9 | exit(-1)
10 |
11 | def __call__(self, rgb_np_img, files, form):
12 | ...
13 |
14 | def check_dep(self):
15 | ...
16 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/gfpgan_plugin.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from loguru import logger
3 |
4 | from lama_cleaner.helper import download_model
5 | from lama_cleaner.plugins.base_plugin import BasePlugin
6 |
7 |
8 | class GFPGANPlugin(BasePlugin):
9 | name = "GFPGAN"
10 |
11 | def __init__(self, device, upscaler=None):
12 | super().__init__()
13 | from .gfpganer import MyGFPGANer
14 |
15 | url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
16 | model_md5 = "94d735072630ab734561130a47bc44f8"
17 | model_path = download_model(url, model_md5)
18 | logger.info(f"GFPGAN model path: {model_path}")
19 |
20 | import facexlib
21 |
22 | if hasattr(facexlib.detection.retinaface, "device"):
23 | facexlib.detection.retinaface.device = device
24 |
25 | # Use GFPGAN for face enhancement
26 | self.face_enhancer = MyGFPGANer(
27 | model_path=model_path,
28 | upscale=1,
29 | arch="clean",
30 | channel_multiplier=2,
31 | device=device,
32 | bg_upsampler=upscaler.model if upscaler is not None else None,
33 | )
34 | self.face_enhancer.face_helper.face_det.mean_tensor.to(device)
35 | self.face_enhancer.face_helper.face_det = (
36 | self.face_enhancer.face_helper.face_det.to(device)
37 | )
38 |
39 | def __call__(self, rgb_np_img, files, form):
40 | weight = 0.5
41 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
42 | logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
43 | _, _, bgr_output = self.face_enhancer.enhance(
44 | bgr_np_img,
45 | has_aligned=False,
46 | only_center_face=False,
47 | paste_back=True,
48 | weight=weight,
49 | )
50 | logger.info(f"GFPGAN output shape: {bgr_output.shape}")
51 |
52 | # try:
53 | # if scale != 2:
54 | # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
55 | # h, w = img.shape[0:2]
56 | # output = cv2.resize(
57 | # output,
58 | # (int(w * scale / 2), int(h * scale / 2)),
59 | # interpolation=interpolation,
60 | # )
61 | # except Exception as error:
62 | # print("wrong scale input.", error)
63 | return bgr_output
64 |
65 | def check_dep(self):
66 | try:
67 | import gfpgan
68 | except ImportError:
69 | return (
70 | "gfpgan is not installed, please install it first. pip install gfpgan"
71 | )
72 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/gfpganer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
5 | from gfpgan import GFPGANv1Clean, GFPGANer
6 | from torch.hub import get_dir
7 |
8 |
9 | class MyGFPGANer(GFPGANer):
10 | """Helper for restoration with GFPGAN.
11 |
12 | It will detect and crop faces, and then resize the faces to 512x512.
13 | GFPGAN is used to restored the resized faces.
14 | The background is upsampled with the bg_upsampler.
15 | Finally, the faces will be pasted back to the upsample background image.
16 |
17 | Args:
18 | model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
19 | upscale (float): The upscale of the final output. Default: 2.
20 | arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
21 | channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
22 | bg_upsampler (nn.Module): The upsampler for the background. Default: None.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model_path,
28 | upscale=2,
29 | arch="clean",
30 | channel_multiplier=2,
31 | bg_upsampler=None,
32 | device=None,
33 | ):
34 | self.upscale = upscale
35 | self.bg_upsampler = bg_upsampler
36 |
37 | # initialize model
38 | self.device = (
39 | torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 | if device is None
41 | else device
42 | )
43 | # initialize the GFP-GAN
44 | if arch == "clean":
45 | self.gfpgan = GFPGANv1Clean(
46 | out_size=512,
47 | num_style_feat=512,
48 | channel_multiplier=channel_multiplier,
49 | decoder_load_path=None,
50 | fix_decoder=False,
51 | num_mlp=8,
52 | input_is_latent=True,
53 | different_w=True,
54 | narrow=1,
55 | sft_half=True,
56 | )
57 | elif arch == "RestoreFormer":
58 | from gfpgan.archs.restoreformer_arch import RestoreFormer
59 |
60 | self.gfpgan = RestoreFormer()
61 |
62 | hub_dir = get_dir()
63 | model_dir = os.path.join(hub_dir, "checkpoints")
64 |
65 | # initialize face helper
66 | self.face_helper = FaceRestoreHelper(
67 | upscale,
68 | face_size=512,
69 | crop_ratio=(1, 1),
70 | det_model="retinaface_resnet50",
71 | save_ext="png",
72 | use_parse=True,
73 | device=self.device,
74 | model_rootpath=model_dir,
75 | )
76 |
77 | loadnet = torch.load(model_path)
78 | if "params_ema" in loadnet:
79 | keyname = "params_ema"
80 | else:
81 | keyname = "params"
82 | self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
83 | self.gfpgan.eval()
84 | self.gfpgan = self.gfpgan.to(self.device)
85 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/gif.py:
--------------------------------------------------------------------------------
1 | import io
2 | import math
3 |
4 | from PIL import Image, ImageDraw
5 |
6 | from lama_cleaner.helper import load_img
7 | from lama_cleaner.plugins.base_plugin import BasePlugin
8 |
9 |
10 | def keep_ratio_resize(img, size, resample=Image.BILINEAR):
11 | if img.width > img.height:
12 | w = size
13 | h = int(img.height * size / img.width)
14 | else:
15 | h = size
16 | w = int(img.width * size / img.height)
17 | return img.resize((w, h), resample)
18 |
19 |
20 | def cubic_bezier(p1, p2, duration: int, frames: int):
21 | """
22 |
23 | Args:
24 | p1:
25 | p2:
26 | duration: Total duration of the curve
27 | frames:
28 |
29 | Returns:
30 |
31 | """
32 | x0, y0 = (0, 0)
33 | x1, y1 = p1
34 | x2, y2 = p2
35 | x3, y3 = (1, 1)
36 |
37 | def cal_y(t):
38 | return (
39 | math.pow(1 - t, 3) * y0
40 | + 3 * math.pow(1 - t, 2) * t * y1
41 | + 3 * (1 - t) * math.pow(t, 2) * y2
42 | + math.pow(t, 3) * y3
43 | )
44 |
45 | def cal_x(t):
46 | return (
47 | math.pow(1 - t, 3) * x0
48 | + 3 * math.pow(1 - t, 2) * t * x1
49 | + 3 * (1 - t) * math.pow(t, 2) * x2
50 | + math.pow(t, 3) * x3
51 | )
52 |
53 | res = []
54 | for t in range(0, 1 * frames, duration):
55 | t = t / frames
56 | res.append((cal_x(t), cal_y(t)))
57 |
58 | res.append((1, 0))
59 | return res
60 |
61 |
62 | def make_compare_gif(
63 | clean_img: Image.Image,
64 | src_img: Image.Image,
65 | max_side_length: int = 600,
66 | splitter_width: int = 5,
67 | splitter_color=(255, 203, 0, int(255 * 0.73)),
68 | ):
69 | if clean_img.size != src_img.size:
70 | clean_img = clean_img.resize(src_img.size, Image.BILINEAR)
71 |
72 | duration_per_frame = 20
73 | num_frames = 50
74 | # erase-in-out
75 | cubic_bezier_points = cubic_bezier((0.33, 0), (0.66, 1), 1, num_frames)
76 | cubic_bezier_points.reverse()
77 |
78 | max_side_length = min(max_side_length, max(clean_img.size))
79 |
80 | src_img = keep_ratio_resize(src_img, max_side_length)
81 | clean_img = keep_ratio_resize(clean_img, max_side_length)
82 | width, height = src_img.size
83 |
84 | # Generate images to make Gif from right to left
85 | images = []
86 |
87 | for i in range(num_frames):
88 | new_frame = Image.new("RGB", (width, height))
89 | new_frame.paste(clean_img, (0, 0))
90 |
91 | left = int(cubic_bezier_points[i][0] * width)
92 | cropped_src_img = src_img.crop((left, 0, width, height))
93 | new_frame.paste(cropped_src_img, (left, 0, width, height))
94 | if i != num_frames - 1:
95 | # draw a yellow splitter on the edge of the cropped image
96 | draw = ImageDraw.Draw(new_frame)
97 | draw.line(
98 | [(left, 0), (left, height)], width=splitter_width, fill=splitter_color
99 | )
100 | images.append(new_frame)
101 |
102 | for i in range(30):
103 | images.append(src_img)
104 |
105 | cubic_bezier_points.reverse()
106 | # Generate images to make Gif from left to right
107 | for i in range(num_frames):
108 | new_frame = Image.new("RGB", (width, height))
109 | new_frame.paste(src_img, (0, 0))
110 |
111 | right = int(cubic_bezier_points[i][0] * width)
112 | cropped_src_img = clean_img.crop((0, 0, right, height))
113 | new_frame.paste(cropped_src_img, (0, 0, right, height))
114 | if i != num_frames - 1:
115 | # draw a yellow splitter on the edge of the cropped image
116 | draw = ImageDraw.Draw(new_frame)
117 | draw.line(
118 | [(right, 0), (right, height)], width=splitter_width, fill=splitter_color
119 | )
120 | images.append(new_frame)
121 |
122 | for _ in range(30):
123 | images.append(clean_img)
124 |
125 | img_byte_arr = io.BytesIO()
126 | clean_img.save(
127 | img_byte_arr,
128 | format="GIF",
129 | save_all=True,
130 | include_color_table=True,
131 | append_images=images,
132 | optimize=False,
133 | duration=duration_per_frame,
134 | loop=0,
135 | )
136 | return img_byte_arr.getvalue()
137 |
138 |
139 | class MakeGIF(BasePlugin):
140 | name = "MakeGIF"
141 |
142 | def __call__(self, rgb_np_img, files, form):
143 | origin_image = rgb_np_img
144 | clean_image_bytes = files["clean_img"].read()
145 | clean_image, _ = load_img(clean_image_bytes)
146 | gif_bytes = make_compare_gif(
147 | Image.fromarray(origin_image), Image.fromarray(clean_image)
148 | )
149 | return gif_bytes
150 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/interactive_seg.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import cv2
4 | import numpy as np
5 | from loguru import logger
6 |
7 | from lama_cleaner.helper import download_model
8 | from lama_cleaner.plugins.base_plugin import BasePlugin
9 | from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
10 |
11 | # 从小到大
12 | SEGMENT_ANYTHING_MODELS = {
13 | "vit_b": {
14 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
15 | "md5": "01ec64d29a2fca3f0661936605ae66f8",
16 | },
17 | "vit_l": {
18 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
19 | "md5": "0b3195507c641ddb6910d2bb5adee89c",
20 | },
21 | "vit_h": {
22 | "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
23 | "md5": "4b8939a88964f0f4ff5f5b2642c598a6",
24 | },
25 | }
26 |
27 |
28 | class InteractiveSeg(BasePlugin):
29 | name = "InteractiveSeg"
30 |
31 | def __init__(self, model_name, device):
32 | super().__init__()
33 | model_path = download_model(
34 | SEGMENT_ANYTHING_MODELS[model_name]["url"],
35 | SEGMENT_ANYTHING_MODELS[model_name]["md5"],
36 | )
37 | logger.info(f"SegmentAnything model path: {model_path}")
38 | self.predictor = SamPredictor(
39 | sam_model_registry[model_name](checkpoint=model_path).to(device)
40 | )
41 | self.prev_img_md5 = None
42 |
43 | def __call__(self, rgb_np_img, files, form):
44 | clicks = json.loads(form["clicks"])
45 | return self.forward(rgb_np_img, clicks, form["img_md5"])
46 |
47 | def forward(self, rgb_np_img, clicks, img_md5):
48 | input_point = []
49 | input_label = []
50 | for click in clicks:
51 | x = click[0]
52 | y = click[1]
53 | input_point.append([x, y])
54 | input_label.append(click[2])
55 |
56 | if img_md5 and img_md5 != self.prev_img_md5:
57 | self.prev_img_md5 = img_md5
58 | self.predictor.set_image(rgb_np_img)
59 |
60 | masks, scores, _ = self.predictor.predict(
61 | point_coords=np.array(input_point),
62 | point_labels=np.array(input_label),
63 | multimask_output=False,
64 | )
65 | mask = masks[0].astype(np.uint8) * 255
66 | # TODO: how to set kernel size?
67 | kernel_size = 9
68 | mask = cv2.dilate(
69 | mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
70 | )
71 | # fronted brush color "ffcc00bb"
72 | res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
73 | res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
74 | res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
75 | return res_mask
76 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/realesrgan.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import cv2
4 | from loguru import logger
5 |
6 | from lama_cleaner.const import RealESRGANModelName
7 | from lama_cleaner.helper import download_model
8 | from lama_cleaner.plugins.base_plugin import BasePlugin
9 |
10 |
11 | class RealESRGANUpscaler(BasePlugin):
12 | name = "RealESRGAN"
13 |
14 | def __init__(self, name, device, no_half=False):
15 | super().__init__()
16 | from basicsr.archs.rrdbnet_arch import RRDBNet
17 | from realesrgan import RealESRGANer
18 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact
19 |
20 | REAL_ESRGAN_MODELS = {
21 | RealESRGANModelName.realesr_general_x4v3: {
22 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
23 | "scale": 4,
24 | "model": lambda: SRVGGNetCompact(
25 | num_in_ch=3,
26 | num_out_ch=3,
27 | num_feat=64,
28 | num_conv=32,
29 | upscale=4,
30 | act_type="prelu",
31 | ),
32 | "model_md5": "91a7644643c884ee00737db24e478156",
33 | },
34 | RealESRGANModelName.RealESRGAN_x4plus: {
35 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
36 | "scale": 4,
37 | "model": lambda: RRDBNet(
38 | num_in_ch=3,
39 | num_out_ch=3,
40 | num_feat=64,
41 | num_block=23,
42 | num_grow_ch=32,
43 | scale=4,
44 | ),
45 | "model_md5": "99ec365d4afad750833258a1a24f44ca",
46 | },
47 | RealESRGANModelName.RealESRGAN_x4plus_anime_6B: {
48 | "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
49 | "scale": 4,
50 | "model": lambda: RRDBNet(
51 | num_in_ch=3,
52 | num_out_ch=3,
53 | num_feat=64,
54 | num_block=6,
55 | num_grow_ch=32,
56 | scale=4,
57 | ),
58 | "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
59 | },
60 | }
61 | if name not in REAL_ESRGAN_MODELS:
62 | raise ValueError(f"Unknown RealESRGAN model name: {name}")
63 | model_info = REAL_ESRGAN_MODELS[name]
64 |
65 | model_path = download_model(model_info["url"], model_info["model_md5"])
66 | logger.info(f"RealESRGAN model path: {model_path}")
67 |
68 | self.model = RealESRGANer(
69 | scale=model_info["scale"],
70 | model_path=model_path,
71 | model=model_info["model"](),
72 | half=True if "cuda" in str(device) and not no_half else False,
73 | tile=512,
74 | tile_pad=10,
75 | pre_pad=10,
76 | device=device,
77 | )
78 |
79 | def __call__(self, rgb_np_img, files, form):
80 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
81 | scale = float(form["upscale"])
82 | logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
83 | result = self.forward(bgr_np_img, scale)
84 | logger.info(f"RealESRGAN output shape: {result.shape}")
85 | return result
86 |
87 | def forward(self, bgr_np_img, scale: float):
88 | # 输出是 BGR
89 | upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
90 | return upsampled
91 |
92 | def check_dep(self):
93 | try:
94 | import realesrgan
95 | except ImportError:
96 | return "RealESRGAN is not installed, please install it first. pip install realesrgan"
97 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/remove_bg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from torch.hub import get_dir
5 |
6 | from lama_cleaner.plugins.base_plugin import BasePlugin
7 |
8 |
9 | class RemoveBG(BasePlugin):
10 | name = "RemoveBG"
11 |
12 | def __init__(self):
13 | super().__init__()
14 | from rembg import new_session
15 |
16 | hub_dir = get_dir()
17 | model_dir = os.path.join(hub_dir, "checkpoints")
18 | os.environ["U2NET_HOME"] = model_dir
19 |
20 | self.session = new_session(model_name="u2net")
21 |
22 | def __call__(self, rgb_np_img, files, form):
23 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
24 | return self.forward(bgr_np_img)
25 |
26 | def forward(self, bgr_np_img) -> np.ndarray:
27 | from rembg import remove
28 |
29 | # return BGRA image
30 | output = remove(bgr_np_img, session=self.session)
31 | return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
32 |
33 | def check_dep(self):
34 | try:
35 | import rembg
36 | except ImportError:
37 | return (
38 | "RemoveBG is not installed, please install it first. pip install rembg"
39 | )
40 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/restoreformer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from loguru import logger
3 |
4 | from lama_cleaner.helper import download_model
5 | from lama_cleaner.plugins.base_plugin import BasePlugin
6 |
7 |
8 | class RestoreFormerPlugin(BasePlugin):
9 | name = "RestoreFormer"
10 |
11 | def __init__(self, device, upscaler=None):
12 | super().__init__()
13 | from .gfpganer import MyGFPGANer
14 |
15 | url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
16 | model_md5 = "eaeeff6c4a1caa1673977cb374e6f699"
17 | model_path = download_model(url, model_md5)
18 | logger.info(f"RestoreFormer model path: {model_path}")
19 |
20 | import facexlib
21 |
22 | if hasattr(facexlib.detection.retinaface, "device"):
23 | facexlib.detection.retinaface.device = device
24 |
25 | self.face_enhancer = MyGFPGANer(
26 | model_path=model_path,
27 | upscale=1,
28 | arch="RestoreFormer",
29 | channel_multiplier=2,
30 | device=device,
31 | bg_upsampler=upscaler.model if upscaler is not None else None,
32 | )
33 |
34 | def __call__(self, rgb_np_img, files, form):
35 | weight = 0.5
36 | bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
37 | logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
38 | _, _, bgr_output = self.face_enhancer.enhance(
39 | bgr_np_img,
40 | has_aligned=False,
41 | only_center_face=False,
42 | paste_back=True,
43 | weight=weight,
44 | )
45 | logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
46 | return bgr_output
47 |
48 | def check_dep(self):
49 | try:
50 | import gfpgan
51 | except ImportError:
52 | return (
53 | "gfpgan is not installed, please install it first. pip install gfpgan"
54 | )
55 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam,
49 | "vit_h": build_sam,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/lama_cleaner/plugins/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(
31 | image.shape[0], image.shape[1], self.target_length
32 | )
33 | return np.array(resize(to_pil_image(image), target_size))
34 |
35 | def apply_coords(
36 | self, coords: np.ndarray, original_size: Tuple[int, ...]
37 | ) -> np.ndarray:
38 | """
39 | Expects a numpy array of length 2 in the final dimension. Requires the
40 | original image size in (H, W) format.
41 | """
42 | old_h, old_w = original_size
43 | new_h, new_w = self.get_preprocess_shape(
44 | original_size[0], original_size[1], self.target_length
45 | )
46 | coords = deepcopy(coords).astype(float)
47 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
48 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
49 | return coords
50 |
51 | def apply_boxes(
52 | self, boxes: np.ndarray, original_size: Tuple[int, ...]
53 | ) -> np.ndarray:
54 | """
55 | Expects a numpy array shape Bx4. Requires the original image size
56 | in (H, W) format.
57 | """
58 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
59 | return boxes.reshape(-1, 4)
60 |
61 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
62 | """
63 | Expects batched images with shape BxCxHxW and float format. This
64 | transformation may not exactly match apply_image. apply_image is
65 | the transformation expected by the model.
66 | """
67 | # Expects an image in BCHW format. May not exactly match apply_image.
68 | target_size = self.get_preprocess_shape(
69 | image.shape[0], image.shape[1], self.target_length
70 | )
71 | return F.interpolate(
72 | image, target_size, mode="bilinear", align_corners=False, antialias=True
73 | )
74 |
75 | def apply_coords_torch(
76 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
77 | ) -> torch.Tensor:
78 | """
79 | Expects a torch tensor with length 2 in the last dimension. Requires the
80 | original image size in (H, W) format.
81 | """
82 | old_h, old_w = original_size
83 | new_h, new_w = self.get_preprocess_shape(
84 | original_size[0], original_size[1], self.target_length
85 | )
86 | coords = deepcopy(coords).to(torch.float)
87 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
88 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
89 | return coords
90 |
91 | def apply_boxes_torch(
92 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
93 | ) -> torch.Tensor:
94 | """
95 | Expects a torch tensor with shape Bx4. Requires the original image
96 | size in (H, W) format.
97 | """
98 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
99 | return boxes.reshape(-1, 4)
100 |
101 | @staticmethod
102 | def get_preprocess_shape(
103 | oldh: int, oldw: int, long_side_length: int
104 | ) -> Tuple[int, int]:
105 | """
106 | Compute the output size given input size and target long side length.
107 | """
108 | scale = long_side_length * 1.0 / max(oldh, oldw)
109 | newh, neww = oldh * scale, oldw * scale
110 | neww = int(neww + 0.5)
111 | newh = int(newh + 0.5)
112 | return (newh, neww)
113 |
--------------------------------------------------------------------------------
/lama_cleaner/runtime.py:
--------------------------------------------------------------------------------
1 | # https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py
2 | import platform
3 | import sys
4 | import packaging.version
5 | from rich import print
6 | from typing import Dict, Any
7 |
8 | _PY_VERSION: str = sys.version.split()[0].rstrip("+")
9 |
10 | if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"):
11 | import importlib_metadata # type: ignore
12 | else:
13 | import importlib.metadata as importlib_metadata # type: ignore
14 |
15 | _package_versions = {}
16 |
17 | _CANDIDATES = [
18 | "torch",
19 | "torchvision",
20 | "Pillow",
21 | "diffusers",
22 | "transformers",
23 | "opencv-python",
24 | "xformers",
25 | "accelerate",
26 | "lama-cleaner",
27 | "rembg",
28 | "realesrgan",
29 | "gfpgan",
30 | ]
31 | # Check once at runtime
32 | for name in _CANDIDATES:
33 | _package_versions[name] = "N/A"
34 | try:
35 | _package_versions[name] = importlib_metadata.version(name)
36 | except importlib_metadata.PackageNotFoundError:
37 | pass
38 |
39 |
40 | def dump_environment_info() -> Dict[str, str]:
41 | """Dump information about the machine to help debugging issues. """
42 |
43 | # Generic machine info
44 | info: Dict[str, Any] = {
45 | "Platform": platform.platform(),
46 | "Python version": platform.python_version(),
47 | }
48 | info.update(_package_versions)
49 | print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
50 | return info
51 |
--------------------------------------------------------------------------------
/lama_cleaner/schema.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from enum import Enum
3 |
4 | from PIL.Image import Image
5 | from pydantic import BaseModel
6 |
7 |
8 | class HDStrategy(str, Enum):
9 | # Use original image size
10 | ORIGINAL = "Original"
11 | # Resize the longer side of the image to a specific size(hd_strategy_resize_limit),
12 | # then do inpainting on the resized image. Finally, resize the inpainting result to the original size.
13 | # The area outside the mask will not lose quality.
14 | RESIZE = "Resize"
15 | # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting
16 | CROP = "Crop"
17 |
18 |
19 | class LDMSampler(str, Enum):
20 | ddim = "ddim"
21 | plms = "plms"
22 |
23 |
24 | class SDSampler(str, Enum):
25 | ddim = "ddim"
26 | pndm = "pndm"
27 | k_lms = "k_lms"
28 | k_euler = "k_euler"
29 | k_euler_a = "k_euler_a"
30 | dpm_plus_plus = "dpm++"
31 | uni_pc = "uni_pc"
32 |
33 |
34 | class Config(BaseModel):
35 | class Config:
36 | arbitrary_types_allowed = True
37 |
38 | # Configs for ldm model
39 | ldm_steps: int
40 | ldm_sampler: str = LDMSampler.plms
41 |
42 | # Configs for zits model
43 | zits_wireframe: bool = True
44 |
45 | # Configs for High Resolution Strategy(different way to preprocess image)
46 | hd_strategy: str # See HDStrategy Enum
47 | hd_strategy_crop_margin: int
48 | # If the longer side of the image is larger than this value, use crop strategy
49 | hd_strategy_crop_trigger_size: int
50 | hd_strategy_resize_limit: int
51 |
52 | # Configs for Stable Diffusion 1.5
53 | prompt: str = ""
54 | negative_prompt: str = ""
55 | # Crop image to this size before doing sd inpainting
56 | # The value is always on the original image scale
57 | use_croper: bool = False
58 | croper_x: int = None
59 | croper_y: int = None
60 | croper_height: int = None
61 | croper_width: int = None
62 |
63 | # Resize the image before doing sd inpainting, the area outside the mask will not lose quality.
64 | # Used by sd models and paint_by_example model
65 | sd_scale: float = 1.0
66 | # Blur the edge of mask area. The higher the number the smoother blend with the original image
67 | sd_mask_blur: int = 0
68 | # Ignore this value, it's useless for inpainting
69 | sd_strength: float = 0.75
70 | # The number of denoising steps. More denoising steps usually lead to a
71 | # higher quality image at the expense of slower inference.
72 | sd_steps: int = 50
73 | # Higher guidance scale encourages to generate images that are closely linked
74 | # to the text prompt, usually at the expense of lower image quality.
75 | sd_guidance_scale: float = 7.5
76 | sd_sampler: str = SDSampler.uni_pc
77 | # -1 mean random seed
78 | sd_seed: int = 42
79 | sd_match_histograms: bool = False
80 |
81 | # Configs for opencv inpainting
82 | # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
83 | cv2_flag: str = "INPAINT_NS"
84 | cv2_radius: int = 4
85 |
86 | # Paint by Example
87 | paint_by_example_steps: int = 50
88 | paint_by_example_guidance_scale: float = 7.5
89 | paint_by_example_mask_blur: int = 0
90 | paint_by_example_seed: int = 42
91 | paint_by_example_match_histograms: bool = False
92 | paint_by_example_example_image: Optional[Image] = None
93 |
94 | # InstructPix2Pix
95 | p2p_steps: int = 50
96 | p2p_image_guidance_scale: float = 1.5
97 | p2p_guidance_scale: float = 7.5
98 |
99 | # ControlNet
100 | controlnet_conditioning_scale: float = 0.4
101 | controlnet_method: str = "control_v11p_sd15_canny"
102 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Uminosachi/inpaint-anything/c2204930c48a4b2b65bc7835ff527ebb72de5be9/lama_cleaner/tests/__init__.py
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_instruct_pix2pix.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | import torch
5 |
6 | from lama_cleaner.model_manager import ModelManager
7 | from lama_cleaner.tests.test_model import get_config, assert_equal
8 | from lama_cleaner.schema import HDStrategy
9 |
10 | current_dir = Path(__file__).parent.absolute().resolve()
11 | save_dir = current_dir / 'result'
12 | save_dir.mkdir(exist_ok=True, parents=True)
13 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 |
15 |
16 | @pytest.mark.parametrize("disable_nsfw", [True, False])
17 | @pytest.mark.parametrize("cpu_offload", [False, True])
18 | def test_instruct_pix2pix(disable_nsfw, cpu_offload):
19 | sd_steps = 50 if device == 'cuda' else 1
20 | model = ModelManager(name="instruct_pix2pix",
21 | device=torch.device(device),
22 | hf_access_token="",
23 | sd_run_local=False,
24 | disable_nsfw=disable_nsfw,
25 | sd_cpu_textencoder=False,
26 | cpu_offload=cpu_offload)
27 | cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)
28 |
29 | name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
30 |
31 | assert_equal(
32 | model,
33 | cfg,
34 | f"instruct_pix2pix_{name}.png",
35 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
36 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
37 | fx=1.3
38 | )
39 |
40 |
41 | @pytest.mark.parametrize("disable_nsfw", [False])
42 | @pytest.mark.parametrize("cpu_offload", [False])
43 | def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
44 | sd_steps = 50 if device == 'cuda' else 1
45 | model = ModelManager(name="instruct_pix2pix",
46 | device=torch.device(device),
47 | hf_access_token="",
48 | sd_run_local=False,
49 | disable_nsfw=disable_nsfw,
50 | sd_cpu_textencoder=False,
51 | cpu_offload=cpu_offload)
52 | cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)
53 |
54 | name = f"snow"
55 |
56 | assert_equal(
57 | model,
58 | cfg,
59 | f"instruct_pix2pix_{name}.png",
60 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
61 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
62 | )
63 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_interactive_seg.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from lama_cleaner.plugins import InteractiveSeg, Click
7 |
8 | current_dir = Path(__file__).parent.absolute().resolve()
9 | save_dir = current_dir / "result"
10 | save_dir.mkdir(exist_ok=True, parents=True)
11 | img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
12 |
13 |
14 | def test_interactive_seg():
15 | interactive_seg_model = InteractiveSeg()
16 | img = cv2.imread(str(img_p))
17 | pred = interactive_seg_model.forward(
18 | img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
19 | )
20 | cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
21 |
22 |
23 | def test_interactive_seg_with_negative_click():
24 | interactive_seg_model = InteractiveSeg()
25 | img = cv2.imread(str(img_p))
26 | pred = interactive_seg_model.forward(
27 | img,
28 | clicks=[
29 | Click(coords=(256, 256), indx=0, is_positive=True),
30 | Click(coords=(384, 256), indx=1, is_positive=False),
31 | ],
32 | )
33 | cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
34 |
35 |
36 | def test_interactive_seg_with_prev_mask():
37 | interactive_seg_model = InteractiveSeg()
38 | img = cv2.imread(str(img_p))
39 | mask = np.zeros_like(img)[:, :, 0]
40 | pred = interactive_seg_model.forward(
41 | img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
42 | )
43 | cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)
44 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_load_img.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from lama_cleaner.helper import load_img
4 |
5 | current_dir = Path(__file__).parent.absolute().resolve()
6 | png_img_p = current_dir / "image.png"
7 | jpg_img_p = current_dir / "bunny.jpeg"
8 |
9 |
10 | def test_load_png_image():
11 | with open(png_img_p, "rb") as f:
12 | np_img, alpha_channel = load_img(f.read())
13 | assert np_img.shape == (256, 256, 3)
14 | assert alpha_channel.shape == (256, 256)
15 |
16 |
17 | def test_load_jpg_image():
18 | with open(jpg_img_p, "rb") as f:
19 | np_img, alpha_channel = load_img(f.read())
20 | assert np_img.shape == (394, 448, 3)
21 | assert alpha_channel is None
22 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_model_md5.py:
--------------------------------------------------------------------------------
1 | def test_load_model():
2 | from lama_cleaner.plugins import InteractiveSeg
3 | from lama_cleaner.model_manager import ModelManager
4 |
5 | interactive_seg_model = InteractiveSeg('vit_l', 'cpu')
6 |
7 | models = [
8 | "lama",
9 | "ldm",
10 | "zits",
11 | "mat",
12 | "fcf",
13 | "manga",
14 | ]
15 | for m in models:
16 | ModelManager(
17 | name=m,
18 | device="cpu",
19 | no_half=False,
20 | hf_access_token="",
21 | disable_nsfw=False,
22 | sd_cpu_textencoder=True,
23 | sd_run_local=True,
24 | local_files_only=True,
25 | cpu_offload=True,
26 | enable_xformers=False,
27 | )
28 |
29 |
30 | # def create_empty_file(tmp_dir, name):
31 | # tmp_model_dir = os.path.join(tmp_dir, "torch", "hub", "checkpoints")
32 | # Path(tmp_model_dir).mkdir(exist_ok=True, parents=True)
33 | # path = os.path.join(tmp_model_dir, name)
34 | # with open(path, "w") as f:
35 | # f.write("1")
36 | #
37 | #
38 | # def test_load_model_error():
39 | # MODELS = [
40 | # ("big-lama.pt", "e3aa4aaa15225a33ec84f9f4bc47e500"),
41 | # ("cond_stage_model_encode.pt", "23239fc9081956a3e70de56472b3f296"),
42 | # ("cond_stage_model_decode.pt", "fe419cd15a750d37a4733589d0d3585c"),
43 | # ("diffusion.pt", "b0afda12bf790c03aba2a7431f11d22d"),
44 | # ]
45 | # with tempfile.TemporaryDirectory() as tmp_dir:
46 | # os.environ["XDG_CACHE_HOME"] = tmp_dir
47 | # for name, md5 in MODELS:
48 | # create_empty_file(tmp_dir, name)
49 | # test_load_model()
50 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_paint_by_example.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import cv2
4 | import pytest
5 | import torch
6 | from PIL import Image
7 |
8 | from lama_cleaner.model_manager import ModelManager
9 | from lama_cleaner.schema import HDStrategy
10 | from lama_cleaner.tests.test_model import get_config, get_data
11 |
12 | current_dir = Path(__file__).parent.absolute().resolve()
13 | save_dir = current_dir / 'result'
14 | save_dir.mkdir(exist_ok=True, parents=True)
15 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
16 | device = torch.device(device)
17 |
18 |
19 | def assert_equal(
20 | model, config, gt_name,
21 | fx: float = 1, fy: float = 1,
22 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
23 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
24 | example_p=current_dir / "bunny.jpeg",
25 | ):
26 | img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
27 |
28 | example_image = cv2.imread(str(example_p))
29 | example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
30 | example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
31 |
32 | print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
33 | config.paint_by_example_example_image = Image.fromarray(example_image)
34 | res = model(img, mask, config)
35 | cv2.imwrite(str(save_dir / gt_name), res)
36 |
37 |
38 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
39 | def test_paint_by_example(strategy):
40 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
41 | cfg = get_config(strategy, paint_by_example_steps=30)
42 | assert_equal(
43 | model,
44 | cfg,
45 | f"paint_by_example_{strategy.capitalize()}.png",
46 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
47 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
48 | fy=0.9,
49 | fx=1.3,
50 | )
51 |
52 |
53 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
54 | def test_paint_by_example_disable_nsfw(strategy):
55 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
56 | cfg = get_config(strategy, paint_by_example_steps=30)
57 | assert_equal(
58 | model,
59 | cfg,
60 | f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
61 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
62 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
63 | )
64 |
65 |
66 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
67 | def test_paint_by_example_sd_scale(strategy):
68 | model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
69 | cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
70 | assert_equal(
71 | model,
72 | cfg,
73 | f"paint_by_example_{strategy.capitalize()}_sdscale.png",
74 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
75 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
76 | fy=0.9,
77 | fx=1.3
78 | )
79 |
80 |
81 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
82 | def test_paint_by_example_cpu_offload(strategy):
83 | model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
84 | cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
85 | assert_equal(
86 | model,
87 | cfg,
88 | f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
89 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
90 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
91 | )
92 |
93 |
94 | @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
95 | def test_paint_by_example_cpu_offload_cpu_device(strategy):
96 | model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
97 | cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
98 | assert_equal(
99 | model,
100 | cfg,
101 | f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png",
102 | img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
103 | mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
104 | fy=0.9,
105 | fx=1.3
106 | )
107 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_plugins.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import time
4 |
5 | from lama_cleaner.plugins.anime_seg import AnimeSeg
6 |
7 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
8 | from pathlib import Path
9 |
10 | import cv2
11 | import pytest
12 | import torch.cuda
13 |
14 | from lama_cleaner.plugins import (
15 | RemoveBG,
16 | RealESRGANUpscaler,
17 | GFPGANPlugin,
18 | RestoreFormerPlugin,
19 | InteractiveSeg,
20 | )
21 |
22 | current_dir = Path(__file__).parent.absolute().resolve()
23 | save_dir = current_dir / "result"
24 | save_dir.mkdir(exist_ok=True, parents=True)
25 | img_p = current_dir / "bunny.jpeg"
26 | img_bytes = open(img_p, "rb").read()
27 | bgr_img = cv2.imread(str(img_p))
28 | rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
29 |
30 |
31 | def _save(img, name):
32 | cv2.imwrite(str(save_dir / name), img)
33 |
34 |
35 | def test_remove_bg():
36 | model = RemoveBG()
37 | res = model.forward(bgr_img)
38 | res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
39 | _save(res, "test_remove_bg.png")
40 |
41 |
42 | def test_anime_seg():
43 | model = AnimeSeg()
44 | img = cv2.imread(str(current_dir / "anime_test.png"))
45 | res = model.forward(img)
46 | assert len(res.shape) == 3
47 | assert res.shape[-1] == 4
48 | _save(res, "test_anime_seg.png")
49 |
50 |
51 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
52 | def test_upscale(device):
53 | if device == "cuda" and not torch.cuda.is_available():
54 | return
55 | if device == "mps" and not torch.backends.mps.is_available():
56 | return
57 |
58 | model = RealESRGANUpscaler("realesr-general-x4v3", device)
59 | res = model.forward(bgr_img, 2)
60 | _save(res, f"test_upscale_x2_{device}.png")
61 |
62 | res = model.forward(bgr_img, 4)
63 | _save(res, f"test_upscale_x4_{device}.png")
64 |
65 |
66 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
67 | def test_gfpgan(device):
68 | if device == "cuda" and not torch.cuda.is_available():
69 | return
70 | if device == "mps" and not torch.backends.mps.is_available():
71 | return
72 | model = GFPGANPlugin(device)
73 | res = model(rgb_img, None, None)
74 | _save(res, f"test_gfpgan_{device}.png")
75 |
76 |
77 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
78 | def test_restoreformer(device):
79 | if device == "cuda" and not torch.cuda.is_available():
80 | return
81 | if device == "mps" and not torch.backends.mps.is_available():
82 | return
83 | model = RestoreFormerPlugin(device)
84 | res = model(rgb_img, None, None)
85 | _save(res, f"test_restoreformer_{device}.png")
86 |
87 |
88 | @pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
89 | def test_segment_anything(device):
90 | if device == "cuda" and not torch.cuda.is_available():
91 | return
92 | if device == "mps" and not torch.backends.mps.is_available():
93 | return
94 | img_md5 = hashlib.md5(img_bytes).hexdigest()
95 | model = InteractiveSeg("vit_l", device)
96 | new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
97 |
98 | save_name = f"test_segment_anything_{device}.png"
99 | _save(new_mask, save_name)
100 |
101 | start = time.time()
102 | model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
103 | print(f"Time for {save_name}: {time.time() - start:.2f}s")
104 |
--------------------------------------------------------------------------------
/lama_cleaner/tests/test_save_exif.py:
--------------------------------------------------------------------------------
1 | import io
2 | from pathlib import Path
3 |
4 | from PIL import Image
5 |
6 | from lama_cleaner.helper import pil_to_bytes, load_img
7 |
8 | current_dir = Path(__file__).parent.absolute().resolve()
9 |
10 |
11 | def print_exif(exif):
12 | for k, v in exif.items():
13 | print(f"{k}: {v}")
14 |
15 |
16 | def run_test(img_p: Path):
17 | print(img_p)
18 | ext = img_p.suffix.strip(".")
19 | img_bytes = img_p.read_bytes()
20 | np_img, _, exif_infos = load_img(img_bytes, False, True)
21 | print(exif_infos)
22 | print("Original exif_infos")
23 | print_exif(exif_infos["exif"])
24 |
25 | pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos={})
26 |
27 | pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, exif_infos=exif_infos)
28 | res_img = Image.open(io.BytesIO(pil_bytes))
29 | print(f"Result img info: {res_img.info}")
30 | res_exif = res_img.getexif()
31 | print_exif(res_exif)
32 | assert res_exif == exif_infos["exif"]
33 | assert exif_infos["parameters"] == res_img.info.get("parameters")
34 |
35 |
36 | def test_png():
37 | run_test(current_dir / "image.png")
38 | run_test(current_dir / "pnginfo_test.png")
39 |
40 |
41 | def test_jpeg():
42 | jpg_img_p = current_dir / "bunny.jpeg"
43 | run_test(jpg_img_p)
44 |
--------------------------------------------------------------------------------
/mobile_sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import warnings
8 |
9 | warnings.filterwarnings("ignore", category=UserWarning, module="mobile_sam")
10 |
11 | from .automatic_mask_generator import SamAutomaticMaskGenerator # noqa: E402
12 | from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, # noqa: E402
13 | build_sam_vit_t, sam_model_registry)
14 | from .predictor import SamPredictor # noqa: E402
15 |
16 | __all__ = [
17 | "build_sam",
18 | "build_sam_vit_h",
19 | "build_sam_vit_l",
20 | "build_sam_vit_b",
21 | "build_sam_vit_t",
22 | "sam_model_registry",
23 | "SamPredictor",
24 | "SamAutomaticMaskGenerator",
25 | ]
26 |
--------------------------------------------------------------------------------
/mobile_sam/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | def build_sam_vit_t(checkpoint=None):
48 | prompt_embed_dim = 256
49 | image_size = 1024
50 | vit_patch_size = 16
51 | image_embedding_size = image_size // vit_patch_size
52 | mobile_sam = Sam(
53 | image_encoder=TinyViT(
54 | img_size=1024, in_chans=3, num_classes=1000,
55 | embed_dims=[64, 128, 160, 320],
56 | depths=[2, 2, 6, 2],
57 | num_heads=[2, 4, 5, 10],
58 | window_sizes=[7, 7, 14, 7],
59 | mlp_ratio=4.,
60 | drop_rate=0.,
61 | drop_path_rate=0.0,
62 | use_checkpoint=False,
63 | mbconv_expand_ratio=4.0,
64 | local_conv_size=3,
65 | layer_lr_decay=0.8
66 | ),
67 | prompt_encoder=PromptEncoder(
68 | embed_dim=prompt_embed_dim,
69 | image_embedding_size=(image_embedding_size, image_embedding_size),
70 | input_image_size=(image_size, image_size),
71 | mask_in_chans=16,
72 | ),
73 | mask_decoder=MaskDecoder(
74 | num_multimask_outputs=3,
75 | transformer=TwoWayTransformer(
76 | depth=2,
77 | embedding_dim=prompt_embed_dim,
78 | mlp_dim=2048,
79 | num_heads=8,
80 | ),
81 | transformer_dim=prompt_embed_dim,
82 | iou_head_depth=3,
83 | iou_head_hidden_dim=256,
84 | ),
85 | pixel_mean=[123.675, 116.28, 103.53],
86 | pixel_std=[58.395, 57.12, 57.375],
87 | )
88 |
89 | mobile_sam.eval()
90 | if checkpoint is not None:
91 | with open(checkpoint, "rb") as f:
92 | state_dict = torch.load(f)
93 | mobile_sam.load_state_dict(state_dict)
94 | return mobile_sam
95 |
96 |
97 | sam_model_registry = {
98 | "default": build_sam_vit_h,
99 | "vit_h": build_sam_vit_h,
100 | "vit_l": build_sam_vit_l,
101 | "vit_b": build_sam_vit_b,
102 | "vit_t": build_sam_vit_t,
103 | }
104 |
105 |
106 | def _build_sam(
107 | encoder_embed_dim,
108 | encoder_depth,
109 | encoder_num_heads,
110 | encoder_global_attn_indexes,
111 | checkpoint=None,
112 | ):
113 | prompt_embed_dim = 256
114 | image_size = 1024
115 | vit_patch_size = 16
116 | image_embedding_size = image_size // vit_patch_size
117 | sam = Sam(
118 | image_encoder=ImageEncoderViT(
119 | depth=encoder_depth,
120 | embed_dim=encoder_embed_dim,
121 | img_size=image_size,
122 | mlp_ratio=4,
123 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
124 | num_heads=encoder_num_heads,
125 | patch_size=vit_patch_size,
126 | qkv_bias=True,
127 | use_rel_pos=True,
128 | global_attn_indexes=encoder_global_attn_indexes,
129 | window_size=14,
130 | out_chans=prompt_embed_dim,
131 | ),
132 | prompt_encoder=PromptEncoder(
133 | embed_dim=prompt_embed_dim,
134 | image_embedding_size=(image_embedding_size, image_embedding_size),
135 | input_image_size=(image_size, image_size),
136 | mask_in_chans=16,
137 | ),
138 | mask_decoder=MaskDecoder(
139 | num_multimask_outputs=3,
140 | transformer=TwoWayTransformer(
141 | depth=2,
142 | embedding_dim=prompt_embed_dim,
143 | mlp_dim=2048,
144 | num_heads=8,
145 | ),
146 | transformer_dim=prompt_embed_dim,
147 | iou_head_depth=3,
148 | iou_head_hidden_dim=256,
149 | ),
150 | pixel_mean=[123.675, 116.28, 103.53],
151 | pixel_std=[58.395, 57.12, 57.375],
152 | )
153 | sam.eval()
154 | if checkpoint is not None:
155 | with open(checkpoint, "rb") as f:
156 | state_dict = torch.load(f)
157 | sam.load_state_dict(state_dict)
158 | return sam
159 |
--------------------------------------------------------------------------------
/mobile_sam/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 | from .tiny_vit_sam import TinyViT
13 |
14 | __all__ = [
15 | "Sam",
16 | "ImageEncoderViT",
17 | "MaskDecoder",
18 | "PromptEncoder",
19 | "TwoWayTransformer",
20 | "TinyViT",
21 | ]
22 |
--------------------------------------------------------------------------------
/mobile_sam/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/mobile_sam/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/mobile_sam/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/mobile_sam/utils/torch_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.ops.boxes import box_iou
3 |
4 |
5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
6 | order = torch.argsort(-scores)
7 | keep = []
8 |
9 | while order.numel() > 0:
10 | i = order[0]
11 | keep.append(i.item())
12 |
13 | if order.numel() == 1:
14 | break
15 |
16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
17 | mask = ious <= iou_threshold
18 | order = order[1:][mask]
19 |
20 | return torch.tensor(keep, device=bboxes.device)
21 |
--------------------------------------------------------------------------------
/mobile_sam/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu121
2 | torch==2.3.1
3 | torchvision
4 | accelerate
5 | diffusers<0.32.0
6 | gradio<4.0.0
7 | huggingface-hub
8 | numpy<2.0.0
9 | opencv-python
10 | pillow
11 | segment-anything
12 | transformers<5.0.0
13 | xformers==0.0.27
14 | # lama-cleaner
15 | ultralytics
16 | tqdm
17 | packaging
18 | loguru
19 | rich
20 | pydantic
21 | timm
22 | onnxruntime
23 | hydra-core
24 | iopath
25 |
--------------------------------------------------------------------------------
/requirements_cu118.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu118
2 | torch==2.3.1
3 | torchvision
4 | accelerate
5 | diffusers<0.32.0
6 | gradio<4.0.0
7 | huggingface-hub
8 | numpy<2.0.0
9 | opencv-python
10 | pillow
11 | segment-anything
12 | transformers<5.0.0
13 | xformers==0.0.27
14 | # lama-cleaner
15 | ultralytics
16 | tqdm
17 | packaging
18 | loguru
19 | rich
20 | pydantic
21 | timm
22 | onnxruntime
23 | hydra-core
24 | iopath
25 |
--------------------------------------------------------------------------------
/requirements_mac.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.2
2 | torchvision
3 | accelerate
4 | diffusers<0.32.0
5 | gradio<4.0.0
6 | huggingface-hub
7 | numpy<2.0.0
8 | opencv-python
9 | pillow
10 | segment-anything
11 | transformers<5.0.0
12 | # xformers
13 | # lama-cleaner
14 | ultralytics
15 | tqdm
16 | packaging
17 | loguru
18 | rich
19 | pydantic
20 | timm
21 | onnxruntime
22 | hydra-core
23 | iopath
24 |
--------------------------------------------------------------------------------
/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import warnings
9 |
10 | from hydra import initialize_config_dir, initialize_config_module # noqa: F401
11 |
12 | warnings.filterwarnings("ignore", category=UserWarning, module="sam2")
13 |
14 | inpa_basedir = os.path.abspath(os.path.normpath(os.path.join(os.path.dirname(__file__), "..")))
15 | configs_path = os.path.join(inpa_basedir, "sam2_configs")
16 |
17 | try:
18 | initialize_config_dir(configs_path, version_base="1.2")
19 | except TypeError:
20 | initialize_config_dir(configs_path)
21 | # initialize_config_module("sam2_configs", version_base="1.2")
22 |
--------------------------------------------------------------------------------
/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | import torch
10 | from hydra import compose
11 | from hydra.utils import instantiate
12 | from omegaconf import OmegaConf
13 |
14 |
15 | def build_sam2(
16 | config_file,
17 | ckpt_path=None,
18 | device="cuda",
19 | mode="eval",
20 | hydra_overrides_extra=[],
21 | apply_postprocessing=True,
22 | ):
23 |
24 | if apply_postprocessing:
25 | hydra_overrides_extra = hydra_overrides_extra.copy()
26 | hydra_overrides_extra += [
27 | # dynamically fall back to multi-mask if the single mask is not stable
28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
30 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
31 | ]
32 | # Read config and init model
33 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
34 | OmegaConf.resolve(cfg)
35 | model = instantiate(cfg.model, _recursive_=True)
36 | _load_checkpoint(model, ckpt_path)
37 | model = model.to(device)
38 | if mode == "eval":
39 | model.eval()
40 | return model
41 |
42 |
43 | def build_sam2_video_predictor(
44 | config_file,
45 | ckpt_path=None,
46 | device="cuda",
47 | mode="eval",
48 | hydra_overrides_extra=[],
49 | apply_postprocessing=True,
50 | ):
51 | hydra_overrides = [
52 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
53 | ]
54 | if apply_postprocessing:
55 | hydra_overrides_extra = hydra_overrides_extra.copy()
56 | hydra_overrides_extra += [
57 | # dynamically fall back to multi-mask if the single mask is not stable
58 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
61 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
62 | "++model.binarize_mask_from_pts_for_mem_enc=true",
63 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
64 | "++model.fill_hole_area=8",
65 | ]
66 | hydra_overrides.extend(hydra_overrides_extra)
67 |
68 | # Read config and init model
69 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
70 | OmegaConf.resolve(cfg)
71 | model = instantiate(cfg.model, _recursive_=True)
72 | _load_checkpoint(model, ckpt_path)
73 | model = model.to(device)
74 | if mode == "eval":
75 | model.eval()
76 | return model
77 |
78 |
79 | def _load_checkpoint(model, ckpt_path):
80 | if ckpt_path is not None:
81 | sd = torch.load(ckpt_path, map_location="cpu")["model"]
82 | missing_keys, unexpected_keys = model.load_state_dict(sd)
83 | if missing_keys:
84 | logging.error(missing_keys)
85 | raise RuntimeError()
86 | if unexpected_keys:
87 | logging.error(unexpected_keys)
88 | raise RuntimeError()
89 | logging.info("Loaded checkpoint sucessfully")
90 |
--------------------------------------------------------------------------------
/sam2/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class ImageEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | trunk: nn.Module,
18 | neck: nn.Module,
19 | scalp: int = 0,
20 | ):
21 | super().__init__()
22 | self.trunk = trunk
23 | self.neck = neck
24 | self.scalp = scalp
25 | assert (
26 | self.trunk.channel_list == self.neck.backbone_channel_list
27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28 |
29 | def forward(self, sample: torch.Tensor):
30 | # Forward through backbone
31 | features, pos = self.neck(self.trunk(sample))
32 | if self.scalp > 0:
33 | # Discard the lowest resolution features
34 | features, pos = features[: -self.scalp], pos[: -self.scalp]
35 |
36 | src = features[-1]
37 | output = {
38 | "vision_features": src,
39 | "vision_pos_enc": pos,
40 | "backbone_fpn": features,
41 | }
42 | return output
43 |
44 |
45 | class FpnNeck(nn.Module):
46 | """
47 | A modified variant of Feature Pyramid Network (FPN) neck
48 | (we remove output conv and also do bicubic interpolation similar to ViT
49 | pos embed interpolation)
50 | """
51 |
52 | def __init__(
53 | self,
54 | position_encoding: nn.Module,
55 | d_model: int,
56 | backbone_channel_list: List[int],
57 | kernel_size: int = 1,
58 | stride: int = 1,
59 | padding: int = 0,
60 | fpn_interp_model: str = "bilinear",
61 | fuse_type: str = "sum",
62 | fpn_top_down_levels: Optional[List[int]] = None,
63 | ):
64 | """Initialize the neck
65 | :param trunk: the backbone
66 | :param position_encoding: the positional encoding to use
67 | :param d_model: the dimension of the model
68 | :param neck_norm: the normalization to use
69 | """
70 | super().__init__()
71 | self.position_encoding = position_encoding
72 | self.convs = nn.ModuleList()
73 | self.backbone_channel_list = backbone_channel_list
74 | for dim in backbone_channel_list:
75 | current = nn.Sequential()
76 | current.add_module(
77 | "conv",
78 | nn.Conv2d(
79 | in_channels=dim,
80 | out_channels=d_model,
81 | kernel_size=kernel_size,
82 | stride=stride,
83 | padding=padding,
84 | ),
85 | )
86 |
87 | self.convs.append(current)
88 | self.fpn_interp_model = fpn_interp_model
89 | assert fuse_type in ["sum", "avg"]
90 | self.fuse_type = fuse_type
91 |
92 | # levels to have top-down features in its outputs
93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94 | # have top-down propagation, while outputs of level 0 and level 1 have only
95 | # lateral features from the same backbone level.
96 | if fpn_top_down_levels is None:
97 | # default is to have top-down features on all levels
98 | fpn_top_down_levels = range(len(self.convs))
99 | self.fpn_top_down_levels = list(fpn_top_down_levels)
100 |
101 | def forward(self, xs: List[torch.Tensor]):
102 |
103 | out = [None] * len(self.convs)
104 | pos = [None] * len(self.convs)
105 | assert len(xs) == len(self.convs)
106 | # fpn forward pass
107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
108 | prev_features = None
109 | # forward in top-down order (from low to high resolution)
110 | n = len(self.convs) - 1
111 | for i in range(n, -1, -1):
112 | x = xs[i]
113 | lateral_features = self.convs[n - i](x)
114 | if i in self.fpn_top_down_levels and prev_features is not None:
115 | top_down_features = F.interpolate(
116 | prev_features.to(dtype=torch.float32),
117 | scale_factor=2.0,
118 | mode=self.fpn_interp_model,
119 | align_corners=(
120 | None if self.fpn_interp_model == "nearest" else False
121 | ),
122 | antialias=False,
123 | )
124 | prev_features = lateral_features + top_down_features
125 | if self.fuse_type == "avg":
126 | prev_features /= 2
127 | else:
128 | prev_features = lateral_features
129 | x_out = prev_features
130 | out[i] = x_out
131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
132 |
133 | return out, pos
134 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = (
36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | )
38 | return windows, (Hp, Wp)
39 |
40 |
41 | def window_unpartition(windows, window_size, pad_hw, hw):
42 | """
43 | Window unpartition into original sequences and removing padding.
44 | Args:
45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46 | window_size (int): window size.
47 | pad_hw (Tuple): padded height and width (Hp, Wp).
48 | hw (Tuple): original height and width (H, W) before padding.
49 | Returns:
50 | x: unpartitioned sequences with [B, H, W, C].
51 | """
52 | Hp, Wp = pad_hw
53 | H, W = hw
54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55 | x = windows.view(
56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57 | )
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59 |
60 | if Hp > H or Wp > W:
61 | x = x[:, :H, :W, :].contiguous()
62 | return x
63 |
64 |
65 | class PatchEmbed(nn.Module):
66 | """
67 | Image to Patch Embedding.
68 | """
69 |
70 | def __init__(
71 | self,
72 | kernel_size: Tuple[int, ...] = (7, 7),
73 | stride: Tuple[int, ...] = (4, 4),
74 | padding: Tuple[int, ...] = (3, 3),
75 | in_chans: int = 3,
76 | embed_dim: int = 768,
77 | ):
78 | """
79 | Args:
80 | kernel_size (Tuple): kernel size of the projection layer.
81 | stride (Tuple): stride of the projection layer.
82 | padding (Tuple): padding size of the projection layer.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): embed_dim (int): Patch embedding dimension.
85 | """
86 | super().__init__()
87 | self.proj = nn.Conv2d(
88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89 | )
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | x = self.proj(x)
93 | # B C H W -> B H W C
94 | x = x.permute(0, 2, 3, 1)
95 | return x
96 |
--------------------------------------------------------------------------------
/sam2/modeling/memory_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn, Tensor
11 |
12 | from sam2.modeling.sam.transformer import RoPEAttention
13 |
14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15 |
16 |
17 | class MemoryAttentionLayer(nn.Module):
18 |
19 | def __init__(
20 | self,
21 | activation: str,
22 | cross_attention: nn.Module,
23 | d_model: int,
24 | dim_feedforward: int,
25 | dropout: float,
26 | pos_enc_at_attn: bool,
27 | pos_enc_at_cross_attn_keys: bool,
28 | pos_enc_at_cross_attn_queries: bool,
29 | self_attention: nn.Module,
30 | ):
31 | super().__init__()
32 | self.d_model = d_model
33 | self.dim_feedforward = dim_feedforward
34 | self.dropout_value = dropout
35 | self.self_attn = self_attention
36 | self.cross_attn_image = cross_attention
37 |
38 | # Implementation of Feedforward model
39 | self.linear1 = nn.Linear(d_model, dim_feedforward)
40 | self.dropout = nn.Dropout(dropout)
41 | self.linear2 = nn.Linear(dim_feedforward, d_model)
42 |
43 | self.norm1 = nn.LayerNorm(d_model)
44 | self.norm2 = nn.LayerNorm(d_model)
45 | self.norm3 = nn.LayerNorm(d_model)
46 | self.dropout1 = nn.Dropout(dropout)
47 | self.dropout2 = nn.Dropout(dropout)
48 | self.dropout3 = nn.Dropout(dropout)
49 |
50 | self.activation_str = activation
51 | self.activation = get_activation_fn(activation)
52 |
53 | # Where to add pos enc
54 | self.pos_enc_at_attn = pos_enc_at_attn
55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57 |
58 | def _forward_sa(self, tgt, query_pos):
59 | # Self-Attention
60 | tgt2 = self.norm1(tgt)
61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62 | tgt2 = self.self_attn(q, k, v=tgt2)
63 | tgt = tgt + self.dropout1(tgt2)
64 | return tgt
65 |
66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67 | kwds = {}
68 | if num_k_exclude_rope > 0:
69 | assert isinstance(self.cross_attn_image, RoPEAttention)
70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71 |
72 | # Cross-Attention
73 | tgt2 = self.norm2(tgt)
74 | tgt2 = self.cross_attn_image(
75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77 | v=memory,
78 | **kwds,
79 | )
80 | tgt = tgt + self.dropout2(tgt2)
81 | return tgt
82 |
83 | def forward(
84 | self,
85 | tgt,
86 | memory,
87 | pos: Optional[Tensor] = None,
88 | query_pos: Optional[Tensor] = None,
89 | num_k_exclude_rope: int = 0,
90 | ) -> torch.Tensor:
91 |
92 | # Self-Attn, Cross-Attn
93 | tgt = self._forward_sa(tgt, query_pos)
94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95 | # MLP
96 | tgt2 = self.norm3(tgt)
97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98 | tgt = tgt + self.dropout3(tgt2)
99 | return tgt
100 |
101 |
102 | class MemoryAttention(nn.Module):
103 | def __init__(
104 | self,
105 | d_model: int,
106 | pos_enc_at_input: bool,
107 | layer: nn.Module,
108 | num_layers: int,
109 | batch_first: bool = True, # Do layers expect batch first input?
110 | ):
111 | super().__init__()
112 | self.d_model = d_model
113 | self.layers = get_clones(layer, num_layers)
114 | self.num_layers = num_layers
115 | self.norm = nn.LayerNorm(d_model)
116 | self.pos_enc_at_input = pos_enc_at_input
117 | self.batch_first = batch_first
118 |
119 | def forward(
120 | self,
121 | curr: torch.Tensor, # self-attention inputs
122 | memory: torch.Tensor, # cross-attention inputs
123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126 | ):
127 | if isinstance(curr, list):
128 | assert isinstance(curr_pos, list)
129 | assert len(curr) == len(curr_pos) == 1
130 | curr, curr_pos = (
131 | curr[0],
132 | curr_pos[0],
133 | )
134 |
135 | assert (
136 | curr.shape[1] == memory.shape[1]
137 | ), "Batch size must be the same for curr and memory"
138 |
139 | output = curr
140 | if self.pos_enc_at_input and curr_pos is not None:
141 | output = output + 0.1 * curr_pos
142 |
143 | if self.batch_first:
144 | # Convert to batch first
145 | output = output.transpose(0, 1)
146 | curr_pos = curr_pos.transpose(0, 1)
147 | memory = memory.transpose(0, 1)
148 | memory_pos = memory_pos.transpose(0, 1)
149 |
150 | for layer in self.layers:
151 | kwds = {}
152 | if isinstance(layer.cross_attn_image, RoPEAttention):
153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154 |
155 | output = layer(
156 | tgt=output,
157 | memory=memory,
158 | pos=memory_pos,
159 | query_pos=curr_pos,
160 | **kwds,
161 | )
162 | normed_output = self.norm(output)
163 |
164 | if self.batch_first:
165 | # Convert back to seq first
166 | normed_output = normed_output.transpose(0, 1)
167 | curr_pos = curr_pos.transpose(0, 1)
168 |
169 | return normed_output
170 |
--------------------------------------------------------------------------------
/sam2/modeling/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/sam2_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import copy
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
16 | """
17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
18 | that are temporally closest to the current frame at `frame_idx`. Here, we take
19 | - a) the closest conditioning frame before `frame_idx` (if any);
20 | - b) the closest conditioning frame after `frame_idx` (if any);
21 | - c) any other temporally closest conditioning frames until reaching a total
22 | of `max_cond_frame_num` conditioning frames.
23 |
24 | Outputs:
25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
27 | """
28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
29 | selected_outputs = cond_frame_outputs
30 | unselected_outputs = {}
31 | else:
32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
33 | selected_outputs = {}
34 |
35 | # the closest conditioning frame before `frame_idx` (if any)
36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
37 | if idx_before is not None:
38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before]
39 |
40 | # the closest conditioning frame after `frame_idx` (if any)
41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
42 | if idx_after is not None:
43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after]
44 |
45 | # add other temporally closest conditioning frames until reaching a total
46 | # of `max_cond_frame_num` conditioning frames.
47 | num_remain = max_cond_frame_num - len(selected_outputs)
48 | inds_remain = sorted(
49 | (t for t in cond_frame_outputs if t not in selected_outputs),
50 | key=lambda x: abs(x - frame_idx),
51 | )[:num_remain]
52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
53 | unselected_outputs = {
54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
55 | }
56 |
57 | return selected_outputs, unselected_outputs
58 |
59 |
60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000):
61 | """
62 | Get 1D sine positional embedding as in the original Transformer paper.
63 | """
64 | pe_dim = dim // 2
65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
67 |
68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t
69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
70 | return pos_embed
71 |
72 |
73 | def get_activation_fn(activation):
74 | """Return an activation function given a string"""
75 | if activation == "relu":
76 | return F.relu
77 | if activation == "gelu":
78 | return F.gelu
79 | if activation == "glu":
80 | return F.glu
81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
82 |
83 |
84 | def get_clones(module, N):
85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
86 |
87 |
88 | class DropPath(nn.Module):
89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
90 | def __init__(self, drop_prob=0.0, scale_by_keep=True):
91 | super(DropPath, self).__init__()
92 | self.drop_prob = drop_prob
93 | self.scale_by_keep = scale_by_keep
94 |
95 | def forward(self, x):
96 | if self.drop_prob == 0.0 or not self.training:
97 | return x
98 | keep_prob = 1 - self.drop_prob
99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
101 | if keep_prob > 0.0 and self.scale_by_keep:
102 | random_tensor.div_(keep_prob)
103 | return x * random_tensor
104 |
105 |
106 | # Lightly adapted from
107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
108 | class MLP(nn.Module):
109 | def __init__(
110 | self,
111 | input_dim: int,
112 | hidden_dim: int,
113 | output_dim: int,
114 | num_layers: int,
115 | activation: nn.Module = nn.ReLU,
116 | sigmoid_output: bool = False,
117 | ) -> None:
118 | super().__init__()
119 | self.num_layers = num_layers
120 | h = [hidden_dim] * (num_layers - 1)
121 | self.layers = nn.ModuleList(
122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
123 | )
124 | self.sigmoid_output = sigmoid_output
125 | self.act = activation()
126 |
127 | def forward(self, x):
128 | for i, layer in enumerate(self.layers):
129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
130 | if self.sigmoid_output:
131 | x = F.sigmoid(x)
132 | return x
133 |
134 |
135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
137 | class LayerNorm2d(nn.Module):
138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
139 | super().__init__()
140 | self.weight = nn.Parameter(torch.ones(num_channels))
141 | self.bias = nn.Parameter(torch.zeros(num_channels))
142 | self.eps = eps
143 |
144 | def forward(self, x: torch.Tensor) -> torch.Tensor:
145 | u = x.mean(1, keepdim=True)
146 | s = (x - u).pow(2).mean(1, keepdim=True)
147 | x = (x - u) / torch.sqrt(s + self.eps)
148 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
149 | return x
150 |
--------------------------------------------------------------------------------
/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/utils/torch_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.ops.boxes import box_iou
3 |
4 |
5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
6 | order = torch.argsort(-scores)
7 | keep = []
8 |
9 | while order.numel() > 0:
10 | i = order[0]
11 | keep.append(i.item())
12 |
13 | if order.numel() == 1:
14 | break
15 |
16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
17 | mask = ious <= iou_threshold
18 | order = order[1:][mask]
19 |
20 | return torch.tensor(keep, device=bboxes.device)
21 |
--------------------------------------------------------------------------------
/sam2/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.transforms import Normalize, Resize, ToTensor
11 |
12 |
13 | class SAM2Transforms(nn.Module):
14 | def __init__(
15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
16 | ):
17 | """
18 | Transforms for SAM2.
19 | """
20 | super().__init__()
21 | self.resolution = resolution
22 | self.mask_threshold = mask_threshold
23 | self.max_hole_area = max_hole_area
24 | self.max_sprinkle_area = max_sprinkle_area
25 | self.mean = [0.485, 0.456, 0.406]
26 | self.std = [0.229, 0.224, 0.225]
27 | self.to_tensor = ToTensor()
28 | self.transforms = nn.Sequential(
29 | Resize((self.resolution, self.resolution), antialias=True),
30 | Normalize(self.mean, self.std),
31 | )
32 |
33 | def __call__(self, x):
34 | x = self.to_tensor(x)
35 | return self.transforms(x)
36 |
37 | def forward_batch(self, img_list):
38 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
39 | img_batch = torch.stack(img_batch, dim=0)
40 | return img_batch
41 |
42 | def transform_coords(
43 | self, coords: torch.Tensor, normalize=False, orig_hw=None
44 | ) -> torch.Tensor:
45 | """
46 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
47 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
48 |
49 | Returns
50 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
51 | """
52 | if normalize:
53 | assert orig_hw is not None
54 | h, w = orig_hw
55 | coords = coords.clone()
56 | coords[..., 0] = coords[..., 0] / w
57 | coords[..., 1] = coords[..., 1] / h
58 |
59 | coords = coords * self.resolution # unnormalize coords
60 | return coords
61 |
62 | def transform_boxes(
63 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
64 | ) -> torch.Tensor:
65 | """
66 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
67 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
68 | """
69 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
70 | return boxes
71 |
72 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
73 | """
74 | Perform PostProcessing on output masks.
75 | """
76 | from sam2.utils.misc import get_connected_components
77 |
78 | masks = masks.float()
79 | if self.max_hole_area > 0:
80 | # Holes are those connected components in background with area <= self.fill_hole_area
81 | # (background regions are those with mask scores <= self.mask_threshold)
82 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
83 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
84 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
85 | is_hole = is_hole.reshape_as(masks)
86 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
87 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
88 |
89 | if self.max_sprinkle_area > 0:
90 | labels, areas = get_connected_components(mask_flat > self.mask_threshold)
91 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
92 | is_hole = is_hole.reshape_as(masks)
93 | # We fill holes with negative mask score (-10.0) to change them to background.
94 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
95 |
96 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
97 | return masks
98 |
--------------------------------------------------------------------------------
/sam2_configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2_configs/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/sam2_configs/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/sam2_configs/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/sam2_configs/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/segment_anything_fb/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
17 | __all__ = [
18 | "build_sam",
19 | "build_sam_vit_h",
20 | "build_sam_vit_l",
21 | "build_sam_vit_b",
22 | "sam_model_registry",
23 | "SamPredictor",
24 | "SamAutomaticMaskGenerator",
25 | ]
26 |
--------------------------------------------------------------------------------
/segment_anything_fb/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/segment_anything_fb/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
13 | __all__ = [
14 | "Sam",
15 | "ImageEncoderViT",
16 | "MaskDecoder",
17 | "PromptEncoder",
18 | "TwoWayTransformer",
19 | ]
20 |
--------------------------------------------------------------------------------
/segment_anything_fb/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything_fb/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything_fb/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything_fb/utils/torch_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.ops.boxes import box_iou
3 |
4 |
5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
6 | order = torch.argsort(-scores)
7 | keep = []
8 |
9 | while order.numel() > 0:
10 | i = order[0]
11 | keep.append(i.item())
12 |
13 | if order.numel() == 1:
14 | break
15 |
16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
17 | mask = ious <= iou_threshold
18 | order = order[1:][mask]
19 |
20 | return torch.tensor(keep, device=bboxes.device)
21 |
--------------------------------------------------------------------------------
/segment_anything_fb/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/segment_anything_hq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .build_sam_baseline import sam_model_registry_baseline
15 | from .predictor import SamPredictor
16 | from .automatic_mask_generator import SamAutomaticMaskGenerator
17 |
18 | __all__ = [
19 | "build_sam",
20 | "build_sam_vit_h",
21 | "build_sam_vit_l",
22 | "build_sam_vit_b",
23 | "sam_model_registry",
24 | "sam_model_registry_baseline",
25 | "SamPredictor",
26 | "SamAutomaticMaskGenerator",
27 | ]
28 |
--------------------------------------------------------------------------------
/segment_anything_hq/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
12 | import platform
13 |
14 |
15 | def build_sam_vit_h(checkpoint=None):
16 | return _build_sam(
17 | encoder_embed_dim=1280,
18 | encoder_depth=32,
19 | encoder_num_heads=16,
20 | encoder_global_attn_indexes=[7, 15, 23, 31],
21 | checkpoint=checkpoint,
22 | )
23 |
24 |
25 | build_sam = build_sam_vit_h
26 |
27 |
28 | def build_sam_vit_l(checkpoint=None):
29 | return _build_sam(
30 | encoder_embed_dim=1024,
31 | encoder_depth=24,
32 | encoder_num_heads=16,
33 | encoder_global_attn_indexes=[5, 11, 17, 23],
34 | checkpoint=checkpoint,
35 | )
36 |
37 |
38 | def build_sam_vit_b(checkpoint=None):
39 | return _build_sam(
40 | encoder_embed_dim=768,
41 | encoder_depth=12,
42 | encoder_num_heads=12,
43 | encoder_global_attn_indexes=[2, 5, 8, 11],
44 | checkpoint=checkpoint,
45 | )
46 |
47 |
48 | sam_model_registry = {
49 | "default": build_sam_vit_h,
50 | "vit_h": build_sam_vit_h,
51 | "vit_l": build_sam_vit_l,
52 | "vit_b": build_sam_vit_b,
53 | }
54 |
55 |
56 | def _build_sam(
57 | encoder_embed_dim,
58 | encoder_depth,
59 | encoder_num_heads,
60 | encoder_global_attn_indexes,
61 | checkpoint=None,
62 | ):
63 | prompt_embed_dim = 256
64 | image_size = 1024
65 | vit_patch_size = 16
66 | image_embedding_size = image_size // vit_patch_size
67 | sam = Sam(
68 | image_encoder=ImageEncoderViT(
69 | depth=encoder_depth,
70 | embed_dim=encoder_embed_dim,
71 | img_size=image_size,
72 | mlp_ratio=4,
73 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
74 | num_heads=encoder_num_heads,
75 | patch_size=vit_patch_size,
76 | qkv_bias=True,
77 | use_rel_pos=True,
78 | global_attn_indexes=encoder_global_attn_indexes,
79 | window_size=14,
80 | out_chans=prompt_embed_dim,
81 | ),
82 | prompt_encoder=PromptEncoder(
83 | embed_dim=prompt_embed_dim,
84 | image_embedding_size=(image_embedding_size, image_embedding_size),
85 | input_image_size=(image_size, image_size),
86 | mask_in_chans=16,
87 | ),
88 | mask_decoder=MaskDecoderHQ(
89 | num_multimask_outputs=3,
90 | transformer=TwoWayTransformer(
91 | depth=2,
92 | embedding_dim=prompt_embed_dim,
93 | mlp_dim=2048,
94 | num_heads=8,
95 | ),
96 | transformer_dim=prompt_embed_dim,
97 | iou_head_depth=3,
98 | iou_head_hidden_dim=256,
99 | vit_dim=encoder_embed_dim,
100 | ),
101 | pixel_mean=[123.675, 116.28, 103.53],
102 | pixel_std=[58.395, 57.12, 57.375],
103 | )
104 | sam.eval()
105 | if checkpoint is not None:
106 | with open(checkpoint, "rb") as f:
107 | if platform.system() == "Darwin":
108 | if torch.backends.mps.is_available() and torch.backends.mps.is_built():
109 | state_dict = torch.load(f, map_location=torch.device("mps"))
110 | else:
111 | state_dict = torch.load(f, map_location=torch.device("cpu"))
112 | else:
113 | if torch.cuda.is_available():
114 | state_dict = torch.load(f)
115 | else:
116 | state_dict = torch.load(f, map_location=torch.device("cpu"))
117 | # info = sam.load_state_dict(state_dict, strict=False)
118 | # print(info)
119 | sam.load_state_dict(state_dict, strict=False)
120 | for n, p in sam.named_parameters():
121 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
122 | p.requires_grad = False
123 |
124 | return sam
125 |
--------------------------------------------------------------------------------
/segment_anything_hq/build_sam_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry_baseline = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/segment_anything_hq/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder_hq import MaskDecoderHQ
10 | from .mask_decoder import MaskDecoder
11 | from .prompt_encoder import PromptEncoder
12 | from .transformer import TwoWayTransformer
13 |
14 | __all__ = [
15 | "Sam",
16 | "ImageEncoderViT",
17 | "MaskDecoderHQ",
18 | "MaskDecoder",
19 | "PromptEncoder",
20 | "TwoWayTransformer",
21 | ]
22 |
--------------------------------------------------------------------------------
/segment_anything_hq/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything_hq/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything_hq/utils/torch_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.ops.boxes import box_iou
3 |
4 |
5 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
6 | order = torch.argsort(-scores)
7 | keep = []
8 |
9 | while order.numel() > 0:
10 | i = order[0]
11 | keep.append(i.item())
12 |
13 | if order.numel() == 1:
14 | break
15 |
16 | ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0]
17 | mask = ious <= iou_threshold
18 | order = order[1:][mask]
19 |
20 | return torch.tensor(keep, device=bboxes.device)
21 |
--------------------------------------------------------------------------------
/segment_anything_hq/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------