├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Member.txt ├── MobileSAMv2 ├── Inference.py ├── PromptGuidedDecoder │ └── Prompt_guided_Mask_Decoder.pt ├── efficientvit │ ├── __init__.py │ ├── apps │ │ ├── __init__.py │ │ ├── data_provider │ │ │ ├── __init__.py │ │ │ ├── augment │ │ │ │ ├── __init__.py │ │ │ │ ├── bbox.py │ │ │ │ └── color_aug.py │ │ │ ├── base.py │ │ │ └── random_resolution │ │ │ │ ├── __init__.py │ │ │ │ ├── _data_loader.py │ │ │ │ ├── _data_worker.py │ │ │ │ └── controller.py │ │ ├── setup.py │ │ ├── trainer │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── run_config.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dist.py │ │ │ ├── ema.py │ │ │ ├── export.py │ │ │ ├── init.py │ │ │ ├── lr.py │ │ │ ├── metric.py │ │ │ ├── misc.py │ │ │ └── opt.py │ ├── cls_model_zoo.py │ ├── clscore │ │ ├── __init__.py │ │ ├── data_provider │ │ │ ├── __init__.py │ │ │ └── imagenet.py │ │ └── trainer │ │ │ ├── __init__.py │ │ │ ├── cls_run_config.py │ │ │ ├── cls_trainer.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── label_smooth.py │ │ │ ├── metric.py │ │ │ └── mixup.py │ ├── models │ │ ├── __init__.py │ │ ├── efficientvit │ │ │ ├── __init__.py │ │ │ ├── backbone.py │ │ │ ├── cls.py │ │ │ ├── sam.py │ │ │ └── seg.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── act.py │ │ │ ├── drop.py │ │ │ ├── norm.py │ │ │ └── ops.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── list.py │ │ │ ├── network.py │ │ │ └── random.py │ ├── sam_model_zoo.py │ └── seg_model_zoo.py ├── experiments │ └── mobilesamv2.sh ├── mobilesamv2 │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ ├── promt_mobilesamv2 │ │ ├── __init__.py │ │ ├── model.py │ │ └── predict.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── test_images │ ├── 1.jpg │ └── 2.jpg ├── tinyvit │ └── tiny_vit.py └── ultralytics │ ├── __init__.py │ ├── assets │ ├── bus.jpg │ └── zidane.jpg │ ├── hub │ ├── __init__.py │ ├── auth.py │ ├── session.py │ └── utils.py │ ├── models │ ├── README.md │ ├── rt-detr │ │ ├── rtdetr-l.yaml │ │ └── rtdetr-x.yaml │ ├── v3 │ │ ├── yolov3-spp.yaml │ │ ├── yolov3-tiny.yaml │ │ └── yolov3.yaml │ ├── v5 │ │ ├── yolov5-p6.yaml │ │ └── yolov5.yaml │ ├── v6 │ │ └── yolov6.yaml │ └── v8 │ │ ├── yolov8-cls.yaml │ │ ├── yolov8-p2.yaml │ │ ├── yolov8-p6.yaml │ │ ├── yolov8-pose-p6.yaml │ │ ├── yolov8-pose.yaml │ │ ├── yolov8-rtdetr.yaml │ │ ├── yolov8-seg.yaml │ │ └── yolov8.yaml │ ├── nn │ ├── __init__.py │ ├── autobackend.py │ ├── autoshape.py │ ├── modules │ │ ├── __init__.py │ │ ├── block.py │ │ ├── conv.py │ │ ├── head.py │ │ ├── transformer.py │ │ └── utils.py │ └── tasks.py │ ├── tracker │ ├── README.md │ ├── __init__.py │ ├── cfg │ │ ├── botsort.yaml │ │ └── bytetrack.yaml │ ├── track.py │ ├── trackers │ │ ├── __init__.py │ │ ├── basetrack.py │ │ ├── bot_sort.py │ │ └── byte_tracker.py │ └── utils │ │ ├── __init__.py │ │ ├── gmc.py │ │ ├── kalman_filter.py │ │ └── matching.py │ ├── vit │ ├── __init__.py │ ├── rtdetr │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predict.py │ │ ├── train.py │ │ └── val.py │ ├── sam │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── autosize.py │ │ ├── build.py │ │ ├── model.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── encoders.py │ │ │ ├── mask_generator.py │ │ │ ├── prompt_predictor.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ └── predict.py │ └── utils │ │ ├── __init__.py │ │ ├── loss.py │ │ └── ops.py │ └── yolo │ ├── __init__.py │ ├── cfg │ ├── __init__.py │ └── default.yaml │ ├── data │ ├── __init__.py │ ├── annotator.py │ ├── augment.py │ ├── base.py │ ├── build.py │ ├── converter.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── stream_loaders.py │ │ ├── v5augmentations.py │ │ └── v5loader.py │ ├── dataset.py │ ├── dataset_wrappers.py │ ├── scripts │ │ ├── download_weights.sh │ │ ├── get_coco.sh │ │ ├── get_coco128.sh │ │ └── get_imagenet.sh │ └── utils.py │ ├── engine │ ├── __init__.py │ ├── exporter.py │ ├── model.py │ ├── predictor.py │ ├── results.py │ ├── trainer.py │ └── validator.py │ ├── nas │ ├── __init__.py │ ├── model.py │ ├── predict.py │ └── val.py │ ├── utils │ ├── __init__.py │ ├── autobatch.py │ ├── benchmarks.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── clearml.py │ │ ├── comet.py │ │ ├── dvc.py │ │ ├── hub.py │ │ ├── mlflow.py │ │ ├── neptune.py │ │ ├── raytune.py │ │ ├── tensorboard.py │ │ └── wb.py │ ├── checks.py │ ├── dist.py │ ├── downloads.py │ ├── errors.py │ ├── files.py │ ├── instance.py │ ├── loss.py │ ├── metrics.py │ ├── ops.py │ ├── patches.py │ ├── plotting.py │ ├── tal.py │ ├── torch_utils.py │ └── tuner.py │ └── v8 │ ├── __init__.py │ ├── classify │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ ├── detect │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ ├── pose │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ └── segment │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py ├── README.md ├── app ├── .gitattributes ├── README.md ├── app.py ├── assets │ ├── .DS_Store │ ├── picture1.jpg │ ├── picture2.jpg │ ├── picture3.jpg │ ├── picture4.jpg │ ├── picture5.jpg │ └── picture6.jpg ├── requirements.txt └── utils │ ├── __init__.py │ ├── tools.py │ └── tools_gradio.py ├── assets ├── logo2.png ├── mask_box.jpg ├── mask_comparision.jpg ├── mask_point.jpg ├── model_diagram.jpg ├── notebook1.png └── notebook2.png ├── linter.sh ├── mobile_sam ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── tiny_vit_sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── notebooks ├── automatic_mask_generator_example.ipynb ├── images │ ├── picture1.jpg │ └── picture2.jpg ├── onnx_model_example.ipynb └── predictor_example.ipynb ├── scripts ├── amg.py └── export_onnx_model.py ├── setup.cfg ├── setup.py └── weights └── mobile_sam.pt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.pyd 4 | __py 5 | **/__pycache__/ 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /Member.txt: -------------------------------------------------------------------------------- 1 | Manager: ChaoningZhang, ChoonSeon 2 | Contributors: Dongshenhan, Qiaoyu1002, dhkim2810, ksugar, killian31, yaimwing 3 | -------------------------------------------------------------------------------- /MobileSAMv2/PromptGuidedDecoder/Prompt_guided_Mask_Decoder.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/PromptGuidedDecoder/Prompt_guided_Mask_Decoder.pt -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/efficientvit/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/efficientvit/apps/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .augment import * 6 | from .base import * 7 | from .random_resolution import * 8 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/augment/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .bbox import * 6 | from .color_aug import * 7 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/augment/bbox.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import numpy as np 6 | 7 | __all__ = ["rand_bbox"] 8 | 9 | 10 | def rand_bbox( 11 | h: int, 12 | w: int, 13 | lam: float, 14 | rand_func: callable = np.random.uniform, 15 | ) -> tuple[int, int, int, int]: 16 | """randomly sample bbox, used in cutmix""" 17 | cut_rat = np.sqrt(1.0 - lam) 18 | cut_w = w * cut_rat 19 | cut_h = h * cut_rat 20 | 21 | # uniform 22 | cx = rand_func(0, w) 23 | cy = rand_func(0, h) 24 | 25 | bbx1 = int(np.clip(cx - cut_w / 2, 0, w)) 26 | bby1 = int(np.clip(cy - cut_h / 2, 0, h)) 27 | bbx2 = int(np.clip(cx + cut_w / 2, 0, w)) 28 | bby2 = int(np.clip(cy + cut_h / 2, 0, h)) 29 | 30 | return bbx1, bby1, bbx2, bby2 31 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/augment/color_aug.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from timm.data.auto_augment import rand_augment_transform 9 | 10 | __all__ = ["ColorAug", "RandAug"] 11 | 12 | 13 | class ImageAug: 14 | def aug_image(self, image: Image.Image) -> Image.Image: 15 | raise NotImplementedError 16 | 17 | def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: 18 | if isinstance(feed_dict, dict): 19 | output_dict = feed_dict 20 | image = feed_dict[self.key] 21 | else: 22 | output_dict = None 23 | image = feed_dict 24 | is_ndarray = isinstance(image, np.ndarray) 25 | if is_ndarray: 26 | image = Image.fromarray(image) 27 | 28 | image = self.aug_image(image) 29 | 30 | if is_ndarray: 31 | image = np.array(image) 32 | 33 | if output_dict is None: 34 | return image 35 | else: 36 | output_dict[self.key] = image 37 | return output_dict 38 | 39 | 40 | class ColorAug(transforms.ColorJitter, ImageAug): 41 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"): 42 | super().__init__( 43 | brightness=brightness, 44 | contrast=contrast, 45 | saturation=saturation, 46 | hue=hue, 47 | ) 48 | self.key = key 49 | 50 | def aug_image(self, image: Image.Image) -> Image.Image: 51 | return transforms.ColorJitter.forward(self, image) 52 | 53 | def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: 54 | return ImageAug.__call__(self, feed_dict) 55 | 56 | 57 | class RandAug(ImageAug): 58 | def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"): 59 | n = config.get("n", 2) 60 | m = config.get("m", 9) 61 | mstd = config.get("mstd", 1.0) 62 | inc = config.get("inc", 1) 63 | tpct = config.get("tpct", 0.45) 64 | config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}" 65 | 66 | aa_params = dict( 67 | translate_pct=tpct, 68 | img_mean=tuple([min(255, round(255 * x)) for x in mean]), 69 | interpolation=Image.BICUBIC, 70 | ) 71 | self.aug_op = rand_augment_transform(config_str, aa_params) 72 | self.key = key 73 | 74 | def aug_image(self, image: Image.Image) -> Image.Image: 75 | return self.aug_op(image) 76 | 77 | def __repr__(self): 78 | return self.aug_op.__repr__() 79 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/random_resolution/__init__.py: -------------------------------------------------------------------------------- 1 | """Random resolution data loader compatible with multi-processing and distributed training. 2 | 3 | Replace Pytorch's DataLoader with RRSDataLoader to support random resolution 4 | at the training time, resolution sampling is controlled by RRSController 5 | """ 6 | from .controller import * 7 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/data_provider/random_resolution/controller.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import copy 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | import torchvision.transforms.functional as F 10 | 11 | from efficientvit.models.utils import torch_random_choices 12 | 13 | __all__ = [ 14 | "RRSController", 15 | "get_interpolate", 16 | "MyRandomResizedCrop", 17 | ] 18 | 19 | 20 | class RRSController: 21 | ACTIVE_SIZE = (224, 224) 22 | IMAGE_SIZE_LIST = [(224, 224)] 23 | 24 | CHOICE_LIST = None 25 | 26 | @staticmethod 27 | def get_candidates() -> list[tuple[int, int]]: 28 | return copy.deepcopy(RRSController.IMAGE_SIZE_LIST) 29 | 30 | @staticmethod 31 | def sample_resolution(batch_id: int) -> None: 32 | RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id] 33 | 34 | @staticmethod 35 | def set_epoch(epoch: int, batch_per_epoch: int) -> None: 36 | g = torch.Generator() 37 | g.manual_seed(epoch) 38 | RRSController.CHOICE_LIST = torch_random_choices( 39 | RRSController.get_candidates(), 40 | g, 41 | batch_per_epoch, 42 | ) 43 | 44 | 45 | def get_interpolate(name: str) -> F.InterpolationMode: 46 | mapping = { 47 | "nearest": F.InterpolationMode.NEAREST, 48 | "bilinear": F.InterpolationMode.BILINEAR, 49 | "bicubic": F.InterpolationMode.BICUBIC, 50 | "box": F.InterpolationMode.BOX, 51 | "hamming": F.InterpolationMode.HAMMING, 52 | "lanczos": F.InterpolationMode.LANCZOS, 53 | } 54 | if name in mapping: 55 | return mapping[name] 56 | elif name == "random": 57 | return torch_random_choices( 58 | [ 59 | F.InterpolationMode.NEAREST, 60 | F.InterpolationMode.BILINEAR, 61 | F.InterpolationMode.BICUBIC, 62 | F.InterpolationMode.BOX, 63 | F.InterpolationMode.HAMMING, 64 | F.InterpolationMode.LANCZOS, 65 | ], 66 | ) 67 | else: 68 | raise NotImplementedError 69 | 70 | 71 | class MyRandomResizedCrop(transforms.RandomResizedCrop): 72 | def __init__( 73 | self, 74 | scale=(0.08, 1.0), 75 | ratio=(3.0 / 4.0, 4.0 / 3.0), 76 | interpolation: str = "random", 77 | ): 78 | super(MyRandomResizedCrop, self).__init__(224, scale, ratio) 79 | self.interpolation = interpolation 80 | 81 | def forward(self, img: torch.Tensor) -> torch.Tensor: 82 | i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio)) 83 | target_size = RRSController.ACTIVE_SIZE 84 | return F.resized_crop(img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)) 85 | 86 | def __repr__(self) -> str: 87 | format_string = self.__class__.__name__ 88 | format_string += f"(\n\tsize={RRSController.get_candidates()},\n" 89 | format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n" 90 | format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n" 91 | format_string += f"\tinterpolation={self.interpolation})" 92 | return format_string 93 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .base import * 6 | from .run_config import * 7 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .dist import * 6 | from .ema import * 7 | from .export import * 8 | from .init import * 9 | from .lr import * 10 | from .metric import * 11 | from .misc import * 12 | from .opt import * 13 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/dist.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | import torch.distributed 7 | from torchpack import distributed 8 | 9 | from efficientvit.models.utils.list import list_mean, list_sum 10 | 11 | __all__ = ["sync_tensor"] 12 | 13 | 14 | def sync_tensor(tensor: torch.Tensor or float, reduce="mean") -> torch.Tensor or list[torch.Tensor]: 15 | if not isinstance(tensor, torch.Tensor): 16 | tensor = torch.Tensor(1).fill_(tensor).cuda() 17 | tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())] 18 | torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) 19 | if reduce == "mean": 20 | return list_mean(tensor_list) 21 | elif reduce == "sum": 22 | return list_sum(tensor_list) 23 | elif reduce == "cat": 24 | return torch.cat(tensor_list, dim=0) 25 | elif reduce == "root": 26 | return tensor_list[0] 27 | else: 28 | return tensor_list 29 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/ema.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import copy 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from efficientvit.models.utils import is_parallel 12 | 13 | __all__ = ["EMA"] 14 | 15 | 16 | def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None: 17 | for k, v in ema.state_dict().items(): 18 | if v.dtype.is_floating_point: 19 | v -= (1.0 - decay) * (v - new_state_dict[k].detach()) 20 | 21 | 22 | class EMA: 23 | def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): 24 | self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval() 25 | self.decay = decay 26 | self.warmup_steps = warmup_steps 27 | 28 | for p in self.shadows.parameters(): 29 | p.requires_grad = False 30 | 31 | def step(self, model: nn.Module, global_step: int) -> None: 32 | with torch.no_grad(): 33 | msd = (model.module if is_parallel(model) else model).state_dict() 34 | update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps))) 35 | 36 | def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: 37 | return {self.decay: self.shadows.state_dict()} 38 | 39 | def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: 40 | for decay in state_dict: 41 | if decay == self.decay: 42 | self.shadows.load_state_dict(state_dict[decay]) 43 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/export.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import io 6 | import os 7 | 8 | import onnx 9 | import torch 10 | import torch.nn as nn 11 | from onnxsim import simplify as simplify_func 12 | 13 | __all__ = ["export_onnx"] 14 | 15 | 16 | def export_onnx(model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11) -> None: 17 | """Export a model to a platform-specific onnx format. 18 | 19 | Args: 20 | model: a torch.nn.Module object. 21 | export_path: export location. 22 | sample_inputs: Any. 23 | simplify: a flag to turn on onnx-simplifier 24 | opset: int 25 | """ 26 | model.eval() 27 | 28 | buffer = io.BytesIO() 29 | with torch.no_grad(): 30 | torch.onnx.export(model, sample_inputs, buffer, opset_version=opset) 31 | buffer.seek(0, 0) 32 | if simplify: 33 | onnx_model = onnx.load_model(buffer) 34 | onnx_model, success = simplify_func(onnx_model) 35 | assert success 36 | new_buffer = io.BytesIO() 37 | onnx.save(onnx_model, new_buffer) 38 | buffer = new_buffer 39 | buffer.seek(0, 0) 40 | 41 | if buffer.getbuffer().nbytes > 0: 42 | save_dir = os.path.dirname(export_path) 43 | os.makedirs(save_dir, exist_ok=True) 44 | with open(export_path, "wb") as f: 45 | f.write(buffer.read()) 46 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/init.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | __all__ = ["init_modules", "zero_last_gamma"] 10 | 11 | 12 | def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None: 13 | _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02} 14 | 15 | if isinstance(model, list): 16 | for sub_module in model: 17 | init_modules(sub_module, init_type) 18 | else: 19 | init_params = init_type.split("@") 20 | init_params = float(init_params[1]) if len(init_params) > 1 else None 21 | 22 | if init_type.startswith("trunc_normal"): 23 | init_func = lambda param: nn.init.trunc_normal_( 24 | param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"]) 25 | ) 26 | else: 27 | raise NotImplementedError 28 | 29 | for m in model.modules(): 30 | if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): 31 | init_func(m.weight) 32 | if m.bias is not None: 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.Embedding): 35 | init_func(m.weight) 36 | elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): 37 | m.weight.data.fill_(1) 38 | m.bias.data.zero_() 39 | else: 40 | weight = getattr(m, "weight", None) 41 | bias = getattr(m, "bias", None) 42 | if isinstance(weight, torch.nn.Parameter): 43 | init_func(weight) 44 | if isinstance(bias, torch.nn.Parameter): 45 | bias.data.zero_() 46 | 47 | 48 | def zero_last_gamma(model: nn.Module, init_val=0) -> None: 49 | import efficientvit.models.nn.ops as ops 50 | 51 | for m in model.modules(): 52 | if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer): 53 | if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)): 54 | parent_module = m.main.point_conv 55 | elif isinstance(m.main, ops.ResBlock): 56 | parent_module = m.main.conv2 57 | elif isinstance(m.main, ops.ConvLayer): 58 | parent_module = m.main 59 | elif isinstance(m.main, (ops.LiteMLA)): 60 | parent_module = m.main.proj 61 | else: 62 | parent_module = None 63 | if parent_module is not None: 64 | norm = getattr(parent_module, "norm", None) 65 | if norm is not None: 66 | nn.init.constant_(norm.weight, init_val) 67 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/lr.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import math 6 | 7 | import torch 8 | 9 | from efficientvit.models.utils.list import val2list 10 | 11 | __all__ = ["CosineLRwithWarmup"] 12 | 13 | 14 | class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer: torch.optim.Optimizer, 18 | warmup_steps: int, 19 | warmup_lr: float, 20 | decay_steps: int or list[int], 21 | last_epoch: int = -1, 22 | ) -> None: 23 | self.warmup_steps = warmup_steps 24 | self.warmup_lr = warmup_lr 25 | self.decay_steps = val2list(decay_steps) 26 | super().__init__(optimizer, last_epoch) 27 | 28 | def get_lr(self) -> list[float]: 29 | if self.last_epoch < self.warmup_steps: 30 | return [ 31 | (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr 32 | for base_lr in self.base_lrs 33 | ] 34 | else: 35 | current_steps = self.last_epoch - self.warmup_steps 36 | decay_steps = [0] + self.decay_steps 37 | idx = len(decay_steps) - 2 38 | for i, decay_step in enumerate(decay_steps[:-1]): 39 | if decay_step <= current_steps < decay_steps[i + 1]: 40 | idx = i 41 | break 42 | current_steps -= decay_steps[idx] 43 | decay_step = decay_steps[idx + 1] - decay_steps[idx] 44 | return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs] 45 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/metric.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | 7 | from efficientvit.apps.utils.dist import sync_tensor 8 | 9 | __all__ = ["AverageMeter"] 10 | 11 | 12 | class AverageMeter: 13 | """Computes and stores the average and current value.""" 14 | 15 | def __init__(self, is_distributed=True): 16 | self.is_distributed = is_distributed 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float: 21 | return sync_tensor(val, reduce="sum") if self.is_distributed else val 22 | 23 | def update(self, val: torch.Tensor or int or float, delta_n=1): 24 | self.count += self._sync(delta_n) 25 | self.sum += self._sync(val * delta_n) 26 | 27 | def get_count(self) -> torch.Tensor or int or float: 28 | return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count 29 | 30 | @property 31 | def avg(self): 32 | avg = -1 if self.count == 0 else self.sum / self.count 33 | return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg 34 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/misc.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import os 6 | 7 | import yaml 8 | 9 | __all__ = [ 10 | "parse_with_yaml", 11 | "parse_unknown_args", 12 | "partial_update_config", 13 | "resolve_and_load_config", 14 | "load_config", 15 | "dump_config", 16 | ] 17 | 18 | 19 | def parse_with_yaml(config_str: str) -> str or dict: 20 | try: 21 | # add space manually for dict 22 | if "{" in config_str and "}" in config_str and ":" in config_str: 23 | out_str = config_str.replace(":", ": ") 24 | else: 25 | out_str = config_str 26 | return yaml.safe_load(out_str) 27 | except ValueError: 28 | # return raw string if parsing fails 29 | return config_str 30 | 31 | 32 | def parse_unknown_args(unknown: list) -> dict: 33 | """Parse unknown args.""" 34 | index = 0 35 | parsed_dict = {} 36 | while index < len(unknown): 37 | key, val = unknown[index], unknown[index + 1] 38 | index += 2 39 | if not key.startswith("--"): 40 | continue 41 | key = key[2:] 42 | 43 | # try parsing with either dot notation or full yaml notation 44 | # Note that the vanilla case "--key value" will be parsed the same 45 | if "." in key: 46 | # key == a.b.c, val == val --> parsed_dict[a][b][c] = val 47 | keys = key.split(".") 48 | dict_to_update = parsed_dict 49 | for key in keys[:-1]: 50 | if not (key in dict_to_update and isinstance(dict_to_update[key], dict)): 51 | dict_to_update[key] = {} 52 | dict_to_update = dict_to_update[key] 53 | dict_to_update[keys[-1]] = parse_with_yaml(val) # so we can parse lists, bools, etc... 54 | else: 55 | parsed_dict[key] = parse_with_yaml(val) 56 | return parsed_dict 57 | 58 | 59 | def partial_update_config(config: dict, partial_config: dict) -> dict: 60 | for key in partial_config: 61 | if key in config and isinstance(partial_config[key], dict) and isinstance(config[key], dict): 62 | partial_update_config(config[key], partial_config[key]) 63 | else: 64 | config[key] = partial_config[key] 65 | return config 66 | 67 | 68 | def resolve_and_load_config(path: str, config_name="config.yaml") -> dict: 69 | path = os.path.realpath(os.path.expanduser(path)) 70 | if os.path.isdir(path): 71 | config_path = os.path.join(path, config_name) 72 | else: 73 | config_path = path 74 | if os.path.isfile(config_path): 75 | pass 76 | else: 77 | raise Exception(f"Cannot find a valid config at {path}") 78 | config = load_config(config_path) 79 | return config 80 | 81 | 82 | class SafeLoaderWithTuple(yaml.SafeLoader): 83 | """A yaml safe loader with python tuple loading capabilities.""" 84 | 85 | def construct_python_tuple(self, node): 86 | return tuple(self.construct_sequence(node)) 87 | 88 | 89 | SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple) 90 | 91 | 92 | def load_config(filename: str) -> dict: 93 | """Load a yaml file.""" 94 | filename = os.path.realpath(os.path.expanduser(filename)) 95 | return yaml.load(open(filename), Loader=SafeLoaderWithTuple) 96 | 97 | 98 | def dump_config(config: dict, filename: str) -> None: 99 | """Dump a config file""" 100 | filename = os.path.realpath(os.path.expanduser(filename)) 101 | yaml.dump(config, open(filename, "w"), sort_keys=False) 102 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/apps/utils/opt.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | 7 | __all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"] 8 | 9 | # register optimizer here 10 | # name: optimizer, kwargs with default values 11 | REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = { 12 | "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}), 13 | "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), 14 | "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), 15 | } 16 | 17 | 18 | def build_optimizer( 19 | net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float 20 | ) -> torch.optim.Optimizer: 21 | optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name] 22 | optimizer_params = optimizer_params or {} 23 | 24 | for key in default_params: 25 | if key in optimizer_params: 26 | default_params[key] = optimizer_params[key] 27 | optimizer = optimizer_class(net_params, init_lr, **default_params) 28 | return optimizer 29 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/cls_model_zoo.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from efficientvit.models.efficientvit import ( 6 | EfficientViTCls, 7 | efficientvit_cls_b0, 8 | efficientvit_cls_b1, 9 | efficientvit_cls_b2, 10 | efficientvit_cls_b3, 11 | efficientvit_cls_l1, 12 | efficientvit_cls_l2, 13 | efficientvit_cls_l3, 14 | ) 15 | from efficientvit.models.nn.norm import set_norm_eps 16 | from efficientvit.models.utils import load_state_dict_from_file 17 | 18 | __all__ = ["create_cls_model"] 19 | 20 | 21 | REGISTERED_CLS_MODEL: dict[str, str] = { 22 | "b0-r224": "assets/checkpoints/cls/b0-r224.pt", 23 | ############################################### 24 | "b1-r224": "assets/checkpoints/cls/b1-r224.pt", 25 | "b1-r256": "assets/checkpoints/cls/b1-r256.pt", 26 | "b1-r288": "assets/checkpoints/cls/b1-r288.pt", 27 | ############################################### 28 | "b2-r224": "assets/checkpoints/cls/b2-r224.pt", 29 | "b2-r256": "assets/checkpoints/cls/b2-r256.pt", 30 | "b2-r288": "assets/checkpoints/cls/b2-r288.pt", 31 | ############################################### 32 | "b3-r224": "assets/checkpoints/cls/b3-r224.pt", 33 | "b3-r256": "assets/checkpoints/cls/b3-r256.pt", 34 | "b3-r288": "assets/checkpoints/cls/b3-r288.pt", 35 | ############################################### 36 | "l1-r224": "assets/checkpoints/cls/l1-r224.pt", 37 | ############################################### 38 | "l2-r224": "assets/checkpoints/cls/l2-r224.pt", 39 | "l2-r256": "assets/checkpoints/cls/l2-r256.pt", 40 | "l2-r288": "assets/checkpoints/cls/l2-r288.pt", 41 | "l2-r320": "assets/checkpoints/cls/l2-r320.pt", 42 | "l2-r384": "assets/checkpoints/cls/l2-r384.pt", 43 | ############################################### 44 | "l3-r224": "assets/checkpoints/cls/l3-r224.pt", 45 | "l3-r256": "assets/checkpoints/cls/l3-r256.pt", 46 | "l3-r288": "assets/checkpoints/cls/l3-r288.pt", 47 | "l3-r320": "assets/checkpoints/cls/l3-r320.pt", 48 | "l3-r384": "assets/checkpoints/cls/l3-r384.pt", 49 | } 50 | 51 | 52 | def create_cls_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTCls: 53 | model_dict = { 54 | "b0": efficientvit_cls_b0, 55 | "b1": efficientvit_cls_b1, 56 | "b2": efficientvit_cls_b2, 57 | "b3": efficientvit_cls_b3, 58 | ######################### 59 | "l1": efficientvit_cls_l1, 60 | "l2": efficientvit_cls_l2, 61 | "l3": efficientvit_cls_l3, 62 | } 63 | 64 | model_id = name.split("-")[0] 65 | if model_id not in model_dict: 66 | raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") 67 | else: 68 | model = model_dict[model_id](**kwargs) 69 | if model_id in ["l1", "l2", "l3"]: 70 | set_norm_eps(model, 1e-7) 71 | 72 | if pretrained: 73 | weight_url = weight_url or REGISTERED_CLS_MODEL.get(name, None) 74 | if weight_url is None: 75 | raise ValueError(f"Do not find the pretrained weight of {name}.") 76 | else: 77 | weight = load_state_dict_from_file(weight_url) 78 | model.load_state_dict(weight) 79 | return model 80 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/efficientvit/clscore/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .imagenet import * 6 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .cls_run_config import * 6 | from .cls_trainer import * 7 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/cls_run_config.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from efficientvit.apps.trainer.run_config import RunConfig 6 | 7 | __all__ = ["ClsRunConfig"] 8 | 9 | 10 | class ClsRunConfig(RunConfig): 11 | label_smooth: float 12 | mixup_config: dict # allow none to turn off mixup 13 | bce: bool 14 | mesa: dict 15 | 16 | @property 17 | def none_allowed(self): 18 | return ["mixup_config", "mesa"] + super().none_allowed 19 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .label_smooth import * 6 | from .metric import * 7 | from .mixup import * 8 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/utils/label_smooth.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | 7 | __all__ = ["label_smooth"] 8 | 9 | 10 | def label_smooth(target: torch.Tensor, n_classes: int, smooth_factor=0.1) -> torch.Tensor: 11 | # convert to one-hot 12 | batch_size = target.shape[0] 13 | target = torch.unsqueeze(target, 1) 14 | soft_target = torch.zeros((batch_size, n_classes), device=target.device) 15 | soft_target.scatter_(1, target, 1) 16 | # label smoothing 17 | soft_target = torch.add(soft_target * (1 - smooth_factor), smooth_factor / n_classes) 18 | return soft_target 19 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/utils/metric.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch 6 | 7 | __all__ = ["accuracy"] 8 | 9 | 10 | def accuracy(output: torch.Tensor, target: torch.Tensor, topk=(1,)) -> list[torch.Tensor]: 11 | """Computes the precision@k for the specified values of k.""" 12 | maxk = max(topk) 13 | batch_size = target.shape[0] 14 | 15 | _, pred = output.topk(maxk, 1, True, True) 16 | pred = pred.t() 17 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 18 | 19 | res = [] 20 | for k in topk: 21 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 22 | res.append(correct_k.mul_(100.0 / batch_size)) 23 | return res 24 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/clscore/trainer/utils/mixup.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import torch.distributions 6 | 7 | from efficientvit.apps.data_provider.augment import rand_bbox 8 | from efficientvit.models.utils.random import torch_randint, torch_shuffle 9 | 10 | __all__ = ["apply_mixup", "mixup", "cutmix"] 11 | 12 | 13 | def apply_mixup( 14 | images: torch.Tensor, 15 | labels: torch.Tensor, 16 | lam: float, 17 | mix_type="mixup", 18 | ) -> tuple[torch.Tensor, torch.Tensor]: 19 | if mix_type == "mixup": 20 | return mixup(images, labels, lam) 21 | elif mix_type == "cutmix": 22 | return cutmix(images, labels, lam) 23 | else: 24 | raise NotImplementedError 25 | 26 | 27 | def mixup( 28 | images: torch.Tensor, 29 | target: torch.Tensor, 30 | lam: float, 31 | ) -> tuple[torch.Tensor, torch.Tensor]: 32 | rand_index = torch_shuffle(list(range(0, images.shape[0]))) 33 | 34 | flipped_images = images[rand_index] 35 | flipped_target = target[rand_index] 36 | 37 | return ( 38 | lam * images + (1 - lam) * flipped_images, 39 | lam * target + (1 - lam) * flipped_target, 40 | ) 41 | 42 | 43 | def cutmix( 44 | images: torch.Tensor, 45 | target: torch.Tensor, 46 | lam: float, 47 | ) -> tuple[torch.Tensor, torch.Tensor]: 48 | rand_index = torch_shuffle(list(range(0, images.shape[0]))) 49 | 50 | flipped_images = images[rand_index] 51 | flipped_target = target[rand_index] 52 | 53 | b, _, h, w = images.shape 54 | lam_list = [] 55 | for i in range(b): 56 | bbx1, bby1, bbx2, bby2 = rand_bbox( 57 | h=h, 58 | w=w, 59 | lam=lam, 60 | rand_func=torch_randint, 61 | ) 62 | images[i, :, bby1:bby2, bbx1:bbx2] = flipped_images[i, :, bby1:bby2, bbx1:bbx2] 63 | lam_list.append(1 - ((bbx2 - bbx1) * (bby2 - bby1) / (h * w))) 64 | lam = torch.Tensor(lam_list).to(images.device).view(b, 1) 65 | return images, lam * target + (1 - lam) * flipped_target 66 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/efficientvit/models/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/efficientvit/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .backbone import * 6 | from .cls import * 7 | from .sam import * 8 | from .seg import * 9 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .act import * 6 | from .drop import * 7 | from .norm import * 8 | from .ops import * 9 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/nn/act.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from functools import partial 6 | 7 | import torch.nn as nn 8 | 9 | from efficientvit.models.utils import build_kwargs_from_config 10 | 11 | __all__ = ["build_act"] 12 | 13 | 14 | # register activation function here 15 | REGISTERED_ACT_DICT: dict[str, type] = { 16 | "relu": nn.ReLU, 17 | "relu6": nn.ReLU6, 18 | "hswish": nn.Hardswish, 19 | "silu": nn.SiLU, 20 | "gelu": partial(nn.GELU, approximate="tanh"), 21 | } 22 | 23 | 24 | def build_act(name: str, **kwargs) -> nn.Module or None: 25 | if name in REGISTERED_ACT_DICT: 26 | act_cls = REGISTERED_ACT_DICT[name] 27 | args = build_kwargs_from_config(kwargs, act_cls) 28 | return act_cls(**args) 29 | else: 30 | return None 31 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/nn/drop.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from efficientvit.apps.trainer.run_config import Scheduler 10 | from efficientvit.models.nn.ops import IdentityLayer, ResidualBlock 11 | from efficientvit.models.utils import build_kwargs_from_config 12 | 13 | __all__ = ["apply_drop_func"] 14 | 15 | 16 | def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None: 17 | if drop_config is None: 18 | return 19 | 20 | drop_lookup_table = { 21 | "droppath": apply_droppath, 22 | } 23 | 24 | drop_func = drop_lookup_table[drop_config["name"]] 25 | drop_kwargs = build_kwargs_from_config(drop_config, drop_func) 26 | 27 | drop_func(network, **drop_kwargs) 28 | 29 | 30 | def apply_droppath( 31 | network: nn.Module, 32 | drop_prob: float, 33 | linear_decay=True, 34 | scheduled=True, 35 | skip=0, 36 | ) -> None: 37 | all_valid_blocks = [] 38 | for m in network.modules(): 39 | for name, sub_module in m.named_children(): 40 | if isinstance(sub_module, ResidualBlock) and isinstance(sub_module.shortcut, IdentityLayer): 41 | all_valid_blocks.append((m, name, sub_module)) 42 | all_valid_blocks = all_valid_blocks[skip:] 43 | for i, (m, name, sub_module) in enumerate(all_valid_blocks): 44 | prob = drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob 45 | new_module = DropPathResidualBlock( 46 | sub_module.main, 47 | sub_module.shortcut, 48 | sub_module.post_act, 49 | sub_module.pre_norm, 50 | prob, 51 | scheduled, 52 | ) 53 | m._modules[name] = new_module 54 | 55 | 56 | class DropPathResidualBlock(ResidualBlock): 57 | def __init__( 58 | self, 59 | main: nn.Module, 60 | shortcut: nn.Module or None, 61 | post_act=None, 62 | pre_norm: nn.Module or None = None, 63 | ###################################### 64 | drop_prob: float = 0, 65 | scheduled=True, 66 | ): 67 | super().__init__(main, shortcut, post_act, pre_norm) 68 | 69 | self.drop_prob = drop_prob 70 | self.scheduled = scheduled 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | if not self.training or self.drop_prob == 0 or not isinstance(self.shortcut, IdentityLayer): 74 | return ResidualBlock.forward(self, x) 75 | else: 76 | drop_prob = self.drop_prob 77 | if self.scheduled: 78 | drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1) 79 | keep_prob = 1 - drop_prob 80 | 81 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 82 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 83 | random_tensor.floor_() # binarize 84 | 85 | res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x) 86 | if self.post_act: 87 | res = self.post_act(res) 88 | return res 89 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from .list import * 6 | from .network import * 7 | from .random import * 8 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/utils/list.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | __all__ = [ 6 | "list_sum", 7 | "list_mean", 8 | "weighted_list_sum", 9 | "list_join", 10 | "val2list", 11 | "val2tuple", 12 | "squeeze_list", 13 | ] 14 | 15 | 16 | def list_sum(x: list) -> any: 17 | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) 18 | 19 | 20 | def list_mean(x: list) -> any: 21 | return list_sum(x) / len(x) 22 | 23 | 24 | def weighted_list_sum(x: list, weights: list) -> any: 25 | assert len(x) == len(weights) 26 | return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) 27 | 28 | 29 | def list_join(x: list, sep="\t", format_str="%s") -> str: 30 | return sep.join([format_str % val for val in x]) 31 | 32 | 33 | def val2list(x: list or tuple or any, repeat_time=1) -> list: 34 | if isinstance(x, (list, tuple)): 35 | return list(x) 36 | return [x for _ in range(repeat_time)] 37 | 38 | 39 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: 40 | x = val2list(x) 41 | 42 | # repeat elements if necessary 43 | if len(x) > 0: 44 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 45 | 46 | return tuple(x) 47 | 48 | 49 | def squeeze_list(x: list or None) -> list or any: 50 | if x is not None and len(x) == 1: 51 | return x[0] 52 | else: 53 | return x 54 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/utils/network.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import os 6 | from inspect import signature 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | __all__ = [ 13 | "is_parallel", 14 | "get_device", 15 | "get_same_padding", 16 | "resize", 17 | "build_kwargs_from_config", 18 | "load_state_dict_from_file", 19 | ] 20 | 21 | 22 | def is_parallel(model: nn.Module) -> bool: 23 | return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) 24 | 25 | 26 | def get_device(model: nn.Module) -> torch.device: 27 | return model.parameters().__next__().device 28 | 29 | 30 | def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: 31 | if isinstance(kernel_size, tuple): 32 | return tuple([get_same_padding(ks) for ks in kernel_size]) 33 | else: 34 | assert kernel_size % 2 > 0, "kernel size should be odd number" 35 | return kernel_size // 2 36 | 37 | 38 | def resize( 39 | x: torch.Tensor, 40 | size: any or None = None, 41 | scale_factor: list[float] or None = None, 42 | mode: str = "bicubic", 43 | align_corners: bool or None = False, 44 | ) -> torch.Tensor: 45 | if mode in {"bilinear", "bicubic"}: 46 | return F.interpolate( 47 | x, 48 | size=size, 49 | scale_factor=scale_factor, 50 | mode=mode, 51 | align_corners=align_corners, 52 | ) 53 | elif mode in {"nearest", "area"}: 54 | return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) 55 | else: 56 | raise NotImplementedError(f"resize(mode={mode}) not implemented.") 57 | 58 | 59 | def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]: 60 | valid_keys = list(signature(target_func).parameters) 61 | kwargs = {} 62 | for key in config: 63 | if key in valid_keys: 64 | kwargs[key] = config[key] 65 | return kwargs 66 | 67 | 68 | def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]: 69 | file = os.path.realpath(os.path.expanduser(file)) 70 | checkpoint = torch.load(file, map_location="cpu") 71 | if only_state_dict and "state_dict" in checkpoint: 72 | checkpoint = checkpoint["state_dict"] 73 | return checkpoint 74 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/models/utils/random.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | import numpy as np 6 | import torch 7 | 8 | __all__ = [ 9 | "torch_randint", 10 | "torch_random", 11 | "torch_shuffle", 12 | "torch_uniform", 13 | "torch_random_choices", 14 | ] 15 | 16 | 17 | def torch_randint(low: int, high: int, generator: torch.Generator or None = None) -> int: 18 | """uniform: [low, high)""" 19 | if low == high: 20 | return low 21 | else: 22 | assert low < high 23 | return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) 24 | 25 | 26 | def torch_random(generator: torch.Generator or None = None) -> float: 27 | """uniform distribution on the interval [0, 1)""" 28 | return float(torch.rand(1, generator=generator)) 29 | 30 | 31 | def torch_shuffle(src_list: list[any], generator: torch.Generator or None = None) -> list[any]: 32 | rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() 33 | return [src_list[i] for i in rand_indexes] 34 | 35 | 36 | def torch_uniform(low: float, high: float, generator: torch.Generator or None = None) -> float: 37 | """uniform distribution on the interval [low, high)""" 38 | rand_val = torch_random(generator) 39 | return (high - low) * rand_val + low 40 | 41 | 42 | def torch_random_choices( 43 | src_list: list[any], 44 | generator: torch.Generator or None = None, 45 | k=1, 46 | weight_list: list[float] or None = None, 47 | ) -> any or list: 48 | if weight_list is None: 49 | rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) 50 | out_list = [src_list[i] for i in rand_idx] 51 | else: 52 | assert len(weight_list) == len(src_list) 53 | accumulate_weight_list = np.cumsum(weight_list) 54 | 55 | out_list = [] 56 | for _ in range(k): 57 | val = torch_uniform(0, accumulate_weight_list[-1], generator) 58 | active_id = 0 59 | for i, weight_val in enumerate(accumulate_weight_list): 60 | active_id = i 61 | if weight_val > val: 62 | break 63 | out_list.append(src_list[active_id]) 64 | 65 | return out_list[0] if k == 1 else out_list 66 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/sam_model_zoo.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from efficientvit.models.efficientvit import ( 6 | EfficientViTSam, 7 | efficientvit_sam_l0, 8 | efficientvit_sam_l1, 9 | efficientvit_sam_l2, 10 | ) 11 | from efficientvit.models.nn.norm import set_norm_eps 12 | from efficientvit.models.utils import load_state_dict_from_file 13 | 14 | __all__ = ["create_sam_model"] 15 | 16 | 17 | REGISTERED_SAM_MODEL: dict[str, str] = { 18 | "l0": "assets/checkpoints/sam/l0.pt", 19 | "l1": "assets/checkpoints/sam/l1.pt", 20 | "l2": "assets/checkpoints/sam/l2.pt", 21 | } 22 | 23 | 24 | def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam: 25 | model_dict = { 26 | "l0": efficientvit_sam_l0, 27 | "l1": efficientvit_sam_l1, 28 | "l2": efficientvit_sam_l2, 29 | } 30 | 31 | model_id = name.split("-")[0] 32 | if model_id not in model_dict: 33 | raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") 34 | else: 35 | model = model_dict[model_id](**kwargs) 36 | set_norm_eps(model, 1e-6) 37 | 38 | if pretrained: 39 | weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None) 40 | if weight_url is None: 41 | raise ValueError(f"Do not find the pretrained weight of {name}.") 42 | else: 43 | weight = load_state_dict_from_file(weight_url) 44 | model.load_state_dict(weight) 45 | return model 46 | -------------------------------------------------------------------------------- /MobileSAMv2/efficientvit/seg_model_zoo.py: -------------------------------------------------------------------------------- 1 | # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction 2 | # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han 3 | # International Conference on Computer Vision (ICCV), 2023 4 | 5 | from efficientvit.models.efficientvit import ( 6 | EfficientViTSeg, 7 | efficientvit_seg_b0, 8 | efficientvit_seg_b1, 9 | efficientvit_seg_b2, 10 | efficientvit_seg_b3, 11 | efficientvit_seg_l1, 12 | efficientvit_seg_l2, 13 | ) 14 | from efficientvit.models.nn.norm import set_norm_eps 15 | from efficientvit.models.utils import load_state_dict_from_file 16 | 17 | __all__ = ["create_seg_model"] 18 | 19 | 20 | REGISTERED_SEG_MODEL: dict[str, dict[str, str]] = { 21 | "cityscapes": { 22 | "b0": "assets/checkpoints/seg/cityscapes/b0.pt", 23 | "b1": "assets/checkpoints/seg/cityscapes/b1.pt", 24 | "b2": "assets/checkpoints/seg/cityscapes/b2.pt", 25 | "b3": "assets/checkpoints/seg/cityscapes/b3.pt", 26 | ################################################ 27 | "l1": "assets/checkpoints/seg/cityscapes/l1.pt", 28 | "l2": "assets/checkpoints/seg/cityscapes/l2.pt", 29 | }, 30 | "ade20k": { 31 | "b1": "assets/checkpoints/seg/ade20k/b1.pt", 32 | "b2": "assets/checkpoints/seg/ade20k/b2.pt", 33 | "b3": "assets/checkpoints/seg/ade20k/b3.pt", 34 | ################################################ 35 | "l1": "assets/checkpoints/seg/ade20k/l1.pt", 36 | "l2": "assets/checkpoints/seg/ade20k/l2.pt", 37 | }, 38 | } 39 | 40 | 41 | def create_seg_model( 42 | name: str, dataset: str, pretrained=True, weight_url: str or None = None, **kwargs 43 | ) -> EfficientViTSeg: 44 | model_dict = { 45 | "b0": efficientvit_seg_b0, 46 | "b1": efficientvit_seg_b1, 47 | "b2": efficientvit_seg_b2, 48 | "b3": efficientvit_seg_b3, 49 | ######################### 50 | "l1": efficientvit_seg_l1, 51 | "l2": efficientvit_seg_l2, 52 | } 53 | 54 | model_id = name.split("-")[0] 55 | if model_id not in model_dict: 56 | raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") 57 | else: 58 | model = model_dict[model_id](dataset=dataset, **kwargs) 59 | 60 | if model_id in ["l1", "l2"]: 61 | set_norm_eps(model, 1e-7) 62 | 63 | if pretrained: 64 | weight_url = weight_url or REGISTERED_SEG_MODEL[dataset].get(name, None) 65 | if weight_url is None: 66 | raise ValueError(f"Do not find the pretrained weight of {name}.") 67 | else: 68 | weight = load_state_dict_from_file(weight_url) 69 | model.load_state_dict(weight) 70 | return model 71 | -------------------------------------------------------------------------------- /MobileSAMv2/experiments/mobilesamv2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python Inference.py \ 2 | --img_path './test_images/' \ 3 | --output_dir './' \ 4 | --encoder_type 'efficientvit_l2' \ -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam_vit_h, 9 | build_sam_vit_l, 10 | build_sam_vit_b, 11 | sam_model_registry, 12 | ) 13 | from .predictor import SamPredictor 14 | from .automatic_mask_generator import SamAutomaticMaskGenerator 15 | 16 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/promt_mobilesamv2/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import ObjectAwareModel 4 | from .predict import PromptModelPredictor 5 | 6 | __all__ = 'ObjectAwareModel', 'PromptModelPredictor' 7 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/promt_mobilesamv2/model.py: -------------------------------------------------------------------------------- 1 | from ultralytics.yolo.cfg import get_cfg 2 | from ultralytics.yolo.engine.exporter import Exporter 3 | from ultralytics.yolo.engine.model import YOLO 4 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir 5 | 6 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode 7 | from .predict import PromptModelPredictor 8 | 9 | 10 | class ObjectAwareModel(YOLO): 11 | 12 | @smart_inference_mode() 13 | def predict(self, source=None, stream=False, **kwargs): 14 | """ 15 | Perform prediction using the YOLO model. 16 | 17 | Args: 18 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 19 | Accepts all source types accepted by the YOLO model. 20 | stream (bool): Whether to stream the predictions or not. Defaults to False. 21 | **kwargs : Additional keyword arguments passed to the predictor. 22 | Check the 'configuration' section in the documentation for all available options. 23 | 24 | Returns: 25 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 26 | """ 27 | if source is None: 28 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 29 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 30 | overrides = self.overrides.copy() 31 | overrides['conf'] = 0.25 32 | overrides.update(kwargs) # prefer kwargs 33 | overrides['mode'] = kwargs.get('mode', 'predict') 34 | assert overrides['mode'] in ['track', 'predict'] 35 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python 36 | self.predictor = PromptModelPredictor(overrides=overrides) 37 | 38 | self.predictor.setup_model(model=self.model, verbose=False) 39 | 40 | try: 41 | 42 | return self.predictor(source, stream=stream) 43 | except Exception as e: 44 | return None 45 | 46 | def train(self, **kwargs): 47 | raise NotImplementedError("Currently, the training codes are on the way.") 48 | 49 | @smart_inference_mode() 50 | def export(self, **kwargs): 51 | """ 52 | Export model. 53 | 54 | Args: 55 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 56 | """ 57 | overrides = dict(task='detect') 58 | overrides.update(kwargs) 59 | overrides['mode'] = 'export' 60 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 61 | args.task = self.task 62 | if args.imgsz == DEFAULT_CFG.imgsz: 63 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 64 | if args.batch == DEFAULT_CFG.batch: 65 | args.batch = 1 # default to 1 if not modified 66 | return Exporter(overrides=args)(model=self.model) 67 | 68 | def info(self, detailed=False, verbose=True): 69 | """ 70 | Logs model info. 71 | 72 | Args: 73 | detailed (bool): Show detailed information about model. 74 | verbose (bool): Controls verbosity. 75 | """ 76 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 77 | 78 | def __call__(self, source=None, stream=False, **kwargs): 79 | """Calls the 'predict' function with given arguments to perform object detection.""" 80 | return self.predict(source, stream, **kwargs) 81 | 82 | def __getattr__(self, attr): 83 | """Raises error if object has no requested attribute.""" 84 | name = self.__class__.__name__ 85 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 86 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/promt_mobilesamv2/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ultralytics.yolo.engine.results import Results 3 | from ultralytics.yolo.utils import DEFAULT_CFG, ops 4 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 5 | 6 | class PromptModelPredictor(DetectionPredictor): 7 | 8 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 9 | super().__init__(cfg, overrides, _callbacks) 10 | self.args.task = 'segment' 11 | def adjust_bboxes_to_image_border(self, boxes, image_shape, threshold=20): 12 | h, w = image_shape 13 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor( 14 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1 15 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor( 16 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1 17 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor( 18 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2 19 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor( 20 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2 21 | return boxes 22 | def postprocess(self, preds, img, orig_imgs): 23 | p = ops.non_max_suppression(preds[0], 24 | self.args.conf, 25 | self.args.iou, 26 | agnostic=self.args.agnostic_nms, 27 | max_det=self.args.max_det, 28 | nc=len(self.model.names), 29 | classes=self.args.classes) 30 | results = [] 31 | if len(p) == 0 or len(p[0]) == 0: 32 | print("No object detected.") 33 | return results 34 | full_box = torch.zeros_like(p[0][0]) 35 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 36 | full_box = full_box.view(1, -1) 37 | self.adjust_bboxes_to_image_border(p[0][:, :4], img.shape[2:]) 38 | for i, pred in enumerate(p): 39 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 40 | path = self.batch[0] 41 | img_path = path[i] if isinstance(path, list) else path 42 | if not len(pred): 43 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 44 | continue 45 | if self.args.retina_masks: 46 | if not isinstance(orig_imgs, torch.Tensor): 47 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 48 | else: 49 | if not isinstance(orig_imgs, torch.Tensor): 50 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 51 | results.append( 52 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=torch.zeros_like(img))) 53 | return results 54 | -------------------------------------------------------------------------------- /MobileSAMv2/mobilesamv2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /MobileSAMv2/test_images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/test_images/1.jpg -------------------------------------------------------------------------------- /MobileSAMv2/test_images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/test_images/2.jpg -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | __version__ = '8.0.120' 4 | 5 | from ultralytics.hub import start 6 | from ultralytics.vit.rtdetr import RTDETR 7 | from ultralytics.vit.sam import SAM 8 | from ultralytics.yolo.engine.model import YOLO 9 | from ultralytics.yolo.nas import NAS 10 | from ultralytics.yolo.utils.checks import check_yolo as checks 11 | 12 | __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import 13 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/assets/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/ultralytics/assets/bus.jpg -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/assets/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/ultralytics/assets/zidane.jpg -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/README.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration 4 | files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted 5 | and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image 6 | segmentation tasks. 7 | 8 | These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like 9 | instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, 10 | from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this 11 | directory provides a great starting point for your custom model development needs. 12 | 13 | To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've 14 | selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full 15 | details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free 16 | to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now! 17 | 18 | ### Usage 19 | 20 | Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command: 21 | 22 | ```bash 23 | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 24 | ``` 25 | 26 | They may also be used directly in a Python environment, and accepts the same 27 | [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: 28 | 29 | ```python 30 | from ultralytics import YOLO 31 | 32 | model = YOLO("model.yaml") # build a YOLOv8n model from scratch 33 | # YOLO("model.pt") use pre-trained model if available 34 | model.info() # display model information 35 | model.train(data="coco128.yaml", epochs=100) # train the model 36 | ``` 37 | 38 | ## Pre-trained Model Architectures 39 | 40 | Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information 41 | and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available. 42 | 43 | ## Contributing New Models 44 | 45 | If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing). 46 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/rt-detr/rtdetr-l.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | l: [1.00, 1.00, 1024] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 48]] # 0-P2/4 13 | - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 17 | 18 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 19 | - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut 20 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] 21 | - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 22 | 23 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 24 | - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 25 | 26 | head: 27 | - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 28 | - [-1, 1, AIFI, [1024, 8]] 29 | - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 30 | 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 33 | - [[-2, -1], 1, Concat, [1]] 34 | - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 35 | - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 39 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 40 | - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 41 | 42 | - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 43 | - [[-1, 17], 1, Concat, [1]] # cat Y4 44 | - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 47 | - [[-1, 12], 1, Concat, [1]] # cat Y5 48 | - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 49 | 50 | - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 51 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/rt-detr/rtdetr-x.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | x: [1.00, 1.00, 2048] 9 | 10 | backbone: 11 | # [from, repeats, module, args] 12 | - [-1, 1, HGStem, [32, 64]] # 0-P2/4 13 | - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 14 | 15 | - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 16 | - [-1, 6, HGBlock, [128, 512, 3]] 17 | - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 18 | 19 | - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 20 | - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut 21 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 22 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 23 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] 24 | - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 25 | 26 | - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 27 | - [-1, 6, HGBlock, [512, 2048, 5, True, False]] 28 | - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 29 | 30 | head: 31 | - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 32 | - [-1, 1, AIFI, [2048, 8]] 33 | - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 37 | - [[-2, -1], 1, Concat, [1]] 38 | - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 39 | - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 43 | - [[-2, -1], 1, Concat, [1]] # cat backbone P4 44 | - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 45 | 46 | - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 47 | - [[-1, 21], 1, Concat, [1]] # cat Y4 48 | - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 49 | 50 | - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 51 | - [[-1, 16], 1, Concat, [1]] # cat Y5 52 | - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 53 | 54 | - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 55 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v3/yolov3-spp.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3-SPP head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, SPP, [512, [5, 9, 13]]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v3/yolov3-tiny.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # YOLOv3-tiny backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [16, 3, 1]], # 0 13 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 14 | [-1, 1, Conv, [32, 3, 1]], 15 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 16 | [-1, 1, Conv, [64, 3, 1]], 17 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 18 | [-1, 1, Conv, [128, 3, 1]], 19 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 20 | [-1, 1, Conv, [256, 3, 1]], 21 | [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 22 | [-1, 1, Conv, [512, 3, 1]], 23 | [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11 24 | [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 25 | ] 26 | 27 | # YOLOv3-tiny head 28 | head: 29 | [[-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [256, 1, 1]], 31 | [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [128, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) 37 | 38 | [[19, 15], 1, Detect, [nc]], # Detect(P4, P5) 39 | ] 40 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v3/yolov3.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | depth_multiple: 1.0 # model depth multiple 7 | width_multiple: 1.0 # layer channel multiple 8 | 9 | # darknet53 backbone 10 | backbone: 11 | # [from, number, module, args] 12 | [[-1, 1, Conv, [32, 3, 1]], # 0 13 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 14 | [-1, 1, Bottleneck, [64]], 15 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 16 | [-1, 2, Bottleneck, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 18 | [-1, 8, Bottleneck, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 20 | [-1, 8, Bottleneck, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 22 | [-1, 4, Bottleneck, [1024]], # 10 23 | ] 24 | 25 | # YOLOv3 head 26 | head: 27 | [[-1, 1, Bottleneck, [1024, False]], 28 | [-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, Conv, [1024, 3, 1]], 30 | [-1, 1, Conv, [512, 1, 1]], 31 | [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) 32 | 33 | [-2, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P4 36 | [-1, 1, Bottleneck, [512, False]], 37 | [-1, 1, Bottleneck, [512, False]], 38 | [-1, 1, Conv, [256, 1, 1]], 39 | [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) 40 | 41 | [-2, 1, Conv, [128, 1, 1]], 42 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 43 | [[-1, 6], 1, Concat, [1]], # cat backbone P3 44 | [-1, 1, Bottleneck, [256, False]], 45 | [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) 46 | 47 | [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5) 48 | ] 49 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v5/yolov5-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [768]], 26 | [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 27 | [-1, 3, C3, [1024]], 28 | [-1, 1, SPPF, [1024, 5]], # 11 29 | ] 30 | 31 | # YOLOv5 v6.0 head 32 | head: 33 | [[-1, 1, Conv, [768, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 8], 1, Concat, [1]], # cat backbone P5 36 | [-1, 3, C3, [768, False]], # 15 37 | 38 | [-1, 1, Conv, [512, 1, 1]], 39 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 40 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 41 | [-1, 3, C3, [512, False]], # 19 42 | 43 | [-1, 1, Conv, [256, 1, 1]], 44 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 45 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 46 | [-1, 3, C3, [256, False]], # 23 (P3/8-small) 47 | 48 | [-1, 1, Conv, [256, 3, 2]], 49 | [[-1, 20], 1, Concat, [1]], # cat head P4 50 | [-1, 3, C3, [512, False]], # 26 (P4/16-medium) 51 | 52 | [-1, 1, Conv, [512, 3, 2]], 53 | [[-1, 16], 1, Concat, [1]], # cat head P5 54 | [-1, 3, C3, [768, False]], # 29 (P5/32-large) 55 | 56 | [-1, 1, Conv, [768, 3, 2]], 57 | [[-1, 12], 1, Concat, [1]], # cat head P6 58 | [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) 59 | 60 | [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) 61 | ] 62 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v5/yolov5.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.33, 1.25, 1024] 13 | 14 | # YOLOv5 v6.0 backbone 15 | backbone: 16 | # [from, number, module, args] 17 | [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 18 | [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 19 | [-1, 3, C3, [128]], 20 | [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 21 | [-1, 6, C3, [256]], 22 | [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 23 | [-1, 9, C3, [512]], 24 | [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 25 | [-1, 3, C3, [1024]], 26 | [-1, 1, SPPF, [1024, 5]], # 9 27 | ] 28 | 29 | # YOLOv5 v6.0 head 30 | head: 31 | [[-1, 1, Conv, [512, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 6], 1, Concat, [1]], # cat backbone P4 34 | [-1, 3, C3, [512, False]], # 13 35 | 36 | [-1, 1, Conv, [256, 1, 1]], 37 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 38 | [[-1, 4], 1, Concat, [1]], # cat backbone P3 39 | [-1, 3, C3, [256, False]], # 17 (P3/8-small) 40 | 41 | [-1, 1, Conv, [256, 3, 2]], 42 | [[-1, 14], 1, Concat, [1]], # cat head P4 43 | [-1, 3, C3, [512, False]], # 20 (P4/16-medium) 44 | 45 | [-1, 1, Conv, [512, 3, 2]], 46 | [[-1, 10], 1, Concat, [1]], # cat head P5 47 | [-1, 3, C3, [1024, False]], # 23 (P5/32-large) 48 | 49 | [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5) 50 | ] 51 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v6/yolov6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | activation: nn.ReLU() # (optional) model default activation function 7 | scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv6-3.0s backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 6, Conv, [128, 3, 1]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 12, Conv, [256, 3, 1]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 18, Conv, [512, 3, 1]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 6, Conv, [1024, 3, 1]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv6-3.0s head 30 | head: 31 | - [-1, 1, Conv, [256, 1, 1]] 32 | - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] 33 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 34 | - [-1, 1, Conv, [256, 3, 1]] 35 | - [-1, 9, Conv, [256, 3, 1]] # 14 36 | 37 | - [-1, 1, Conv, [128, 1, 1]] 38 | - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] 39 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 40 | - [-1, 1, Conv, [128, 3, 1]] 41 | - [-1, 9, Conv, [128, 3, 1]] # 19 42 | 43 | - [-1, 1, Conv, [128, 3, 2]] 44 | - [[-1, 15], 1, Concat, [1]] # cat head P4 45 | - [-1, 1, Conv, [256, 3, 1]] 46 | - [-1, 9, Conv, [256, 3, 1]] # 23 47 | 48 | - [-1, 1, Conv, [256, 3, 2]] 49 | - [[-1, 10], 1, Concat, [1]] # cat head P5 50 | - [-1, 1, Conv, [512, 3, 1]] 51 | - [-1, 9, Conv, [512, 3, 1]] # 27 52 | 53 | - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) 54 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-cls.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify 3 | 4 | # Parameters 5 | nc: 1000 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 1024] 11 | l: [1.00, 1.00, 1024] 12 | x: [1.00, 1.25, 1024] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | 27 | # YOLOv8.0n head 28 | head: 29 | - [-1, 1, Classify, [nc]] # Classify 30 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-p2.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0-p2 head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 39 | - [[-1, 2], 1, Concat, [1]] # cat backbone P2 40 | - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) 41 | 42 | - [-1, 1, Conv, [128, 3, 2]] 43 | - [[-1, 15], 1, Concat, [1]] # cat head P3 44 | - [-1, 3, C2f, [256]] # 21 (P3/8-small) 45 | 46 | - [-1, 1, Conv, [256, 3, 2]] 47 | - [[-1, 12], 1, Concat, [1]] # cat head P4 48 | - [-1, 3, C2f, [512]] # 24 (P4/16-medium) 49 | 50 | - [-1, 1, Conv, [512, 3, 2]] 51 | - [[-1, 9], 1, Concat, [1]] # cat head P5 52 | - [-1, 3, C2f, [1024]] # 27 (P5/32-large) 53 | 54 | - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) 55 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0x6 backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [768, True]] 26 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 27 | - [-1, 3, C2f, [1024, True]] 28 | - [-1, 1, SPPF, [1024, 5]] # 11 29 | 30 | # YOLOv8.0x6 head 31 | head: 32 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 33 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 34 | - [-1, 3, C2, [768, False]] # 14 35 | 36 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 37 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 38 | - [-1, 3, C2, [512, False]] # 17 39 | 40 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 41 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 42 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 43 | 44 | - [-1, 1, Conv, [256, 3, 2]] 45 | - [[-1, 17], 1, Concat, [1]] # cat head P4 46 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 47 | 48 | - [-1, 1, Conv, [512, 3, 2]] 49 | - [[-1, 14], 1, Concat, [1]] # cat head P5 50 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 51 | 52 | - [-1, 1, Conv, [768, 3, 2]] 53 | - [[-1, 11], 1, Concat, [1]] # cat head P6 54 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 55 | 56 | - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) 57 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-pose-p6.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0x6 backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [768, True]] 27 | - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 28 | - [-1, 3, C2f, [1024, True]] 29 | - [-1, 1, SPPF, [1024, 5]] # 11 30 | 31 | # YOLOv8.0x6 head 32 | head: 33 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 34 | - [[-1, 8], 1, Concat, [1]] # cat backbone P5 35 | - [-1, 3, C2, [768, False]] # 14 36 | 37 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 38 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 39 | - [-1, 3, C2, [512, False]] # 17 40 | 41 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 42 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 43 | - [-1, 3, C2, [256, False]] # 20 (P3/8-small) 44 | 45 | - [-1, 1, Conv, [256, 3, 2]] 46 | - [[-1, 17], 1, Concat, [1]] # cat head P4 47 | - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) 48 | 49 | - [-1, 1, Conv, [512, 3, 2]] 50 | - [[-1, 14], 1, Concat, [1]] # cat head P5 51 | - [-1, 3, C2, [768, False]] # 26 (P5/32-large) 52 | 53 | - [-1, 1, Conv, [768, 3, 2]] 54 | - [[-1, 11], 1, Concat, [1]] # cat head P6 55 | - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) 56 | 57 | - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) 58 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-pose.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose 3 | 4 | # Parameters 5 | nc: 1 # number of classes 6 | kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) 7 | scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' 8 | # [depth, width, max_channels] 9 | n: [0.33, 0.25, 1024] 10 | s: [0.33, 0.50, 1024] 11 | m: [0.67, 0.75, 768] 12 | l: [1.00, 1.00, 512] 13 | x: [1.00, 1.25, 512] 14 | 15 | # YOLOv8.0n backbone 16 | backbone: 17 | # [from, repeats, module, args] 18 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 19 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 20 | - [-1, 3, C2f, [128, True]] 21 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 22 | - [-1, 6, C2f, [256, True]] 23 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 24 | - [-1, 6, C2f, [512, True]] 25 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 26 | - [-1, 3, C2f, [1024, True]] 27 | - [-1, 1, SPPF, [1024, 5]] # 9 28 | 29 | # YOLOv8.0n head 30 | head: 31 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 32 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 33 | - [-1, 3, C2f, [512]] # 12 34 | 35 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 36 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 37 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 38 | 39 | - [-1, 1, Conv, [256, 3, 2]] 40 | - [[-1, 12], 1, Concat, [1]] # cat head P4 41 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 42 | 43 | - [-1, 1, Conv, [512, 3, 2]] 44 | - [[-1, 9], 1, Concat, [1]] # cat head P5 45 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 46 | 47 | - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) 48 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-rtdetr.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8-seg.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] 9 | s: [0.33, 0.50, 1024] 10 | m: [0.67, 0.75, 768] 11 | l: [1.00, 1.00, 512] 12 | x: [1.00, 1.25, 512] 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/models/v8/yolov8.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect 3 | 4 | # Parameters 5 | nc: 80 # number of classes 6 | scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' 7 | # [depth, width, max_channels] 8 | n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs 9 | s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs 10 | m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs 11 | l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs 12 | x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs 13 | 14 | # YOLOv8.0n backbone 15 | backbone: 16 | # [from, repeats, module, args] 17 | - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 18 | - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 19 | - [-1, 3, C2f, [128, True]] 20 | - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 21 | - [-1, 6, C2f, [256, True]] 22 | - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 23 | - [-1, 6, C2f, [512, True]] 24 | - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 25 | - [-1, 3, C2f, [1024, True]] 26 | - [-1, 1, SPPF, [1024, 5]] # 9 27 | 28 | # YOLOv8.0n head 29 | head: 30 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 31 | - [[-1, 6], 1, Concat, [1]] # cat backbone P4 32 | - [-1, 3, C2f, [512]] # 12 33 | 34 | - [-1, 1, nn.Upsample, [None, 2, 'nearest']] 35 | - [[-1, 4], 1, Concat, [1]] # cat backbone P3 36 | - [-1, 3, C2f, [256]] # 15 (P3/8-small) 37 | 38 | - [-1, 1, Conv, [256, 3, 2]] 39 | - [[-1, 12], 1, Concat, [1]] # cat head P4 40 | - [-1, 3, C2f, [512]] # 18 (P4/16-medium) 41 | 42 | - [-1, 1, Conv, [512, 3, 2]] 43 | - [[-1, 9], 1, Concat, [1]] # cat head P5 44 | - [-1, 3, C2f, [1024]] # 21 (P5/32-large) 45 | 46 | - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) 47 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, 4 | attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load, 5 | yaml_model_load) 6 | 7 | __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task', 8 | 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel', 9 | 'BaseModel') 10 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Ultralytics modules. Visualize with: 4 | 5 | from ultralytics.nn.modules import * 6 | import torch 7 | import os 8 | 9 | x = torch.ones(1, 128, 40, 40) 10 | m = Conv(128, 128) 11 | f = f'{m._get_name()}.onnx' 12 | torch.onnx.export(m, x, f) 13 | os.system(f'onnxsim {f} {f} && open {f}') 14 | """ 15 | 16 | from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, 17 | HGBlock, HGStem, Proto, RepC3) 18 | from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, 19 | GhostConv, LightConv, RepConv, SpatialAttention) 20 | from .head import Classify, Detect, Pose, RTDETRDecoder, Segment 21 | from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, 22 | MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) 23 | 24 | __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 25 | 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 26 | 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 27 | 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 28 | 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', 29 | 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP') 30 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/nn/modules/utils.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Module utils 4 | """ 5 | 6 | import copy 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn.init import uniform_ 14 | 15 | __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' 16 | 17 | 18 | def _get_clones(module, n): 19 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 20 | 21 | 22 | def bias_init_with_prob(prior_prob=0.01): 23 | """initialize conv/fc bias value according to a given probability value.""" 24 | return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init 25 | 26 | 27 | def linear_init_(module): 28 | bound = 1 / math.sqrt(module.weight.shape[0]) 29 | uniform_(module.weight, -bound, bound) 30 | if hasattr(module, 'bias') and module.bias is not None: 31 | uniform_(module.bias, -bound, bound) 32 | 33 | 34 | def inverse_sigmoid(x, eps=1e-5): 35 | x = x.clamp(min=0, max=1) 36 | x1 = x.clamp(min=eps) 37 | x2 = (1 - x).clamp(min=eps) 38 | return torch.log(x1 / x2) 39 | 40 | 41 | def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, 42 | sampling_locations: torch.Tensor, 43 | attention_weights: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Multi-scale deformable attention. 46 | https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py 47 | """ 48 | 49 | bs, _, num_heads, embed_dims = value.shape 50 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape 51 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 52 | sampling_grids = 2 * sampling_locations - 1 53 | sampling_value_list = [] 54 | for level, (H_, W_) in enumerate(value_spatial_shapes): 55 | # bs, H_*W_, num_heads, embed_dims -> 56 | # bs, H_*W_, num_heads*embed_dims -> 57 | # bs, num_heads*embed_dims, H_*W_ -> 58 | # bs*num_heads, embed_dims, H_, W_ 59 | value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) 60 | # bs, num_queries, num_heads, num_points, 2 -> 61 | # bs, num_heads, num_queries, num_points, 2 -> 62 | # bs*num_heads, num_queries, num_points, 2 63 | sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) 64 | # bs*num_heads, embed_dims, num_queries, num_points 65 | sampling_value_l_ = F.grid_sample(value_l_, 66 | sampling_grid_l_, 67 | mode='bilinear', 68 | padding_mode='zeros', 69 | align_corners=False) 70 | sampling_value_list.append(sampling_value_l_) 71 | # (bs, num_queries, num_heads, num_levels, num_points) -> 72 | # (bs, num_heads, num_queries, num_levels, num_points) -> 73 | # (bs, num_heads, 1, num_queries, num_levels*num_points) 74 | attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, 75 | num_levels * num_points) 76 | output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( 77 | bs, num_heads * embed_dims, num_queries)) 78 | return output.transpose(1, 2).contiguous() 79 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/README.md: -------------------------------------------------------------------------------- 1 | # Tracker 2 | 3 | ## Supported Trackers 4 | 5 | - [x] ByteTracker 6 | - [x] BoT-SORT 7 | 8 | ## Usage 9 | 10 | ### python interface: 11 | 12 | You can use the Python interface to track objects using the YOLO model. 13 | 14 | ```python 15 | from ultralytics import YOLO 16 | 17 | model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt 18 | model.track( 19 | source="video/streams", 20 | stream=True, 21 | tracker="botsort.yaml", # or 'bytetrack.yaml' 22 | show=True, 23 | ) 24 | ``` 25 | 26 | You can get the IDs of the tracked objects using the following code: 27 | 28 | ```python 29 | from ultralytics import YOLO 30 | 31 | model = YOLO("yolov8n.pt") 32 | 33 | for result in model.track(source="video.mp4"): 34 | print( 35 | result.boxes.id.cpu().numpy().astype(int) 36 | ) # this will print the IDs of the tracked objects in the frame 37 | ``` 38 | 39 | If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking. 40 | 41 | ```python 42 | import cv2 43 | from ultralytics import YOLO 44 | 45 | cap = cv2.VideoCapture("video.mp4") 46 | model = YOLO("yolov8n.pt") 47 | while True: 48 | ret, frame = cap.read() 49 | if not ret: 50 | break 51 | results = model.track(frame, persist=True) 52 | boxes = results[0].boxes.xyxy.cpu().numpy().astype(int) 53 | ids = results[0].boxes.id.cpu().numpy().astype(int) 54 | for box, id in zip(boxes, ids): 55 | cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) 56 | cv2.putText( 57 | frame, 58 | f"Id {id}", 59 | (box[0], box[1]), 60 | cv2.FONT_HERSHEY_SIMPLEX, 61 | 1, 62 | (0, 0, 255), 63 | 2, 64 | ) 65 | cv2.imshow("frame", frame) 66 | if cv2.waitKey(1) & 0xFF == ord("q"): 67 | break 68 | ``` 69 | 70 | ## Change tracker parameters 71 | 72 | You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/tracker/cfg folder. 73 | 74 | ## Command Line Interface (CLI) 75 | 76 | You can also use the command line interface to track objects using the YOLO model. 77 | 78 | ```bash 79 | yolo detect track source=... tracker=... 80 | yolo segment track source=... tracker=... 81 | yolo pose track source=... tracker=... 82 | ``` 83 | 84 | By default, trackers will use the configuration in `ultralytics/tracker/cfg`. 85 | We also support using a modified tracker config file. Please refer to the tracker config files 86 | in `ultralytics/tracker/cfg`.
87 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .track import register_tracker 4 | from .trackers import BOTSORT, BYTETracker 5 | 6 | __all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import 7 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/cfg/botsort.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT 3 | 4 | tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | 13 | # BoT-SORT settings 14 | cmc_method: sparseOptFlow # method of global motion compensation 15 | # ReID model related thresh (not supported yet) 16 | proximity_thresh: 0.5 17 | appearance_thresh: 0.25 18 | with_reid: False 19 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/cfg/bytetrack.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack 3 | 4 | tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] 5 | track_high_thresh: 0.5 # threshold for the first association 6 | track_low_thresh: 0.1 # threshold for the second association 7 | new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks 8 | track_buffer: 30 # buffer to calculate the time when to remove tracks 9 | match_thresh: 0.8 # threshold for matching tracks 10 | # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) 11 | # mot20: False # for tracker evaluation(not used for now) 12 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/track.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from functools import partial 4 | 5 | import torch 6 | 7 | from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load 8 | from ultralytics.yolo.utils.checks import check_yaml 9 | 10 | from .trackers import BOTSORT, BYTETracker 11 | 12 | TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT} 13 | 14 | 15 | def on_predict_start(predictor, persist=False): 16 | """ 17 | Initialize trackers for object tracking during prediction. 18 | 19 | Args: 20 | predictor (object): The predictor object to initialize trackers for. 21 | persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. 22 | 23 | Raises: 24 | AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. 25 | """ 26 | if hasattr(predictor, 'trackers') and persist: 27 | return 28 | tracker = check_yaml(predictor.args.tracker) 29 | cfg = IterableSimpleNamespace(**yaml_load(tracker)) 30 | assert cfg.tracker_type in ['bytetrack', 'botsort'], \ 31 | f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" 32 | trackers = [] 33 | for _ in range(predictor.dataset.bs): 34 | tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) 35 | trackers.append(tracker) 36 | predictor.trackers = trackers 37 | 38 | 39 | def on_predict_postprocess_end(predictor): 40 | """Postprocess detected boxes and update with object tracking.""" 41 | bs = predictor.dataset.bs 42 | im0s = predictor.batch[1] 43 | for i in range(bs): 44 | det = predictor.results[i].boxes.cpu().numpy() 45 | if len(det) == 0: 46 | continue 47 | tracks = predictor.trackers[i].update(det, im0s[i]) 48 | if len(tracks) == 0: 49 | continue 50 | idx = tracks[:, -1].astype(int) 51 | predictor.results[i] = predictor.results[i][idx] 52 | predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) 53 | 54 | 55 | def register_tracker(model, persist): 56 | """ 57 | Register tracking callbacks to the model for object tracking during prediction. 58 | 59 | Args: 60 | model (object): The model object to register tracking callbacks for. 61 | persist (bool): Whether to persist the trackers if they already exist. 62 | 63 | """ 64 | model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) 65 | model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end) 66 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .bot_sort import BOTSORT 4 | from .byte_tracker import BYTETracker 5 | 6 | __all__ = 'BOTSORT', 'BYTETracker' # allow simpler import 7 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/trackers/basetrack.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | 7 | 8 | class TrackState: 9 | """Enumeration of possible object tracking states.""" 10 | 11 | New = 0 12 | Tracked = 1 13 | Lost = 2 14 | Removed = 3 15 | 16 | 17 | class BaseTrack: 18 | """Base class for object tracking, handling basic track attributes and operations.""" 19 | 20 | _count = 0 21 | 22 | track_id = 0 23 | is_activated = False 24 | state = TrackState.New 25 | 26 | history = OrderedDict() 27 | features = [] 28 | curr_feature = None 29 | score = 0 30 | start_frame = 0 31 | frame_id = 0 32 | time_since_update = 0 33 | 34 | # Multi-camera 35 | location = (np.inf, np.inf) 36 | 37 | @property 38 | def end_frame(self): 39 | """Return the last frame ID of the track.""" 40 | return self.frame_id 41 | 42 | @staticmethod 43 | def next_id(): 44 | """Increment and return the global track ID counter.""" 45 | BaseTrack._count += 1 46 | return BaseTrack._count 47 | 48 | def activate(self, *args): 49 | """Activate the track with the provided arguments.""" 50 | raise NotImplementedError 51 | 52 | def predict(self): 53 | """Predict the next state of the track.""" 54 | raise NotImplementedError 55 | 56 | def update(self, *args, **kwargs): 57 | """Update the track with new observations.""" 58 | raise NotImplementedError 59 | 60 | def mark_lost(self): 61 | """Mark the track as lost.""" 62 | self.state = TrackState.Lost 63 | 64 | def mark_removed(self): 65 | """Mark the track as removed.""" 66 | self.state = TrackState.Removed 67 | 68 | @staticmethod 69 | def reset_id(): 70 | """Reset the global track ID counter.""" 71 | BaseTrack._count = 0 72 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/tracker/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/ultralytics/tracker/utils/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .rtdetr import RTDETR 4 | from .sam import SAM 5 | 6 | __all__ = 'RTDETR', 'SAM' # allow simpler import 7 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/rtdetr/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import RTDETR 4 | from .predict import RTDETRPredictor 5 | from .val import RTDETRValidator 6 | 7 | __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR' 8 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/rtdetr/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.data.augment import LetterBox 6 | from ultralytics.yolo.engine.predictor import BasePredictor 7 | from ultralytics.yolo.engine.results import Results 8 | from ultralytics.yolo.utils import ops 9 | 10 | 11 | class RTDETRPredictor(BasePredictor): 12 | 13 | def postprocess(self, preds, img, orig_imgs): 14 | """Postprocess predictions and returns a list of Results objects.""" 15 | bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) 16 | bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) 17 | results = [] 18 | for i, bbox in enumerate(bboxes): # (300, 4) 19 | bbox = ops.xywh2xyxy(bbox) 20 | score, cls = scores[i].max(-1, keepdim=True) # (300, 1) 21 | idx = score.squeeze(-1) > self.args.conf # (300, ) 22 | if self.args.classes is not None: 23 | idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx 24 | pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter 25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 26 | oh, ow = orig_img.shape[:2] 27 | if not isinstance(orig_imgs, torch.Tensor): 28 | pred[..., [0, 2]] *= ow 29 | pred[..., [1, 3]] *= oh 30 | path = self.batch[0] 31 | img_path = path[i] if isinstance(path, list) else path 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 33 | return results 34 | 35 | def pre_transform(self, im): 36 | """Pre-transform input image before inference. 37 | 38 | Args: 39 | im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. 40 | 41 | Return: A list of transformed imgs. 42 | """ 43 | # The size must be square(640) and scaleFilled. 44 | return [LetterBox(self.imgsz, auto=False, scaleFill=True)(image=x) for x in im] 45 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/rtdetr/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | import torch 6 | 7 | from ultralytics.nn.tasks import RTDETRDetectionModel 8 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr 9 | from ultralytics.yolo.v8.detect import DetectionTrainer 10 | 11 | from .val import RTDETRDataset, RTDETRValidator 12 | 13 | 14 | class RTDETRTrainer(DetectionTrainer): 15 | 16 | def get_model(self, cfg=None, weights=None, verbose=True): 17 | """Return a YOLO detection model.""" 18 | model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 19 | if weights: 20 | model.load(weights) 21 | return model 22 | 23 | def build_dataset(self, img_path, mode='val', batch=None): 24 | """Build RTDETR Dataset 25 | 26 | Args: 27 | img_path (str): Path to the folder containing images. 28 | mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. 29 | batch (int, optional): Size of batches, this is for `rect`. Defaults to None. 30 | """ 31 | return RTDETRDataset( 32 | img_path=img_path, 33 | imgsz=self.args.imgsz, 34 | batch_size=batch, 35 | augment=mode == 'train', # no augmentation 36 | hyp=self.args, 37 | rect=False, # no rect 38 | cache=self.args.cache or None, 39 | prefix=colorstr(f'{mode}: '), 40 | data=self.data) 41 | 42 | def get_validator(self): 43 | """Returns a DetectionValidator for RTDETR model validation.""" 44 | self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss' 45 | return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 46 | 47 | def preprocess_batch(self, batch): 48 | """Preprocesses a batch of images by scaling and converting to float.""" 49 | batch = super().preprocess_batch(batch) 50 | bs = len(batch['img']) 51 | batch_idx = batch['batch_idx'] 52 | gt_bbox, gt_class = [], [] 53 | for i in range(bs): 54 | gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device)) 55 | gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) 56 | return batch 57 | 58 | 59 | def train(cfg=DEFAULT_CFG, use_python=False): 60 | """Train and optimize RTDETR model given training data and device.""" 61 | model = 'rtdetr-l.yaml' 62 | data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") 63 | device = cfg.device if cfg.device is not None else '' 64 | 65 | # NOTE: F.grid_sample which is in rt-detr does not support deterministic=True 66 | # NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching 67 | args = dict(model=model, 68 | data=data, 69 | device=device, 70 | imgsz=640, 71 | exist_ok=True, 72 | batch=4, 73 | deterministic=False, 74 | amp=False) 75 | trainer = RTDETRTrainer(overrides=args) 76 | trainer.train() 77 | 78 | 79 | if __name__ == '__main__': 80 | train() 81 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .build import build_sam # noqa 4 | from .model import SAM # noqa 5 | from .modules.prompt_predictor import PromptPredictor # noqa 6 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/sam/autosize.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from copy import deepcopy 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | import torch 14 | from torch.nn import functional as F 15 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 16 | 17 | 18 | class ResizeLongestSide: 19 | """ 20 | Resizes images to the longest side 'target_length', as well as provides 21 | methods for resizing coordinates and boxes. Provides methods for 22 | transforming both numpy array and batched torch tensors. 23 | """ 24 | 25 | def __init__(self, target_length: int) -> None: 26 | self.target_length = target_length 27 | 28 | def apply_image(self, image: np.ndarray) -> np.ndarray: 29 | """ 30 | Expects a numpy array with shape HxWxC in uint8 format. 31 | """ 32 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 33 | return np.array(resize(to_pil_image(image), target_size)) 34 | 35 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 36 | """ 37 | Expects a numpy array of length 2 in the final dimension. Requires the 38 | original image size in (H, W) format. 39 | """ 40 | old_h, old_w = original_size 41 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True) 64 | 65 | def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 66 | """ 67 | Expects a torch tensor with length 2 in the last dimension. Requires the 68 | original image size in (H, W) format. 69 | """ 70 | old_h, old_w = original_size 71 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 72 | coords = deepcopy(coords).to(torch.float) 73 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 74 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 75 | return coords 76 | 77 | def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 78 | """ 79 | Expects a torch tensor with shape Bx4. Requires the original image 80 | size in (H, W) format. 81 | """ 82 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 83 | return boxes.reshape(-1, 4) 84 | 85 | @staticmethod 86 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 87 | """ 88 | Compute the output size given input size and target long side length. 89 | """ 90 | scale = long_side_length * 1.0 / max(oldh, oldw) 91 | newh, neww = oldh * scale, oldw * scale 92 | neww = int(neww + 0.5) 93 | newh = int(newh + 0.5) 94 | return (newh, neww) 95 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/sam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | SAM model interface 4 | """ 5 | 6 | from ultralytics.yolo.cfg import get_cfg 7 | 8 | from ...yolo.utils.torch_utils import model_info 9 | from .build import build_sam 10 | from .predict import Predictor 11 | 12 | 13 | class SAM: 14 | 15 | def __init__(self, model='sam_b.pt') -> None: 16 | if model and not model.endswith('.pt') and not model.endswith('.pth'): 17 | # Should raise AssertionError instead? 18 | raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') 19 | self.model = build_sam(model) 20 | self.task = 'segment' # required 21 | self.predictor = None # reuse predictor 22 | 23 | def predict(self, source, stream=False, **kwargs): 24 | """Predicts and returns segmentation masks for given image or video source.""" 25 | overrides = dict(conf=0.25, task='segment', mode='predict') 26 | overrides.update(kwargs) # prefer kwargs 27 | if not self.predictor: 28 | self.predictor = Predictor(overrides=overrides) 29 | self.predictor.setup_model(model=self.model) 30 | else: # only update args if predictor is already setup 31 | self.predictor.args = get_cfg(self.predictor.args, overrides) 32 | return self.predictor(source, stream=stream) 33 | 34 | def train(self, **kwargs): 35 | """Function trains models but raises an error as SAM models do not support training.""" 36 | raise NotImplementedError("SAM models don't support training") 37 | 38 | def val(self, **kwargs): 39 | """Run validation given dataset.""" 40 | raise NotImplementedError("SAM models don't support validation") 41 | 42 | def __call__(self, source=None, stream=False, **kwargs): 43 | """Calls the 'predict' function with given arguments to perform object detection.""" 44 | return self.predict(source, stream, **kwargs) 45 | 46 | def __getattr__(self, attr): 47 | """Raises error if object has no requested attribute.""" 48 | name = self.__class__.__name__ 49 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 50 | 51 | def info(self, detailed=False, verbose=True): 52 | """ 53 | Logs model info. 54 | 55 | Args: 56 | detailed (bool): Show detailed information about model. 57 | verbose (bool): Controls verbosity. 58 | """ 59 | return model_info(self.model, detailed=detailed, verbose=verbose) 60 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/sam/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/sam/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ultralytics.yolo.engine.predictor import BasePredictor 7 | from ultralytics.yolo.engine.results import Results 8 | from ultralytics.yolo.utils.torch_utils import select_device 9 | 10 | from .modules.mask_generator import SamAutomaticMaskGenerator 11 | 12 | 13 | class Predictor(BasePredictor): 14 | 15 | def preprocess(self, im): 16 | """Prepares input image for inference.""" 17 | # TODO: Only support bs=1 for now 18 | # im = ResizeLongestSide(1024).apply_image(im[0]) 19 | # im = torch.as_tensor(im, device=self.device) 20 | # im = im.permute(2, 0, 1).contiguous()[None, :, :, :] 21 | return im[0] 22 | 23 | def setup_model(self, model): 24 | """Set up YOLO model with specified thresholds and device.""" 25 | device = select_device(self.args.device) 26 | model.eval() 27 | self.model = SamAutomaticMaskGenerator(model.to(device), 28 | pred_iou_thresh=self.args.conf, 29 | box_nms_thresh=self.args.iou) 30 | self.device = device 31 | # TODO: Temporary settings for compatibility 32 | self.model.pt = False 33 | self.model.triton = False 34 | self.model.stride = 32 35 | self.model.fp16 = False 36 | self.done_warmup = True 37 | 38 | def postprocess(self, preds, path, orig_imgs): 39 | """Postprocesses inference output predictions to create detection masks for objects.""" 40 | names = dict(enumerate(list(range(len(preds))))) 41 | results = [] 42 | # TODO 43 | for i, pred in enumerate([preds]): 44 | masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0)) 45 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 46 | path = self.batch[0] 47 | img_path = path[i] if isinstance(path, list) else path 48 | results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks)) 49 | return results 50 | 51 | # def __call__(self, source=None, model=None, stream=False): 52 | # frame = cv2.imread(source) 53 | # preds = self.model.generate(frame) 54 | # return self.postprocess(preds, source, frame) 55 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/vit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from . import v8 4 | 5 | __all__ = 'v8', # tuple or list 6 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import BaseDataset 4 | from .build import build_dataloader, build_yolo_dataset, load_inference_source 5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset 6 | from .dataset_wrappers import MixAndRectDataset 7 | 8 | __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset', 9 | 'build_yolo_dataset', 'build_dataloader', 'load_inference_source') 10 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/data/annotator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from ultralytics import YOLO 4 | from ultralytics.vit.sam import PromptPredictor, build_sam 5 | from ultralytics.yolo.utils.torch_utils import select_device 6 | 7 | 8 | def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): 9 | """ 10 | Automatically annotates images using a YOLO object detection model and a SAM segmentation model. 11 | Args: 12 | data (str): Path to a folder containing images to be annotated. 13 | det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. 14 | sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. 15 | device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available). 16 | output_dir (str | None | optional): Directory to save the annotated results. 17 | Defaults to a 'labels' folder in the same directory as 'data'. 18 | """ 19 | device = select_device(device) 20 | det_model = YOLO(det_model) 21 | sam_model = build_sam(sam_model) 22 | det_model.to(device) 23 | sam_model.to(device) 24 | 25 | if not output_dir: 26 | output_dir = Path(str(data)).parent / 'labels' 27 | Path(output_dir).mkdir(exist_ok=True, parents=True) 28 | 29 | prompt_predictor = PromptPredictor(sam_model) 30 | det_results = det_model(data, stream=True) 31 | 32 | for result in det_results: 33 | boxes = result.boxes.xyxy # Boxes object for bbox outputs 34 | class_ids = result.boxes.cls.int().tolist() # noqa 35 | if len(class_ids): 36 | prompt_predictor.set_image(result.orig_img) 37 | masks, _, _ = prompt_predictor.predict_torch( 38 | point_coords=None, 39 | point_labels=None, 40 | boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]), 41 | multimask_output=False, 42 | ) 43 | 44 | result.update(masks=masks.squeeze(1)) 45 | segments = result.masks.xyn # noqa 46 | 47 | with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f: 48 | for i in range(len(segments)): 49 | s = segments[i] 50 | if len(s) == 0: 51 | continue 52 | segment = map(str, segments[i].reshape(-1).tolist()) 53 | f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n') 54 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/data/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/MobileSAMv2/ultralytics/yolo/data/dataloaders/__init__.py -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/data/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import collections 4 | from copy import deepcopy 5 | 6 | from .augment import LetterBox 7 | 8 | 9 | class MixAndRectDataset: 10 | """ 11 | A dataset class that applies mosaic and mixup transformations as well as rectangular training. 12 | 13 | Attributes: 14 | dataset: The base dataset. 15 | imgsz: The size of the images in the dataset. 16 | """ 17 | 18 | def __init__(self, dataset): 19 | """ 20 | Args: 21 | dataset (BaseDataset): The base dataset to apply transformations to. 22 | """ 23 | self.dataset = dataset 24 | self.imgsz = dataset.imgsz 25 | 26 | def __len__(self): 27 | """Returns the number of items in the dataset.""" 28 | return len(self.dataset) 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Applies mosaic, mixup and rectangular training transformations to an item in the dataset. 33 | 34 | Args: 35 | index (int): Index of the item in the dataset. 36 | 37 | Returns: 38 | (dict): A dictionary containing the transformed item data. 39 | """ 40 | labels = deepcopy(self.dataset[index]) 41 | for transform in self.dataset.transforms.tolist(): 42 | # Mosaic and mixup 43 | if hasattr(transform, 'get_indexes'): 44 | indexes = transform.get_indexes(self.dataset) 45 | if not isinstance(indexes, collections.abc.Sequence): 46 | indexes = [indexes] 47 | labels['mix_labels'] = [deepcopy(self.dataset[index]) for index in indexes] 48 | if self.dataset.rect and isinstance(transform, LetterBox): 49 | transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]] 50 | labels = transform(labels) 51 | if 'mix_labels' in labels: 52 | labels.pop('mix_labels') 53 | return labels 54 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/data/scripts/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ultralytics YOLO 🚀, AGPL-3.0 license 3 | # Download latest models from https://github.com/ultralytics/assets/releases 4 | # Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh 5 | # parent 6 | # └── weights 7 | # ├── yolov8n.pt ← downloads here 8 | # ├── yolov8s.pt 9 | # └── ... 10 | 11 | python - <= batch_sizes[i]: # y intercept above failure point 80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point 81 | if b < 1 or b > 1024: # b outside of safe range 82 | b = batch_size 83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') 84 | 85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted 86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') 87 | return b 88 | except Exception as e: 89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') 90 | return batch_size 91 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .base import add_integration_callbacks, default_callbacks, get_default_callbacks 4 | 5 | __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' 6 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/hub.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import json 4 | from time import time 5 | 6 | from ultralytics.hub.utils import PREFIX, events 7 | from ultralytics.yolo.utils import LOGGER 8 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 9 | 10 | 11 | def on_pretrain_routine_end(trainer): 12 | """Logs info before starting timer for upload rate limit.""" 13 | session = getattr(trainer, 'hub_session', None) 14 | if session: 15 | # Start timer for upload rate limit 16 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 17 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit 18 | 19 | 20 | def on_fit_epoch_end(trainer): 21 | """Uploads training progress metrics at the end of each epoch.""" 22 | session = getattr(trainer, 'hub_session', None) 23 | if session: 24 | # Upload metrics after val end 25 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} 26 | if trainer.epoch == 0: 27 | all_plots = {**all_plots, **model_info_for_loggers(trainer)} 28 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots) 29 | if time() - session.timers['metrics'] > session.rate_limits['metrics']: 30 | session.upload_metrics() 31 | session.timers['metrics'] = time() # reset timer 32 | session.metrics_queue = {} # reset queue 33 | 34 | 35 | def on_model_save(trainer): 36 | """Saves checkpoints to Ultralytics HUB with rate limiting.""" 37 | session = getattr(trainer, 'hub_session', None) 38 | if session: 39 | # Upload checkpoints with rate limiting 40 | is_best = trainer.best_fitness == trainer.fitness 41 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: 42 | LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}') 43 | session.upload_model(trainer.epoch, trainer.last, is_best) 44 | session.timers['ckpt'] = time() # reset timer 45 | 46 | 47 | def on_train_end(trainer): 48 | """Upload final model and metrics to Ultralytics HUB at the end of training.""" 49 | session = getattr(trainer, 'hub_session', None) 50 | if session: 51 | # Upload final model and metrics with exponential standoff 52 | LOGGER.info(f'{PREFIX}Syncing final model...') 53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) 54 | session.alive = False # stop heartbeats 55 | LOGGER.info(f'{PREFIX}Done ✅\n' 56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 57 | 58 | 59 | def on_train_start(trainer): 60 | """Run events on train start.""" 61 | events(trainer.args) 62 | 63 | 64 | def on_val_start(validator): 65 | """Runs events on validation start.""" 66 | events(validator.args) 67 | 68 | 69 | def on_predict_start(predictor): 70 | """Run events on predict start.""" 71 | events(predictor.args) 72 | 73 | 74 | def on_export_start(exporter): 75 | """Run events on export start.""" 76 | events(exporter.args) 77 | 78 | 79 | callbacks = { 80 | 'on_pretrain_routine_end': on_pretrain_routine_end, 81 | 'on_fit_epoch_end': on_fit_epoch_end, 82 | 'on_model_save': on_model_save, 83 | 'on_train_end': on_train_end, 84 | 'on_train_start': on_train_start, 85 | 'on_val_start': on_val_start, 86 | 'on_predict_start': on_predict_start, 87 | 'on_export_start': on_export_start} 88 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/mlflow.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | from pathlib import Path 6 | 7 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 8 | 9 | try: 10 | import mlflow 11 | 12 | assert not TESTS_RUNNING # do not log pytest 13 | assert hasattr(mlflow, '__version__') # verify package is not directory 14 | except (ImportError, AssertionError): 15 | mlflow = None 16 | 17 | 18 | def on_pretrain_routine_end(trainer): 19 | """Logs training parameters to MLflow.""" 20 | global mlflow, run, run_id, experiment_name 21 | 22 | if os.environ.get('MLFLOW_TRACKING_URI') is None: 23 | mlflow = None 24 | 25 | if mlflow: 26 | mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000" 27 | mlflow.set_tracking_uri(mlflow_location) 28 | 29 | experiment_name = os.environ.get('MLFLOW_EXPERIMENT') or trainer.args.project or '/Shared/YOLOv8' 30 | experiment = mlflow.get_experiment_by_name(experiment_name) 31 | if experiment is None: 32 | mlflow.create_experiment(experiment_name) 33 | mlflow.set_experiment(experiment_name) 34 | 35 | prefix = colorstr('MLFlow: ') 36 | try: 37 | run, active_run = mlflow, mlflow.active_run() 38 | if not active_run: 39 | active_run = mlflow.start_run(experiment_id=experiment.experiment_id) 40 | run_id = active_run.info.run_id 41 | LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}') 42 | run.log_params(vars(trainer.model.args)) 43 | except Exception as err: 44 | LOGGER.error(f'{prefix}Failing init - {repr(err)}') 45 | LOGGER.warning(f'{prefix}Continuing without Mlflow') 46 | 47 | 48 | def on_fit_epoch_end(trainer): 49 | """Logs training metrics to Mlflow.""" 50 | if mlflow: 51 | metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} 52 | run.log_metrics(metrics=metrics_dict, step=trainer.epoch) 53 | 54 | 55 | def on_train_end(trainer): 56 | """Called at end of train loop to log model artifact info.""" 57 | if mlflow: 58 | root_dir = Path(__file__).resolve().parents[3] 59 | run.log_artifact(trainer.last) 60 | run.log_artifact(trainer.best) 61 | run.pyfunc.log_model(artifact_path=experiment_name, 62 | code_path=[str(root_dir)], 63 | artifacts={'model_path': str(trainer.save_dir)}, 64 | python_model=run.pyfunc.PythonModel()) 65 | 66 | 67 | callbacks = { 68 | 'on_pretrain_routine_end': on_pretrain_routine_end, 69 | 'on_fit_epoch_end': on_fit_epoch_end, 70 | 'on_train_end': on_train_end} if mlflow else {} 71 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/neptune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import matplotlib.image as mpimg 4 | import matplotlib.pyplot as plt 5 | 6 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 7 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 8 | 9 | try: 10 | import neptune 11 | from neptune.types import File 12 | 13 | assert not TESTS_RUNNING # do not log pytest 14 | assert hasattr(neptune, '__version__') 15 | except (ImportError, AssertionError): 16 | neptune = None 17 | 18 | run = None # NeptuneAI experiment logger instance 19 | 20 | 21 | def _log_scalars(scalars, step=0): 22 | """Log scalars to the NeptuneAI experiment logger.""" 23 | if run: 24 | for k, v in scalars.items(): 25 | run[k].append(value=v, step=step) 26 | 27 | 28 | def _log_images(imgs_dict, group=''): 29 | """Log scalars to the NeptuneAI experiment logger.""" 30 | if run: 31 | for k, v in imgs_dict.items(): 32 | run[f'{group}/{k}'].upload(File(v)) 33 | 34 | 35 | def _log_plot(title, plot_path): 36 | """Log plots to the NeptuneAI experiment logger.""" 37 | """ 38 | Log image as plot in the plot section of NeptuneAI 39 | 40 | arguments: 41 | title (str) Title of the plot 42 | plot_path (PosixPath or str) Path to the saved image file 43 | """ 44 | img = mpimg.imread(plot_path) 45 | fig = plt.figure() 46 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks 47 | ax.imshow(img) 48 | run[f'Plots/{title}'].upload(fig) 49 | 50 | 51 | def on_pretrain_routine_start(trainer): 52 | """Callback function called before the training routine starts.""" 53 | try: 54 | global run 55 | run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) 56 | run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} 57 | except Exception as e: 58 | LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') 59 | 60 | 61 | def on_train_epoch_end(trainer): 62 | """Callback function called at end of each training epoch.""" 63 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 64 | _log_scalars(trainer.lr, trainer.epoch + 1) 65 | if trainer.epoch == 1: 66 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') 67 | 68 | 69 | def on_fit_epoch_end(trainer): 70 | """Callback function called at end of each fit (train+val) epoch.""" 71 | if run and trainer.epoch == 0: 72 | run['Configuration/Model'] = model_info_for_loggers(trainer) 73 | _log_scalars(trainer.metrics, trainer.epoch + 1) 74 | 75 | 76 | def on_val_end(validator): 77 | """Callback function called at end of each validation.""" 78 | if run: 79 | # Log val_labels and val_pred 80 | _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') 81 | 82 | 83 | def on_train_end(trainer): 84 | """Callback function called at end of training.""" 85 | if run: 86 | # Log final results, CM matrix + PR plots 87 | files = [ 88 | 'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png', 89 | *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] 90 | files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter 91 | for f in files: 92 | _log_plot(title=f.stem, plot_path=f) 93 | # Log the final model 94 | run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( 95 | trainer.best))) 96 | 97 | 98 | callbacks = { 99 | 'on_pretrain_routine_start': on_pretrain_routine_start, 100 | 'on_train_epoch_end': on_train_epoch_end, 101 | 'on_fit_epoch_end': on_fit_epoch_end, 102 | 'on_val_end': on_val_end, 103 | 'on_train_end': on_train_end} if neptune else {} 104 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/raytune.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | try: 4 | import ray 5 | from ray import tune 6 | from ray.air import session 7 | except (ImportError, AssertionError): 8 | tune = None 9 | 10 | 11 | def on_fit_epoch_end(trainer): 12 | """Sends training metrics to Ray Tune at end of each epoch.""" 13 | if ray.tune.is_session_enabled(): 14 | metrics = trainer.metrics 15 | metrics['epoch'] = trainer.epoch 16 | session.report(metrics) 17 | 18 | 19 | callbacks = { 20 | 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} 21 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 4 | 5 | try: 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | assert not TESTS_RUNNING # do not log pytest 9 | except (ImportError, AssertionError): 10 | SummaryWriter = None 11 | 12 | writer = None # TensorBoard SummaryWriter instance 13 | 14 | 15 | def _log_scalars(scalars, step=0): 16 | """Logs scalar values to TensorBoard.""" 17 | if writer: 18 | for k, v in scalars.items(): 19 | writer.add_scalar(k, v, step) 20 | 21 | 22 | def on_pretrain_routine_start(trainer): 23 | """Initialize TensorBoard logging with SummaryWriter.""" 24 | if SummaryWriter: 25 | try: 26 | global writer 27 | writer = SummaryWriter(str(trainer.save_dir)) 28 | prefix = colorstr('TensorBoard: ') 29 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") 30 | except Exception as e: 31 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') 32 | 33 | 34 | def on_batch_end(trainer): 35 | """Logs scalar statistics at the end of a training batch.""" 36 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 37 | 38 | 39 | def on_fit_epoch_end(trainer): 40 | """Logs epoch metrics at end of training epoch.""" 41 | _log_scalars(trainer.metrics, trainer.epoch + 1) 42 | 43 | 44 | callbacks = { 45 | 'on_pretrain_routine_start': on_pretrain_routine_start, 46 | 'on_fit_epoch_end': on_fit_epoch_end, 47 | 'on_batch_end': on_batch_end} 48 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/callbacks/wb.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | from ultralytics.yolo.utils import TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import model_info_for_loggers 4 | 5 | try: 6 | import wandb as wb 7 | 8 | assert hasattr(wb, '__version__') 9 | assert not TESTS_RUNNING # do not log pytest 10 | except (ImportError, AssertionError): 11 | wb = None 12 | 13 | _processed_plots = {} 14 | 15 | 16 | def _log_plots(plots, step): 17 | for name, params in plots.items(): 18 | timestamp = params['timestamp'] 19 | if _processed_plots.get(name, None) != timestamp: 20 | wb.run.log({name.stem: wb.Image(str(name))}, step=step) 21 | _processed_plots[name] = timestamp 22 | 23 | 24 | def on_pretrain_routine_start(trainer): 25 | """Initiate and start project if module is present.""" 26 | wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) 27 | 28 | 29 | def on_fit_epoch_end(trainer): 30 | """Logs training metrics and model information at the end of an epoch.""" 31 | wb.run.log(trainer.metrics, step=trainer.epoch + 1) 32 | _log_plots(trainer.plots, step=trainer.epoch + 1) 33 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 34 | if trainer.epoch == 0: 35 | wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) 36 | 37 | 38 | def on_train_epoch_end(trainer): 39 | """Log metrics and save images at the end of each training epoch.""" 40 | wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) 41 | wb.run.log(trainer.lr, step=trainer.epoch + 1) 42 | if trainer.epoch == 1: 43 | _log_plots(trainer.plots, step=trainer.epoch + 1) 44 | 45 | 46 | def on_train_end(trainer): 47 | """Save the best model as an artifact at end of training.""" 48 | _log_plots(trainer.validator.plots, step=trainer.epoch + 1) 49 | _log_plots(trainer.plots, step=trainer.epoch + 1) 50 | art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') 51 | if trainer.best.exists(): 52 | art.add_file(trainer.best) 53 | wb.run.log_artifact(art) 54 | 55 | 56 | callbacks = { 57 | 'on_pretrain_routine_start': on_pretrain_routine_start, 58 | 'on_train_epoch_end': on_train_epoch_end, 59 | 'on_fit_epoch_end': on_fit_epoch_end, 60 | 'on_train_end': on_train_end} if wb else {} 61 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import os 4 | import re 5 | import shutil 6 | import socket 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | from . import USER_CONFIG_DIR 12 | from .torch_utils import TORCH_1_9 13 | 14 | 15 | def find_free_network_port() -> int: 16 | """Finds a free port on localhost. 17 | 18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the 19 | `MASTER_PORT` environment variable. 20 | """ 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('127.0.0.1', 0)) 23 | return s.getsockname()[1] # port 24 | 25 | 26 | def generate_ddp_file(trainer): 27 | """Generates a DDP file and returns its file name.""" 28 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) 29 | 30 | content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__": 31 | from {module} import {name} 32 | from ultralytics.yolo.utils import DEFAULT_CFG_DICT 33 | 34 | cfg = DEFAULT_CFG_DICT.copy() 35 | cfg.update(save_dir='') # handle the extra key 'save_dir' 36 | trainer = {name}(cfg=cfg, overrides=overrides) 37 | trainer.train()''' 38 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) 39 | with tempfile.NamedTemporaryFile(prefix='_temp_', 40 | suffix=f'{id(trainer)}.py', 41 | mode='w+', 42 | encoding='utf-8', 43 | dir=USER_CONFIG_DIR / 'DDP', 44 | delete=False) as file: 45 | file.write(content) 46 | return file.name 47 | 48 | 49 | def generate_ddp_command(world_size, trainer): 50 | """Generates and returns command for distributed training.""" 51 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 52 | if not trainer.resume: 53 | shutil.rmtree(trainer.save_dir) # remove the save_dir 54 | file = str(Path(sys.argv[0]).resolve()) 55 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters 56 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI 57 | file = generate_ddp_file(trainer) 58 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' 59 | port = find_free_network_port() 60 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] 61 | return cmd, file 62 | 63 | 64 | def ddp_cleanup(trainer, file): 65 | """Delete temp file if created.""" 66 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file 67 | os.remove(file) 68 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import emojis 4 | 5 | 6 | class HUBModelError(Exception): 7 | 8 | def __init__(self, message='Model not found. Please check model URL and try again.'): 9 | """Create an exception for when a model is not found.""" 10 | super().__init__(emojis(message)) 11 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/files.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import contextlib 4 | import glob 5 | import os 6 | import shutil 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | 11 | class WorkingDirectory(contextlib.ContextDecorator): 12 | """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" 13 | 14 | def __init__(self, new_dir): 15 | """Sets the working directory to 'new_dir' upon instantiation.""" 16 | self.dir = new_dir # new dir 17 | self.cwd = Path.cwd().resolve() # current dir 18 | 19 | def __enter__(self): 20 | """Changes the current directory to the specified directory.""" 21 | os.chdir(self.dir) 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | """Restore the current working directory on context exit.""" 25 | os.chdir(self.cwd) 26 | 27 | 28 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 29 | """ 30 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 31 | 32 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to 33 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the 34 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a 35 | directory if it does not already exist. 36 | 37 | Args: 38 | path (str, pathlib.Path): Path to increment. 39 | exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False. 40 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''. 41 | mkdir (bool, optional): Create a directory if it does not exist. Defaults to False. 42 | 43 | Returns: 44 | (pathlib.Path): Incremented path. 45 | """ 46 | path = Path(path) # os-agnostic 47 | if path.exists() and not exist_ok: 48 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 49 | 50 | # Method 1 51 | for n in range(2, 9999): 52 | p = f'{path}{sep}{n}{suffix}' # increment path 53 | if not os.path.exists(p): # 54 | break 55 | path = Path(p) 56 | 57 | if mkdir: 58 | path.mkdir(parents=True, exist_ok=True) # make directory 59 | 60 | return path 61 | 62 | 63 | def file_age(path=__file__): 64 | """Return days since last file update.""" 65 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 66 | return dt.days # + dt.seconds / 86400 # fractional days 67 | 68 | 69 | def file_date(path=__file__): 70 | """Return human-readable file modification date, i.e. '2021-3-26'.""" 71 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 72 | return f'{t.year}-{t.month}-{t.day}' 73 | 74 | 75 | def file_size(path): 76 | """Return file/dir size (MB).""" 77 | if isinstance(path, (str, Path)): 78 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 79 | path = Path(path) 80 | if path.is_file(): 81 | return path.stat().st_size / mb 82 | elif path.is_dir(): 83 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 84 | return 0.0 85 | 86 | 87 | def get_latest_run(search_dir='.'): 88 | """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" 89 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 90 | return max(last_list, key=os.path.getctime) if last_list else '' 91 | 92 | 93 | def make_dirs(dir='new_dir/'): 94 | # Create folders 95 | dir = Path(dir) 96 | if dir.exists(): 97 | shutil.rmtree(dir) # delete dir 98 | for p in dir, dir / 'labels', dir / 'images': 99 | p.mkdir(parents=True, exist_ok=True) # make dir 100 | return dir 101 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/patches.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | Monkey patches to update/extend functionality of existing functions 4 | """ 5 | 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | 12 | # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ 13 | _imshow = cv2.imshow # copy to avoid recursion errors 14 | 15 | 16 | def imread(filename, flags=cv2.IMREAD_COLOR): 17 | return cv2.imdecode(np.fromfile(filename, np.uint8), flags) 18 | 19 | 20 | def imwrite(filename, img): 21 | try: 22 | cv2.imencode(Path(filename).suffix, img)[1].tofile(filename) 23 | return True 24 | except Exception: 25 | return False 26 | 27 | 28 | def imshow(path, im): 29 | _imshow(path.encode('unicode_escape').decode(), im) 30 | 31 | 32 | # PyTorch functions ---------------------------------------------------------------------------------------------------- 33 | _torch_save = torch.save # copy to avoid recursion errors 34 | 35 | 36 | def torch_save(*args, **kwargs): 37 | # Use dill (if exists) to serialize the lambda functions where pickle does not do this 38 | try: 39 | import dill as pickle 40 | except ImportError: 41 | import pickle 42 | 43 | if 'pickle_module' not in kwargs: 44 | kwargs['pickle_module'] = pickle 45 | return _torch_save(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/utils/tuner.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.utils import LOGGER 4 | 5 | try: 6 | from ray import tune 7 | from ray.air import RunConfig, session # noqa 8 | from ray.air.integrations.wandb import WandbLoggerCallback # noqa 9 | from ray.tune.schedulers import ASHAScheduler # noqa 10 | from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa 11 | 12 | except ImportError: 13 | LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`") 14 | tune = None 15 | 16 | default_space = { 17 | # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), 18 | 'lr0': tune.uniform(1e-5, 1e-1), 19 | 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) 20 | 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 21 | 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 22 | 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) 23 | 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum 24 | 'box': tune.uniform(0.02, 0.2), # box loss gain 25 | 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) 26 | 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) 27 | 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) 28 | 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) 29 | 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) 30 | 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) 31 | 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) 32 | 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) 33 | 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 34 | 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) 35 | 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) 36 | 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) 37 | 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) 38 | 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) 39 | 40 | task_metric_map = { 41 | 'detect': 'metrics/mAP50-95(B)', 42 | 'segment': 'metrics/mAP50-95(M)', 43 | 'classify': 'metrics/accuracy_top1', 44 | 'pose': 'metrics/mAP50-95(P)'} 45 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.v8 import classify, detect, pose, segment 4 | 5 | __all__ = 'classify', 'segment', 'detect', 'pose' 6 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/classify/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict 4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train 5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val 6 | 7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val' 8 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/classify/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT 8 | 9 | 10 | class ClassificationPredictor(BasePredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'classify' 15 | 16 | def preprocess(self, img): 17 | """Converts input image to model-compatible data type.""" 18 | if not isinstance(img, torch.Tensor): 19 | img = torch.stack([self.transforms(im) for im in img], dim=0) 20 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 21 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 22 | 23 | def postprocess(self, preds, img, orig_imgs): 24 | """Postprocesses predictions to return Results objects.""" 25 | results = [] 26 | for i, pred in enumerate(preds): 27 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 28 | path = self.batch[0] 29 | img_path = path[i] if isinstance(path, list) else path 30 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) 31 | 32 | return results 33 | 34 | 35 | def predict(cfg=DEFAULT_CFG, use_python=False): 36 | """Run YOLO model predictions on input images/videos.""" 37 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 38 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 39 | else 'https://ultralytics.com/images/bus.jpg' 40 | 41 | args = dict(model=model, source=source) 42 | if use_python: 43 | from ultralytics import YOLO 44 | YOLO(model)(**args) 45 | else: 46 | predictor = ClassificationPredictor(overrides=args) 47 | predictor.predict_cli() 48 | 49 | 50 | if __name__ == '__main__': 51 | predict() 52 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/detect/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import DetectionPredictor, predict 4 | from .train import DetectionTrainer, train 5 | from .val import DetectionValidator, val 6 | 7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' 8 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/detect/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 8 | 9 | 10 | class DetectionPredictor(BasePredictor): 11 | 12 | def postprocess(self, preds, img, orig_imgs): 13 | """Postprocesses predictions and returns a list of Results objects.""" 14 | preds = ops.non_max_suppression(preds, 15 | self.args.conf, 16 | self.args.iou, 17 | agnostic=self.args.agnostic_nms, 18 | max_det=self.args.max_det, 19 | classes=self.args.classes) 20 | 21 | results = [] 22 | for i, pred in enumerate(preds): 23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 24 | if not isinstance(orig_imgs, torch.Tensor): 25 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 26 | path = self.batch[0] 27 | img_path = path[i] if isinstance(path, list) else path 28 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 29 | return results 30 | 31 | 32 | def predict(cfg=DEFAULT_CFG, use_python=False): 33 | """Runs YOLO model inference on input image(s).""" 34 | model = cfg.model or 'yolov8n.pt' 35 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 36 | else 'https://ultralytics.com/images/bus.jpg' 37 | 38 | args = dict(model=model, source=source) 39 | if use_python: 40 | from ultralytics import YOLO 41 | YOLO(model)(**args) 42 | else: 43 | predictor = DetectionPredictor(overrides=args) 44 | predictor.predict_cli() 45 | 46 | 47 | if __name__ == '__main__': 48 | predict() 49 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/pose/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import PosePredictor, predict 4 | from .train import PoseTrainer, train 5 | from .val import PoseValidator, val 6 | 7 | __all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict' 8 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/pose/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | 7 | 8 | class PosePredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'pose' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """Return detection results for a given input image or list of images.""" 16 | preds = ops.non_max_suppression(preds, 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | classes=self.args.classes, 22 | nc=len(self.model.names)) 23 | 24 | results = [] 25 | for i, pred in enumerate(preds): 26 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 27 | shape = orig_img.shape 28 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() 29 | pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] 30 | pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape) 31 | path = self.batch[0] 32 | img_path = path[i] if isinstance(path, list) else path 33 | results.append( 34 | Results(orig_img=orig_img, 35 | path=img_path, 36 | names=self.model.names, 37 | boxes=pred[:, :6], 38 | keypoints=pred_kpts)) 39 | return results 40 | 41 | 42 | def predict(cfg=DEFAULT_CFG, use_python=False): 43 | """Runs YOLO to predict objects in an image or video.""" 44 | model = cfg.model or 'yolov8n-pose.pt' 45 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 46 | else 'https://ultralytics.com/images/bus.jpg' 47 | 48 | args = dict(model=model, source=source) 49 | if use_python: 50 | from ultralytics import YOLO 51 | YOLO(model)(**args) 52 | else: 53 | predictor = PosePredictor(overrides=args) 54 | predictor.predict_cli() 55 | 56 | 57 | if __name__ == '__main__': 58 | predict() 59 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/pose/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from copy import copy 4 | 5 | from ultralytics.nn.tasks import PoseModel 6 | from ultralytics.yolo import v8 7 | from ultralytics.yolo.utils import DEFAULT_CFG 8 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 9 | 10 | 11 | # BaseTrainer python usage 12 | class PoseTrainer(v8.detect.DetectionTrainer): 13 | 14 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 15 | """Initialize a PoseTrainer object with specified configurations and overrides.""" 16 | if overrides is None: 17 | overrides = {} 18 | overrides['task'] = 'pose' 19 | super().__init__(cfg, overrides, _callbacks) 20 | 21 | def get_model(self, cfg=None, weights=None, verbose=True): 22 | """Get pose estimation model with specified configuration and weights.""" 23 | model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) 24 | if weights: 25 | model.load(weights) 26 | 27 | return model 28 | 29 | def set_model_attributes(self): 30 | """Sets keypoints shape attribute of PoseModel.""" 31 | super().set_model_attributes() 32 | self.model.kpt_shape = self.data['kpt_shape'] 33 | 34 | def get_validator(self): 35 | """Returns an instance of the PoseValidator class for validation.""" 36 | self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' 37 | return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 38 | 39 | def plot_training_samples(self, batch, ni): 40 | """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" 41 | images = batch['img'] 42 | kpts = batch['keypoints'] 43 | cls = batch['cls'].squeeze(-1) 44 | bboxes = batch['bboxes'] 45 | paths = batch['im_file'] 46 | batch_idx = batch['batch_idx'] 47 | plot_images(images, 48 | batch_idx, 49 | cls, 50 | bboxes, 51 | kpts=kpts, 52 | paths=paths, 53 | fname=self.save_dir / f'train_batch{ni}.jpg', 54 | on_plot=self.on_plot) 55 | 56 | def plot_metrics(self): 57 | """Plots training/val metrics.""" 58 | plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png 59 | 60 | 61 | def train(cfg=DEFAULT_CFG, use_python=False): 62 | """Train the YOLO model on the given data and device.""" 63 | model = cfg.model or 'yolov8n-pose.yaml' 64 | data = cfg.data or 'coco8-pose.yaml' 65 | device = cfg.device if cfg.device is not None else '' 66 | 67 | args = dict(model=model, data=data, device=device) 68 | if use_python: 69 | from ultralytics import YOLO 70 | YOLO(model).train(**args) 71 | else: 72 | trainer = PoseTrainer(overrides=args) 73 | trainer.train() 74 | 75 | 76 | if __name__ == '__main__': 77 | train() 78 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/segment/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .predict import SegmentationPredictor, predict 4 | from .train import SegmentationTrainer, train 5 | from .val import SegmentationValidator, val 6 | 7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val' 8 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/segment/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.results import Results 6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 7 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 8 | 9 | 10 | class SegmentationPredictor(DetectionPredictor): 11 | 12 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 13 | super().__init__(cfg, overrides, _callbacks) 14 | self.args.task = 'segment' 15 | 16 | def postprocess(self, preds, img, orig_imgs): 17 | """TODO: filter by classes.""" 18 | p = ops.non_max_suppression(preds[0], 19 | self.args.conf, 20 | self.args.iou, 21 | agnostic=self.args.agnostic_nms, 22 | max_det=self.args.max_det, 23 | nc=len(self.model.names), 24 | classes=self.args.classes) 25 | results = [] 26 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 27 | for i, pred in enumerate(p): 28 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 29 | path = self.batch[0] 30 | img_path = path[i] if isinstance(path, list) else path 31 | if not len(pred): # save empty boxes 32 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 33 | continue 34 | if self.args.retina_masks: 35 | if not isinstance(orig_imgs, torch.Tensor): 36 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 37 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 38 | else: 39 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 40 | if not isinstance(orig_imgs, torch.Tensor): 41 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 42 | results.append( 43 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 44 | return results 45 | 46 | 47 | def predict(cfg=DEFAULT_CFG, use_python=False): 48 | """Runs YOLO object detection on an image or video source.""" 49 | model = cfg.model or 'yolov8n-seg.pt' 50 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 51 | else 'https://ultralytics.com/images/bus.jpg' 52 | 53 | args = dict(model=model, source=source) 54 | if use_python: 55 | from ultralytics import YOLO 56 | YOLO(model)(**args) 57 | else: 58 | predictor = SegmentationPredictor(overrides=args) 59 | predictor.predict_cli() 60 | 61 | 62 | if __name__ == '__main__': 63 | predict() 64 | -------------------------------------------------------------------------------- /MobileSAMv2/ultralytics/yolo/v8/segment/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | from copy import copy 3 | 4 | from ultralytics.nn.tasks import SegmentationModel 5 | from ultralytics.yolo import v8 6 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK 7 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 8 | 9 | 10 | # BaseTrainer python usage 11 | class SegmentationTrainer(v8.detect.DetectionTrainer): 12 | 13 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 14 | """Initialize a SegmentationTrainer object with given arguments.""" 15 | if overrides is None: 16 | overrides = {} 17 | overrides['task'] = 'segment' 18 | super().__init__(cfg, overrides, _callbacks) 19 | 20 | def get_model(self, cfg=None, weights=None, verbose=True): 21 | """Return SegmentationModel initialized with specified config and weights.""" 22 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) 23 | if weights: 24 | model.load(weights) 25 | 26 | return model 27 | 28 | def get_validator(self): 29 | """Return an instance of SegmentationValidator for validation of YOLO model.""" 30 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' 31 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 32 | 33 | def plot_training_samples(self, batch, ni): 34 | """Creates a plot of training sample images with labels and box coordinates.""" 35 | plot_images(batch['img'], 36 | batch['batch_idx'], 37 | batch['cls'].squeeze(-1), 38 | batch['bboxes'], 39 | batch['masks'], 40 | paths=batch['im_file'], 41 | fname=self.save_dir / f'train_batch{ni}.jpg', 42 | on_plot=self.on_plot) 43 | 44 | def plot_metrics(self): 45 | """Plots training/val metrics.""" 46 | plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png 47 | 48 | 49 | def train(cfg=DEFAULT_CFG, use_python=False): 50 | """Train a YOLO segmentation model based on passed arguments.""" 51 | model = cfg.model or 'yolov8n-seg.pt' 52 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") 53 | device = cfg.device if cfg.device is not None else '' 54 | 55 | args = dict(model=model, data=data, device=device) 56 | if use_python: 57 | from ultralytics import YOLO 58 | YOLO(model).train(**args) 59 | else: 60 | trainer = SegmentationTrainer(overrides=args) 61 | trainer.train() 62 | 63 | 64 | if __name__ == '__main__': 65 | train() 66 | -------------------------------------------------------------------------------- /app/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | assets/sa_1309.jpg filter=lfs diff=lfs merge=lfs -text 37 | assets/sa_192.jpg filter=lfs diff=lfs merge=lfs -text 38 | assets/sa_414.jpg filter=lfs diff=lfs merge=lfs -text 39 | assets/sa_862.jpg filter=lfs diff=lfs merge=lfs -text 40 | -------------------------------------------------------------------------------- /app/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: MobileSAM 3 | emoji: 🐠 4 | colorFrom: indigo 5 | colorTo: yellow 6 | sdk: gradio 7 | python_version: 3.8.10 8 | sdk_version: 3.35.2 9 | app_file: app.py 10 | pinned: false 11 | license: apache-2.0 12 | --- 13 | 14 | # Faster Segment Anything(MobileSAM) 15 | 16 | Demo of official PyTorch implementation of [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). 17 | 18 | 19 | **MobileSAM** performs on par with the original SAM (at least visually) and keeps exactly the same pipeline as the original SAM except for a change on the image encoder. 20 | Specifically, we replace the original heavyweight ViT-H encoder (632M) with a much smaller Tiny-ViT (5M). On a single GPU, MobileSAM runs around 12ms per image: 8ms on the image encoder and 4ms on the mask decoder. 21 | 22 | ## To run on local PC 23 | First, mobile_sam must be installed to run on pc. Refer to [Installation Instruction](https://github.com/dhkim2810/MobileSAM/tree/master#installation) 24 | 25 | Then run the following 26 | 27 | ``` 28 | python app.py 29 | ``` 30 | 31 | ## License 32 | 33 | The model is licensed under the [Apache 2.0 license](LICENSE). 34 | 35 | 36 | ## Acknowledgement 37 | 38 | - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes. 39 | - [TinyViT](https://github.com/microsoft/Cream/tree/main/TinyViT) provides codes and pre-trained models. 40 | 41 | ## Citing MobileSAM 42 | 43 | If you find this project useful for your research, please consider citing the following BibTeX entry. 44 | 45 | ```bibtex 46 | @article{mobile_sam, 47 | title={Faster Segment Anything: Towards Lightweight SAM for Mobile Applications}, 48 | author={Zhang, Chaoning and Han, Dongshen and Qiao, Yu and Kim, Jung Uk and Bae, Sung Ho and Lee, Seungkyu and Hong, Choong Seon}, 49 | journal={arXiv preprint arXiv:2306.14289}, 50 | year={2023} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /app/assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/.DS_Store -------------------------------------------------------------------------------- /app/assets/picture1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture1.jpg -------------------------------------------------------------------------------- /app/assets/picture2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture2.jpg -------------------------------------------------------------------------------- /app/assets/picture3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture3.jpg -------------------------------------------------------------------------------- /app/assets/picture4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture4.jpg -------------------------------------------------------------------------------- /app/assets/picture5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture5.jpg -------------------------------------------------------------------------------- /app/assets/picture6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/assets/picture6.jpg -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | opencv-python 5 | git+https://github.com/dhkim2810/MobileSAM.git 6 | -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/app/utils/__init__.py -------------------------------------------------------------------------------- /assets/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/logo2.png -------------------------------------------------------------------------------- /assets/mask_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/mask_box.jpg -------------------------------------------------------------------------------- /assets/mask_comparision.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/mask_comparision.jpg -------------------------------------------------------------------------------- /assets/mask_point.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/mask_point.jpg -------------------------------------------------------------------------------- /assets/model_diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/model_diagram.jpg -------------------------------------------------------------------------------- /assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/notebook1.png -------------------------------------------------------------------------------- /assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/assets/notebook2.png -------------------------------------------------------------------------------- /linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /mobile_sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | build_sam_vit_t, 13 | sam_model_registry, 14 | ) 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /mobile_sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT 13 | -------------------------------------------------------------------------------- /mobile_sam/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /mobile_sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mobile_sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /notebooks/images/picture1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/notebooks/images/picture1.jpg -------------------------------------------------------------------------------- /notebooks/images/picture2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/notebooks/images/picture2.jpg -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=mobile_sam 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="mobile_sam", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /weights/mobile_sam.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/34bbbfdface3c18e5221aa7de6032d7220c6c6a1/weights/mobile_sam.pt --------------------------------------------------------------------------------