├── .gitignore ├── LICENSE ├── README.md ├── assets └── overview.jpg ├── configs ├── base_config.py ├── config_hrsid.py ├── config_nwpu.py └── config_whu.py ├── data ├── HRSID │ └── Annotations │ │ ├── inshore │ │ ├── inshore_test.json │ │ └── inshore_train.json │ │ └── offshore │ │ ├── offshore_test.json │ │ └── offshore_train.json ├── NWPU │ └── Annotations │ │ ├── NWPU_instances_train.json │ │ └── NWPU_instances_val.json └── WHU │ └── annotations │ ├── WHU_building_test.json │ ├── WHU_building_train.json │ └── WHU_building_val.json ├── datasets ├── HRSID.py ├── NWPU.py ├── WHU.py ├── __init__.py ├── augmentation.py └── tools.py ├── inference.py ├── pretrain └── Where_To_Save_Pretrained_SAM_Checkpoints ├── requirements.txt ├── scripts ├── train_hrsid_pointsam.sh ├── train_hrsid_selftrain.sh ├── train_hrsid_supervise.sh ├── train_nwpu_pointsam.sh ├── train_nwpu_selftrain.sh ├── train_nwpu_supervise.sh ├── train_whu_pointsam.sh ├── train_whu_selftrain.sh └── train_whu_supervise.sh ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── image_encoder_adapter.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── train_pointsam.py ├── train_selftrain.py ├── train_supervise.py └── utils ├── eval_utils.py ├── finch.py ├── losses.py ├── model.py ├── sam_lora.py ├── sample_utils.py ├── tools.py ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | .idea/* 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/en/_build/ 70 | docs/zh_cn/_build/ 71 | src 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .dmypy.json 114 | dmypy.json 115 | 116 | # Pyre type checker 117 | .pyre/ 118 | .DS_Store 119 | .idea 120 | *work_dirs* 121 | tmp 122 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nanqing Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # PointSAM: Pointly-Supervised Segment Anything Model for Remote Sensing Images 4 |

5 | Oryx Video-ChatGPT 6 |

7 | 8 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2409.13401) 9 | 10 |
11 | 12 | 19 | 20 | 21 | --- 22 | ## 📢 Latest Updates 23 | - **2 Jan 2025**: **PointSAM** has been accepted by TGRS and is now available [here](https://ieeexplore.ieee.org/document/10839471). 24 | - **8 Dec 2024**: The complete code is released. 25 | - **20 Sep 2024**: The arXiv version is released [here](https://arxiv.org/abs/2409.13401). 26 | --- 27 | 28 | 29 | 30 | ## 🎨 Overview 31 | 32 | ![PDF Page](assets/overview.jpg) 33 | 34 | ## 🎮 Getting Started 35 | ### 1.Install Environment 36 | To ensure compatibility, **Python version must not exceed 3.10**. Follow these steps to set up your environment: 37 | ```bash 38 | conda create --name pointsam python=3.10 39 | conda activate pointsam 40 | 41 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 42 | git clone https://github.com/Lans1ng/PointSAM.git 43 | cd PointSAM 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | **Note:** 48 | The CUDA version in the `pip install` command is specified as `cu118` (CUDA 11.8). If your system uses a different CUDA version (e.g., CUDA 12.1), replace `cu118` with the appropriate version tag (e.g., `cu121`). 49 | 50 | ### 2.Prepare Dataset 51 | 52 | #### WHU Building Dataset 53 | 54 | - Dataset download address: [WHU Building Dataset](https://aistudio.baidu.com/datasetdetail/56502)。 55 | 56 | - For converting semantic label to instance label, you can refer to corresponding [conversion script](https://github.com/KyanChen/RSPrompter/blob/release/tools/rsprompter/whu2coco.py). 57 | 58 | #### HRSID Dataset 59 | 60 | - Dataset download address: [HRSID Dataset](https://github.com/chaozhong2010/HRSID). 61 | 62 | #### NWPU VHR-10 Dataset 63 | 64 | - Dataset download address: [NWPU VHR-10 Dataset](https://aistudio.baidu.com/datasetdetail/52812). 65 | 66 | - Instance label download address: [NWPU VHR-10 Instance Label](https://github.com/chaozhong2010/VHR-10_dataset_coco). 67 | 68 | For convenience, the necessary JSON annotations are included in this repo. You only need to download the corresponding images. Organize your dataset as follows: 69 | 70 | ``` 71 | data 72 | ├── WHU 73 | │ ├── annotations 74 | │ │ ├── WHU_building_train.json 75 | │ │ ├── WHU_building_test.json 76 | │ │ └── WHU_building_val.json 77 | │ └── images 78 | │ ├── train 79 | │ │ ├── image 80 | │ │ └── label 81 | │ ├── val 82 | │ │ ├── image 83 | │ │ └── label 84 | │ └── test 85 | │ ├── image 86 | │ └── label 87 | ├── HRSID 88 | │ ├── Annotations 89 | │ │ ├── all 90 | │ │ ├── inshore 91 | │ │ │ ├── inshore_test.json 92 | │ │ │ └── inshore_train.json 93 | │ │ └── offshore 94 | │ └── Images 95 | └── NWPU 96 | ├── Annotations 97 | │ ├── NWPU_instnaces_train.json 98 | │ └── NWPU_instnaces_val.json 99 | └── Images 100 | 101 | ``` 102 | ### 3.Download Checkpoints 103 | 104 | Click the links below to download the checkpoint for the corresponding model type. 105 | 106 | - `vit-h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) 107 | - `vit-l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) 108 | - `vit-b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 109 | 110 | After downloading, move the models to the `pretrain` folder. 111 | 112 | **Note**: In our project, only the `vit-b` model is used. 113 | 114 | ### 4.Training 115 | For convenience, the `scripts` folder contains instructions for **Supervised Training**, **Self-Training**, and **PointSAM** on the NWPU VHR-10, WHU, and HRSID datasets. 116 | 117 | Here’s an example of training PointSAM on the WHU dataset: 118 | ```bash 119 | bash scripts/train_whu_pointsam.sh 120 | ``` 121 | 122 | ### 5. Inference 123 | 124 | Here’s an example of how to perform inference: 125 | 126 | ``` 127 | python inference.py --cfg --out_dir --ckpt 128 | ``` 129 | 130 | Please replace ``, ``, and `` with the values of the actual path. 131 | 132 | **Note:** The generated results consist of four images arranged in parallel: 133 | 134 | - The first image is the original input image. 135 | - The second image is the visualization of the GT mask. 136 | - The third image is the result obtained by direct testing through the original SAM. 137 | - The fourth image is the result obtained using the provided checkpoint. 138 | 139 | 140 | ## 💡 Acknowledgement 141 | 142 | - [wesam](https://github.com/zhang-haojie/wesam) 143 | - [OWOD](https://github.com/JosephKJ/OWOD) 144 | - [RSPrompter](https://github.com/KyanChen/RSPrompter) 145 | 146 | 147 | ## 🖊️ Citation 148 | 149 | If you find this project useful in your research, please consider starring ⭐ and citing 📚: 150 | 151 | ```BibTeX 152 | @ARTICLE{10839471, 153 | author={Liu, Nanqing and Xu, Xun and Su, Yongyi and Zhang, Haojie and Li, Heng-Chao}, 154 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 155 | title={PointSAM: Pointly-Supervised Segment Anything Model for Remote Sensing Images}, 156 | year={2025}, 157 | volume={63}, 158 | number={}, 159 | pages={1-15}, 160 | doi={10.1109/TGRS.2025.3529031}} 161 | 162 | ``` 163 | -------------------------------------------------------------------------------- /assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lans1ng/PointSAM/d6831edf23de4bd7304fb46a5ce2a0baa69dcbd5/assets/overview.jpg -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | base_config = { 2 | "eval_interval": 1, 3 | "ema_rate": 0.999, 4 | "csv_keys": [ "Prompt", "IoU", "Recall", "Precision", "F1", "epoch"], 5 | "opt": { 6 | "learning_rate": 1e-5, 7 | "weight_decay": 1e-4,# 8 | "decay_factor": 10, 9 | "steps": [3000, 8000], 10 | "warmup_steps": 250, 11 | }, 12 | "corruptions": [ 13 | "gaussian_noise", 14 | "shot_noise", 15 | "impulse_noise", 16 | "defocus_blur", 17 | "glass_blur", 18 | "motion_blur", 19 | "zoom_blur", 20 | "snow", 21 | "frost", 22 | "fog", 23 | "brightness", 24 | "contrast", 25 | "elastic_transform", 26 | "pixelate", 27 | "jpeg_compression", 28 | ], 29 | "model": { 30 | "type": "vit_b", 31 | "checkpoint": "./pretrain/", 32 | "ckpt": "", 33 | "freeze": { 34 | "image_encoder": True, 35 | "prompt_encoder": True, 36 | "mask_decoder": True, 37 | }, 38 | }, 39 | "datasets": { 40 | "NWPU": { 41 | "root_dir": "data/NWPU/Images", 42 | "annotation_file_train": "data/NWPU/Annotations/NWPU_instances_train.json", 43 | "annotation_file_val": "data/NWPU/Annotations/NWPU_instances_val.json", 44 | }, 45 | "WHU": { 46 | "root_dir": "data/WHU", 47 | "annotation_file_train": "data/WHU/annotations/WHU_building_train.json", 48 | "annotation_file_val": "data/WHU/annotations/WHU_building_val.json", 49 | }, 50 | "HRSID": { 51 | "root_dir": "/root/autodl-fs/_DATASETS/HRSID/Images", 52 | "annotation_file_train": "/root/autodl-fs/_DATASETS/HRSID/Annotations/inshore/inshore_train.json", 53 | "annotation_file_val": "/root/autodl-fs/_DATASETS/HRSID/Annotations/inshore/inshore_test.json" 54 | }, 55 | }, 56 | } 57 | -------------------------------------------------------------------------------- /configs/config_hrsid.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | from configs.base_config import base_config 3 | 4 | config = { 5 | "dataset": "HRSID", 6 | "load_type": "soft", 7 | "num_points": 1, 8 | 9 | "batch_size": 1, #only support 1 10 | "val_batchsize": 1, 11 | "num_workers": 0, 12 | "num_epochs": 10, 13 | "max_nums": 50, 14 | "resume": False, 15 | 16 | "start_lora_layer": 1, 17 | "lora_rank": 4, 18 | "mem_bank_max_len": 512, 19 | "match_interval": 30, 20 | "iou_thr": 0.1, 21 | 22 | "prompt": "point", 23 | "out_dir": "", 24 | "name": "base", 25 | "corrupt": None, 26 | "visual": False, 27 | "model": { 28 | "type": "vit_b", 29 | }, 30 | "opt": { 31 | "learning_rate": 5e-4, 32 | "weight_decay": 1e-4, 33 | "decay_factor": 10, 34 | "steps": [1500, 2000], 35 | "warmup_steps": 250, 36 | }, 37 | } 38 | 39 | cfg = Box(base_config) 40 | cfg.merge_update(config) 41 | -------------------------------------------------------------------------------- /configs/config_nwpu.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | from configs.base_config import base_config 3 | 4 | config = { 5 | "dataset": "NWPU", 6 | "load_type": "soft", 7 | "num_points": 1, 8 | 9 | "batch_size": 1, #only support 1 10 | "val_batchsize": 1, 11 | "num_workers": 0, 12 | "num_epochs": 10, 13 | "max_nums": 50, 14 | "resume": False, 15 | 16 | "start_lora_layer": 6, 17 | "lora_rank": 4, 18 | "mem_bank_max_len": 128, 19 | "match_interval": 30, 20 | "iou_thr": 0.1, 21 | 22 | "prompt": "point", 23 | "out_dir": "", 24 | "name": "base", 25 | "corrupt": None, 26 | "visual": False, 27 | "model": { 28 | "type": "vit_b", 29 | }, 30 | "opt": { 31 | "learning_rate": 5e-4, 32 | "weight_decay": 1e-4, 33 | "decay_factor": 10, 34 | "steps": [2000, 4000], 35 | "warmup_steps": 250, 36 | }, 37 | } 38 | 39 | cfg = Box(base_config) 40 | cfg.merge_update(config) 41 | -------------------------------------------------------------------------------- /configs/config_whu.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | from configs.base_config import base_config 3 | 4 | config = { 5 | "dataset": "WHU", 6 | "load_type": "soft", 7 | "num_points": 1, 8 | 9 | "batch_size": 1, #only support 1 10 | "val_batchsize": 1, 11 | "num_workers": 0, 12 | "num_epochs": 5, 13 | "max_nums": 30, 14 | "resume": False, 15 | 16 | "start_lora_layer": 1, 17 | "lora_rank": 4, 18 | "mem_bank_max_len": 512, 19 | "match_interval": 30, 20 | "iou_thr": 0.1, 21 | 22 | "prompt": "point", 23 | "out_dir": "", 24 | "name": "base", 25 | "corrupt": None, 26 | "visual": False, 27 | "model": { 28 | "type": "vit_b", 29 | }, 30 | "opt": { 31 | "learning_rate": 5e-4, 32 | "weight_decay": 1e-4, 33 | "decay_factor": 10, 34 | "steps": [8000,15000], 35 | "warmup_steps": 250, 36 | }, 37 | } 38 | 39 | cfg = Box(base_config) 40 | cfg.merge_update(config) 41 | -------------------------------------------------------------------------------- /datasets/HRSID.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset 7 | from pycocotools.coco import COCO 8 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_ 9 | 10 | class HRSIDDataset(Dataset): 11 | def __init__(self, cfg, root_dir, annotation_file, transform=None, training=False, if_self_training=False, gen_pt=False): 12 | self.cfg = cfg 13 | self.root_dir = root_dir 14 | self.transform = transform 15 | self.coco = COCO(annotation_file) 16 | 17 | image_ids = sorted(list(self.coco.imgs.keys())) 18 | 19 | # only for inshore 20 | if gen_pt: 21 | removed_ids = [351,375,376] 22 | else: 23 | if training: 24 | removed_ids = [351,375,376] 25 | else: 26 | removed_ids = [166] 27 | 28 | 29 | if removed_ids: 30 | for i in removed_ids: 31 | image_ids.remove(i) 32 | 33 | # Filter out image_ids without any annotations 34 | self.image_ids = [ 35 | image_id 36 | for image_id in image_ids 37 | if len(self.coco.getAnnIds(imgIds=image_id)) > 0 38 | ] 39 | self.if_self_training = if_self_training 40 | 41 | def __len__(self): 42 | return len(self.image_ids) 43 | 44 | def __getitem__(self, idx): 45 | image_id = self.image_ids[idx] 46 | image_info = self.coco.loadImgs(image_id)[0] 47 | image_path = os.path.join(self.root_dir, image_info["file_name"]) 48 | image = cv2.imread(image_path) 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | ann_ids = self.coco.getAnnIds(imgIds=image_id)#the num of objects in one image 51 | anns = self.coco.loadAnns(ann_ids) 52 | bboxes = [] 53 | masks = [] 54 | categories = [] 55 | for ann in anns: 56 | x, y, w, h = ann["bbox"] 57 | 58 | mask = self.coco.annToMask(ann) 59 | 60 | bboxes.append([x, y, x + w, y + h]) 61 | masks.append(mask) 62 | categories.append(ann["category_id"]) 63 | 64 | 65 | if self.if_self_training: 66 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 67 | if self.transform: 68 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 69 | image_strong = self.transform.transform_image(image_strong) 70 | 71 | bboxes_weak = np.stack(bboxes_weak, axis=0) 72 | masks_weak = np.stack(masks_weak, axis=0) 73 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path 74 | 75 | elif self.cfg.visual: 76 | origin_image = image 77 | origin_bboxes = bboxes 78 | origin_masks = masks 79 | if self.transform: 80 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 81 | 82 | bboxes = np.stack(bboxes, axis=0) 83 | masks = np.stack(masks, axis=0) 84 | origin_bboxes = np.stack(origin_bboxes, axis=0) 85 | origin_masks = np.stack(origin_masks, axis=0) 86 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 87 | 88 | else: 89 | if self.transform: 90 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 91 | bboxes = np.stack(bboxes, axis=0) 92 | masks = np.stack(masks, axis=0) 93 | return image, torch.tensor(bboxes), torch.tensor(masks).float(), image_path 94 | 95 | def load_datasets(cfg, img_size): 96 | transform = ResizeAndPad(img_size) 97 | train = HRSIDDataset( 98 | cfg, 99 | root_dir=cfg.datasets.HRSID.root_dir, 100 | annotation_file=cfg.datasets.HRSID.annotation_file_train, 101 | transform=transform, 102 | training=True, 103 | ) 104 | train_dataloader = DataLoader( 105 | train, 106 | batch_size=cfg.batch_size, 107 | shuffle=True, 108 | num_workers=cfg.num_workers, 109 | collate_fn=collate_fn, 110 | ) 111 | 112 | val = HRSIDDataset( 113 | cfg, 114 | root_dir=cfg.datasets.HRSID.root_dir, 115 | annotation_file=cfg.datasets.HRSID.annotation_file_val, 116 | transform=transform, 117 | ) 118 | val_dataloader = DataLoader( 119 | val, 120 | batch_size=cfg.val_batchsize, 121 | shuffle=False, 122 | num_workers=cfg.num_workers, 123 | collate_fn=collate_fn, 124 | ) 125 | return train_dataloader, val_dataloader 126 | 127 | 128 | def load_datasets_soft(cfg, img_size, return_pt = False): 129 | transform = ResizeAndPad(img_size) 130 | 131 | soft_train = HRSIDDataset( 132 | cfg, 133 | root_dir=cfg.datasets.HRSID.root_dir, 134 | annotation_file=cfg.datasets.HRSID.annotation_file_train, 135 | transform=transform, 136 | training=True, 137 | if_self_training=True, 138 | ) 139 | soft_train_dataloader = DataLoader( 140 | soft_train, 141 | batch_size=cfg.batch_size, 142 | shuffle=True, 143 | num_workers=cfg.num_workers, 144 | collate_fn=collate_fn_soft, 145 | ) 146 | 147 | val = HRSIDDataset( 148 | cfg, 149 | root_dir=cfg.datasets.HRSID.root_dir, 150 | annotation_file=cfg.datasets.HRSID.annotation_file_val, 151 | transform=transform, 152 | ) 153 | val_dataloader = DataLoader( 154 | val, 155 | batch_size=cfg.val_batchsize, 156 | shuffle=False, 157 | num_workers=cfg.num_workers, 158 | collate_fn=collate_fn, 159 | ) 160 | 161 | if return_pt: 162 | pt = HRSIDDataset( 163 | cfg, 164 | root_dir=cfg.datasets.HRSID.root_dir, 165 | annotation_file=cfg.datasets.HRSID.annotation_file_train, 166 | transform=transform, 167 | gen_pt = return_pt 168 | ) 169 | pt_dataloader = DataLoader( 170 | pt, 171 | batch_size=cfg.val_batchsize, 172 | shuffle=False, 173 | num_workers=cfg.num_workers, 174 | collate_fn=collate_fn, 175 | ) 176 | return soft_train_dataloader, val_dataloader, pt_dataloader 177 | else: 178 | return soft_train_dataloader, val_dataloader 179 | 180 | def load_datasets_visual(cfg, img_size): 181 | transform = ResizeAndPad(img_size) 182 | val = HRSIDDataset( 183 | cfg, 184 | root_dir=cfg.datasets.HRSID.root_dir, 185 | annotation_file=cfg.datasets.HRSID.annotation_file_val, 186 | transform=transform, 187 | ) 188 | val_dataloader = DataLoader( 189 | val, 190 | batch_size=cfg.val_batchsize, 191 | shuffle=False, 192 | num_workers=cfg.num_workers, 193 | collate_fn=collate_fn_, 194 | ) 195 | return val_dataloader 196 | -------------------------------------------------------------------------------- /datasets/NWPU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset 7 | from pycocotools.coco import COCO 8 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_ 9 | 10 | class NWPUDataset(Dataset): 11 | def __init__(self, cfg, root_dir, annotation_file, transform=None, training=False, if_self_training=False): 12 | self.cfg = cfg 13 | self.root_dir = root_dir 14 | self.transform = transform 15 | self.coco = COCO(annotation_file) 16 | self.image_ids = sorted(list(self.coco.imgs.keys())) 17 | 18 | self.if_self_training = if_self_training 19 | 20 | def __len__(self): 21 | return len(self.image_ids) 22 | 23 | def __getitem__(self, idx): 24 | image_id = self.image_ids[idx] 25 | image_info = self.coco.loadImgs(image_id)[0] 26 | image_path = os.path.join(self.root_dir, image_info["file_name"]) 27 | image = cv2.imread(image_path) 28 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 29 | origin_image = image 30 | ann_ids = self.coco.getAnnIds(imgIds=image_id) 31 | anns = self.coco.loadAnns(ann_ids) 32 | bboxes = [] 33 | masks = [] 34 | categories = [] 35 | for ann in anns: 36 | x, y, w, h = ann["bbox"] 37 | bboxes.append([x, y, x + w, y + h]) 38 | mask = self.coco.annToMask(ann) 39 | masks.append(mask) 40 | categories.append(ann["category_id"]) 41 | if self.if_self_training: 42 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 43 | # image_origin = image_weak 44 | 45 | if self.transform: 46 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 47 | image_strong = self.transform.transform_image(image_strong) 48 | 49 | bboxes_weak = np.stack(bboxes_weak, axis=0) 50 | masks_weak = np.stack(masks_weak, axis=0) 51 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path 52 | 53 | elif self.cfg.visual: 54 | origin_image = image 55 | origin_bboxes = bboxes 56 | origin_masks = masks 57 | if self.transform: 58 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 59 | 60 | bboxes = np.stack(bboxes, axis=0) 61 | masks = np.stack(masks, axis=0) 62 | origin_bboxes = np.stack(origin_bboxes, axis=0) 63 | origin_masks = np.stack(origin_masks, axis=0) 64 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 65 | 66 | else: 67 | if self.transform: 68 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 69 | 70 | bboxes = np.stack(bboxes, axis=0) 71 | masks = np.stack(masks, axis=0) 72 | return image, torch.tensor(bboxes), torch.tensor(masks).float(), image_path 73 | 74 | def load_datasets(cfg, img_size): 75 | transform = ResizeAndPad(img_size) 76 | train = NWPUDataset( 77 | cfg, 78 | root_dir=cfg.datasets.NWPU.root_dir, 79 | annotation_file=cfg.datasets.NWPU.annotation_file_train, 80 | transform=transform, 81 | training=True, 82 | ) 83 | train_dataloader = DataLoader( 84 | train, 85 | batch_size=cfg.batch_size, 86 | shuffle=True, 87 | num_workers=cfg.num_workers, 88 | collate_fn=collate_fn, 89 | ) 90 | 91 | val = NWPUDataset( 92 | cfg, 93 | root_dir=cfg.datasets.NWPU.root_dir, 94 | annotation_file=cfg.datasets.NWPU.annotation_file_val, 95 | transform=transform, 96 | ) 97 | val_dataloader = DataLoader( 98 | val, 99 | batch_size=cfg.val_batchsize, 100 | shuffle=False, 101 | num_workers=cfg.num_workers, 102 | collate_fn=collate_fn, 103 | ) 104 | return train_dataloader, val_dataloader 105 | 106 | 107 | def load_datasets_soft(cfg, img_size, return_pt = False): 108 | transform = ResizeAndPad(img_size) 109 | 110 | soft_train = NWPUDataset( 111 | cfg, 112 | root_dir=cfg.datasets.NWPU.root_dir, 113 | annotation_file=cfg.datasets.NWPU.annotation_file_train, 114 | transform=transform, 115 | training=True, 116 | if_self_training=True, 117 | ) 118 | soft_train_dataloader = DataLoader( 119 | soft_train, 120 | batch_size=cfg.batch_size, 121 | shuffle=True, 122 | num_workers=cfg.num_workers, 123 | collate_fn=collate_fn_soft, 124 | ) 125 | 126 | val = NWPUDataset( 127 | cfg, 128 | root_dir=cfg.datasets.NWPU.root_dir, 129 | annotation_file=cfg.datasets.NWPU.annotation_file_val, 130 | transform=transform, 131 | ) 132 | val_dataloader = DataLoader( 133 | val, 134 | batch_size=cfg.val_batchsize, 135 | shuffle=False, 136 | num_workers=cfg.num_workers, 137 | collate_fn=collate_fn, 138 | ) 139 | if return_pt: 140 | pt = NWPUDataset( 141 | cfg, 142 | root_dir=cfg.datasets.NWPU.root_dir, 143 | annotation_file=cfg.datasets.NWPU.annotation_file_train, 144 | transform=transform, 145 | ) 146 | pt_dataloader = DataLoader( 147 | pt, 148 | batch_size=cfg.val_batchsize, 149 | shuffle=False, 150 | num_workers=cfg.num_workers, 151 | collate_fn=collate_fn, 152 | ) 153 | 154 | return soft_train_dataloader, val_dataloader, pt_dataloader 155 | else: 156 | return soft_train_dataloader, val_dataloader 157 | 158 | def load_datasets_visual(cfg, img_size): 159 | transform = ResizeAndPad(img_size) 160 | val = NWPUDataset( 161 | cfg, 162 | root_dir=cfg.datasets.NWPU.root_dir, 163 | annotation_file=cfg.datasets.NWPU.annotation_file_val, 164 | transform=transform, 165 | ) 166 | val_dataloader = DataLoader( 167 | val, 168 | batch_size=cfg.val_batchsize, 169 | shuffle=False, 170 | num_workers=cfg.num_workers, 171 | collate_fn=collate_fn_, 172 | ) 173 | return val_dataloader 174 | 175 | -------------------------------------------------------------------------------- /datasets/WHU.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import Dataset 8 | from pycocotools.coco import COCO 9 | from skimage.draw import polygon2mask 10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_ 11 | 12 | 13 | class WHUDataset(Dataset): 14 | def __init__(self, cfg, root_dir, annotation_file, rate=(5, 1), transform=None, training=False, if_self_training=False, gen_pt = False): 15 | self.cfg = cfg 16 | self.root_dir = root_dir 17 | self.transform = transform 18 | self.coco = COCO(annotation_file) 19 | self.image_ids = sorted(list(self.coco.imgs.keys())) 20 | self.training = training 21 | self.gen_pt = gen_pt 22 | 23 | # # Filter out image_ids without any annotations 24 | self.image_ids = [ 25 | image_id 26 | for image_id in self.image_ids 27 | if len(self.coco.getAnnIds(imgIds=image_id)) > 0 28 | ] 29 | 30 | self.if_self_training = if_self_training 31 | 32 | def __len__(self): 33 | return len(self.image_ids) 34 | 35 | def __getitem__(self, idx): 36 | image_id = self.image_ids[idx] 37 | image_info = self.coco.loadImgs(image_id)[0] 38 | if self.training or self.gen_pt: 39 | image_path = os.path.join(self.root_dir,'train/image', image_info["file_name"]) 40 | else: 41 | image_path = os.path.join(self.root_dir,'val/image', image_info["file_name"]) 42 | image = cv2.imread(image_path) 43 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 44 | 45 | ann_ids = self.coco.getAnnIds(imgIds=image_id) 46 | anns = self.coco.loadAnns(ann_ids) 47 | bboxes = [] 48 | masks = [] 49 | categories = [] 50 | for ann in anns: 51 | x, y, w, h = ann["bbox"] 52 | bboxes.append([x, y, x + w, y + h]) 53 | mask = self.coco.annToMask(ann) 54 | masks.append(mask) 55 | categories.append(ann["category_id"]) 56 | 57 | if self.if_self_training: 58 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 59 | 60 | if self.transform: 61 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 62 | image_strong = self.transform.transform_image(image_strong) 63 | 64 | bboxes_weak = np.stack(bboxes_weak, axis=0) 65 | masks_weak = np.stack(masks_weak, axis=0) 66 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path 67 | 68 | elif self.cfg.visual: 69 | origin_image = image 70 | origin_bboxes = bboxes 71 | origin_masks = masks 72 | if self.transform: 73 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 74 | 75 | bboxes = np.stack(bboxes, axis=0) 76 | masks = np.stack(masks, axis=0) 77 | origin_bboxes = np.stack(origin_bboxes, axis=0) 78 | origin_masks = np.stack(origin_masks, axis=0) 79 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 80 | 81 | else: 82 | if self.transform: 83 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 84 | bboxes = np.stack(bboxes, axis=0) 85 | masks = np.stack(masks, axis=0) 86 | return image, torch.tensor(bboxes), torch.tensor(masks).float(),image_path 87 | 88 | def load_datasets(cfg, img_size): 89 | transform = ResizeAndPad(img_size) 90 | train = WHUDataset( 91 | cfg, 92 | root_dir=cfg.datasets.WHU.root_dir, 93 | annotation_file=cfg.datasets.WHU.annotation_file_train, 94 | transform=transform, 95 | training=True, 96 | ) 97 | train_dataloader = DataLoader( 98 | train, 99 | batch_size=cfg.batch_size, 100 | shuffle=True, 101 | num_workers=cfg.num_workers, 102 | collate_fn=collate_fn, 103 | ) 104 | 105 | val = WHUDataset( 106 | cfg, 107 | root_dir=cfg.datasets.WHU.root_dir, 108 | annotation_file=cfg.datasets.WHU.annotation_file_val, 109 | transform=transform, 110 | ) 111 | val_dataloader = DataLoader( 112 | val, 113 | batch_size=cfg.val_batchsize, 114 | shuffle=False, 115 | num_workers=cfg.num_workers, 116 | collate_fn=collate_fn, 117 | ) 118 | return train_dataloader, val_dataloader 119 | 120 | 121 | def load_datasets_soft(cfg, img_size, return_pt = False): 122 | transform = ResizeAndPad(img_size) 123 | 124 | soft_train = WHUDataset( 125 | cfg, 126 | root_dir=cfg.datasets.WHU.root_dir, 127 | annotation_file=cfg.datasets.WHU.annotation_file_train, 128 | transform=transform, 129 | training=True, 130 | if_self_training=True, 131 | ) 132 | soft_train_dataloader = DataLoader( 133 | soft_train, 134 | batch_size=cfg.batch_size, 135 | shuffle=True, 136 | pin_memory=True, 137 | num_workers=cfg.num_workers, 138 | collate_fn=collate_fn_soft, 139 | ) 140 | 141 | val = WHUDataset( 142 | cfg, 143 | root_dir=cfg.datasets.WHU.root_dir, 144 | annotation_file=cfg.datasets.WHU.annotation_file_val, 145 | transform=transform, 146 | ) 147 | val_dataloader = DataLoader( 148 | val, 149 | batch_size=cfg.val_batchsize, 150 | shuffle=False, 151 | num_workers=cfg.num_workers, 152 | collate_fn=collate_fn, 153 | ) 154 | 155 | if return_pt: 156 | pt = WHUDataset( 157 | cfg, 158 | root_dir=cfg.datasets.WHU.root_dir, 159 | annotation_file=cfg.datasets.WHU.annotation_file_train, 160 | transform=transform, 161 | gen_pt = return_pt, 162 | ) 163 | pt_dataloader = DataLoader( 164 | pt, 165 | batch_size=cfg.val_batchsize, 166 | shuffle=False, 167 | num_workers=cfg.num_workers, 168 | collate_fn=collate_fn, 169 | ) 170 | return soft_train_dataloader, val_dataloader, pt_dataloader 171 | else: 172 | return soft_train_dataloader, val_dataloader 173 | 174 | def load_datasets_visual(cfg, img_size): 175 | transform = ResizeAndPad(img_size) 176 | val = WHUDataset( 177 | cfg, 178 | root_dir=cfg.datasets.WHU.root_dir, 179 | annotation_file=cfg.datasets.WHU.annotation_file_val, 180 | transform=transform, 181 | ) 182 | val_dataloader = DataLoader( 183 | val, 184 | batch_size=cfg.val_batchsize, 185 | shuffle=False, 186 | num_workers=cfg.num_workers, 187 | collate_fn=collate_fn_, 188 | ) 189 | return val_dataloader 190 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | def call_load_dataset(cfg): 2 | name, type = cfg.dataset, cfg.load_type 3 | key = name.split("-")[0] 4 | module_name = f"datasets.{key}" 5 | 6 | if type == "load": 7 | function_name = "load_datasets" 8 | elif type == "soft": 9 | function_name = "load_datasets_soft" 10 | elif type == "visual": 11 | function_name = "load_datasets_visual" 12 | 13 | if cfg.prompt == "coarse": 14 | function_name = function_name + "_" + "coarse" 15 | 16 | exec(f"from {module_name} import {function_name}") 17 | func = eval(function_name) 18 | return func 19 | -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | # import kornia as K 5 | import albumentations as A 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from typing import List, Union 9 | from imagecorruptions import corrupt, get_corruption_names 10 | 11 | weak_transforms = A.Compose( 12 | [A.Flip()], 13 | bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]), 14 | # keypoint_params=A.KeypointParams(format='xy') 15 | ) 16 | 17 | strong_transforms = A.Compose( 18 | [ 19 | A.Posterize(), 20 | A.Equalize(), 21 | A.Sharpen(), 22 | A.Solarize(), 23 | A.RandomBrightnessContrast(), 24 | A.RandomShadow(), 25 | ] 26 | ) 27 | 28 | 29 | def corrupt_image(image, filename): 30 | file_name = os.path.basename(os.path.abspath(filename)) 31 | file_path = os.path.dirname(os.path.abspath(filename)) 32 | for corruption in get_corruption_names(): 33 | corrupted = corrupt(image, severity=5, corruption_name=corruption) 34 | corrupt_path = file_path.replace( 35 | "val2017", os.path.join("corruption", corruption) 36 | ) 37 | if not os.path.exists(corrupt_path): 38 | os.makedirs(corrupt_path, exist_ok=True) 39 | cv2.imwrite(os.path.join(corrupt_path, file_name), corrupted) 40 | -------------------------------------------------------------------------------- /datasets/tools.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from segment_anything.utils.transforms import ResizeLongestSide 7 | from datasets.augmentation import weak_transforms, strong_transforms 8 | import torchvision.transforms.functional as F 9 | 10 | class ResizeAndPad: 11 | 12 | def __init__(self, target_size): 13 | self.target_size = target_size 14 | self.transform = ResizeLongestSide(target_size) 15 | self.to_tensor = transforms.ToTensor() 16 | 17 | def __call__(self, image, masks, bboxes=None, visual=False): 18 | # Resize image and masks 19 | og_h, og_w, _ = image.shape 20 | image = self.transform.apply_image(image) 21 | masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks] 22 | image = self.to_tensor(image) 23 | 24 | # Pad image and masks to form a square 25 | _, h, w = image.shape 26 | max_dim = max(w, h) 27 | pad_w = (max_dim - w) // 2 28 | pad_h = (max_dim - h) // 2 29 | 30 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) 31 | image = transforms.Pad(padding)(image) 32 | masks = [transforms.Pad(padding)(mask) for mask in masks] 33 | 34 | # Adjust bounding boxes 35 | if bboxes is not None: 36 | bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w)) 37 | bboxes = [ 38 | [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] 39 | for bbox in bboxes 40 | ] 41 | if visual: 42 | return padding, image, masks, bboxes 43 | else: 44 | return image, masks, bboxes 45 | else: 46 | if visual: 47 | return padding, image, masks 48 | else: 49 | return image, masks 50 | 51 | def transform_image(self, image): 52 | # Resize image and masks 53 | image = self.transform.apply_image(image) 54 | image = self.to_tensor(image) 55 | 56 | # Pad image and masks to form a square 57 | _, h, w = image.shape 58 | max_dim = max(w, h) 59 | pad_w = (max_dim - w) // 2 60 | pad_h = (max_dim - h) // 2 61 | 62 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) 63 | image = transforms.Pad(padding)(image) 64 | return image 65 | 66 | def transform_coord(self, points, image): 67 | og_h, og_w, _ = image.shape 68 | coords = points.reshape(1, -1, 2) 69 | points = self.transform.apply_coords(coords, (og_h, og_w)) 70 | return points.reshape(-1, 2) 71 | 72 | def transform_coords(self, points, image, n): 73 | og_h, og_w, _ = image.shape 74 | coords = points.reshape(-1, n, 2) 75 | points = self.transform.apply_coords(coords, (og_h, og_w)) 76 | return points.reshape(-1, n, 2) 77 | 78 | 79 | def soft_transform( 80 | image: np.ndarray, bboxes: list, masks: list, categories: list 81 | ): 82 | weak_transformed = weak_transforms( 83 | image=image, bboxes=bboxes, masks=masks, category_ids=categories) 84 | image_weak = weak_transformed["image"] 85 | bboxes_weak = weak_transformed["bboxes"] 86 | masks_weak = weak_transformed["masks"] 87 | 88 | strong_transformed = strong_transforms(image=image_weak) 89 | image_strong = strong_transformed["image"] 90 | return image_weak, bboxes_weak, masks_weak, image_strong 91 | 92 | 93 | def soft_transform_all( 94 | image: np.ndarray, bboxes: list, masks: list, points: list, categories: list 95 | ): 96 | weak_transformed = weak_transforms( 97 | image=image, bboxes=bboxes, masks=masks, category_ids=categories, keypoints=points) 98 | image_weak = weak_transformed["image"] 99 | bboxes_weak = weak_transformed["bboxes"] 100 | masks_weak = weak_transformed["masks"] 101 | keypoints_weak = weak_transformed["keypoints"] 102 | 103 | strong_transformed = strong_transforms(image=image_weak) 104 | image_strong = strong_transformed["image"] 105 | return image_weak, bboxes_weak, masks_weak, keypoints_weak, image_strong 106 | 107 | 108 | def collate_fn(batch): 109 | images, bboxes, masks, img_paths= zip(*batch) 110 | images = torch.stack(images) 111 | return images, bboxes, masks, img_paths 112 | 113 | 114 | def collate_fn_soft(batch): 115 | images_soft, images, bboxes, masks, img_paths = zip(*batch) 116 | images = torch.stack(images) 117 | images_soft = torch.stack(images_soft) 118 | return images_soft, images, bboxes, masks, img_paths 119 | 120 | 121 | def collate_fn_coarse(batch): 122 | images, bboxes, masks, coarse_masks = zip(*batch) 123 | images = torch.stack(images) 124 | return images, bboxes, masks, coarse_masks 125 | 126 | 127 | def collate_fn_(batch): 128 | return zip(*batch) 129 | 130 | 131 | def decode_mask(mask): 132 | """ 133 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects 134 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects. 135 | 136 | Args: 137 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 138 | 139 | Returns: 140 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 141 | """ 142 | unique_labels = torch.unique(mask) 143 | unique_labels = unique_labels[unique_labels != 0] 144 | n_objects = len(unique_labels) 145 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64) 146 | for i, label in enumerate(unique_labels): 147 | new_mask[i] = (mask == label).squeeze(0) 148 | return new_mask 149 | 150 | 151 | def encode_mask(mask): 152 | """ 153 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects 154 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 155 | 156 | Args: 157 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 158 | 159 | Returns: 160 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 161 | """ 162 | n_objects = mask.shape[0] 163 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64) 164 | for i in range(n_objects): 165 | new_mask[0][mask[i] == 1] = i + 1 166 | return new_mask 167 | 168 | 169 | if __name__ == "__main__": 170 | mask_encode = np.array([[[0, 0, 1], [2, 0, 2], [0, 3, 3]]]) 171 | mask_decode = np.array([[[0, 0, 1], [0, 0, 0], [0, 0, 0]], 172 | [[0, 0, 0], [1, 0, 1], [0, 0, 0]], 173 | [[0, 0, 0], [0, 0, 0], [0, 1, 1]]]) 174 | encoded_mask = encode_mask(torch.tensor(mask_decode)) 175 | decoded_mask = decode_mask(torch.tensor(mask_encode)) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import copy 4 | import torch 5 | import argparse 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import segmentation_models_pytorch as smp 9 | import lightning as L 10 | from lightning.fabric.loggers import TensorBoardLogger 11 | from lightning.fabric.fabric import _FabricOptimizer 12 | import matplotlib.pyplot as plt 13 | from torch.utils.data import DataLoader 14 | from box import Box 15 | from datasets import call_load_dataset 16 | from utils.model import Model 17 | from utils.tools import copy_model, create_csv, reduce_instances 18 | from utils.eval_utils import AverageMeter 19 | from utils.sample_utils import uniform_sampling 20 | from tqdm import tqdm 21 | 22 | torch.set_float32_matmul_precision('high') 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 24 | 25 | def compute_centroids(masks): 26 | centroids = [] 27 | for mask in masks: 28 | if not isinstance(mask, np.ndarray): 29 | mask = mask.cpu().numpy().astype(np.uint8) 30 | 31 | # Find connected components 32 | num_labels, _, _, centroids_data = cv2.connectedComponentsWithStats(mask, connectivity=8) 33 | 34 | # Extract centroids (skip background label 0) 35 | component_centroids = centroids_data[1:] # Skip the first entry, which is for the background 36 | 37 | # Append centroids as a list 38 | centroids.append(component_centroids.tolist()) 39 | 40 | return centroids 41 | 42 | def plotwithpoint(fabric: L.Fabric, anchor_model: Model, model: Model, val_dataloader: DataLoader): 43 | model.eval() 44 | anchor_model.eval() 45 | ious = AverageMeter() 46 | f1_scores = AverageMeter() 47 | num_points = cfg.num_points 48 | transform = val_dataloader.dataset.transform 49 | 50 | save_path = cfg.out_dir 51 | if not os.path.exists(save_path): 52 | os.makedirs(save_path) 53 | 54 | with torch.no_grad(): 55 | for iter, data in enumerate(tqdm(val_dataloader, desc="Processing", unit="batch")): 56 | image_ids, paddings, ori_images, ori_bboxes, origin_masks, images, bboxes, gt_masks = data 57 | images = torch.stack(images).to(device=fabric.device) 58 | num_images = images.size(0) 59 | 60 | prompts = [] 61 | for mask in gt_masks: 62 | try: 63 | po_points = compute_centroids(mask) 64 | po_point_coords = torch.tensor(po_points, device=fabric.device) 65 | except: 66 | continue 67 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points) 68 | na_point_coords = torch.tensor(na_points, device=fabric.device) 69 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1) 70 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=fabric.device) 71 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=fabric.device) 72 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1) 73 | in_points = (point_coords, point_labels) 74 | prompts.append(in_points) 75 | 76 | _, base_masks, _, _ = anchor_model(images, prompts) 77 | _, pred_masks, _, _ = model(images, prompts) 78 | 79 | draw_points = [] 80 | for ori_mask in origin_masks: 81 | ori_po_points = uniform_sampling(ori_mask, num_points) 82 | draw_points.append(ori_po_points) 83 | 84 | # for pred_mask, gt_mask in zip(pred_masks, gt_masks): 85 | # batch_stats = smp.metrics.get_stats( 86 | # pred_mask, 87 | # gt_mask.to(device=fabric.device).int(), 88 | # mode='binary', 89 | # threshold=0.5, 90 | # ) 91 | # batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise") 92 | # batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise") 93 | # ious.update(batch_iou, num_images) 94 | # f1_scores.update(batch_f1, num_images) 95 | # fabric.print( 96 | # f'Val:[{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]' 97 | # ) 98 | # torch.cuda.empty_cache() 99 | 100 | for image_id, padding, base_mask, pred_mask, ori_mask, points, image in zip(image_ids, paddings, base_masks, pred_masks, origin_masks, draw_points, ori_images): 101 | H, W, C = image.shape 102 | 103 | base_mask = base_mask.unsqueeze(1) 104 | base_mask = base_mask[..., padding[1] : base_mask.shape[-2] - padding[3], padding[0] : base_mask.shape[-1] - padding[2]] 105 | base_mask = F.interpolate(base_mask, (H, W), mode="bilinear", align_corners=False) 106 | 107 | pred_mask = pred_mask.unsqueeze(1) 108 | pred_mask = pred_mask[..., padding[1] : pred_mask.shape[-2] - padding[3], padding[0] : pred_mask.shape[-1] - padding[2]] 109 | pred_mask = F.interpolate(pred_mask, (H, W), mode="bilinear", align_corners=False) 110 | 111 | fig, axs = plt.subplots(1, 4) 112 | fig.set_size_inches(W/100.0*4, H/100.0) 113 | 114 | image_0 = copy.deepcopy(image) 115 | image_1 = copy.deepcopy(image) 116 | image_2 = copy.deepcopy(image) 117 | image_3 = copy.deepcopy(image) 118 | axs[0].imshow(image_0) 119 | axs[1].imshow(image_1) 120 | axs[2].imshow(image_2) 121 | axs[3].imshow(image_3) 122 | axs[0].axis('off') 123 | axs[1].axis('off') 124 | axs[2].axis('off') 125 | axs[3].axis('off') 126 | 127 | masked_image_1 = np.zeros((H, W, 4)) 128 | masked_image_2 = np.zeros((H, W, 4)) 129 | masked_image_3 = np.zeros((H, W, 4)) 130 | for point, ori_mask_i, base_mask_i, pred_mask_i in zip(points, ori_mask, base_mask, pred_mask): 131 | color = np.random.random(3) 132 | x_coords = [] 133 | y_coords = [] 134 | for point_i in point: 135 | x, y = point_i 136 | x_coords.append(x) 137 | y_coords.append(y) 138 | point_color = np.concatenate([color, [1.0]]) 139 | axs[0].scatter(x_coords, y_coords, color=point_color) 140 | 141 | base_mask_i = (base_mask_i.squeeze(0) > 0.).cpu().numpy().astype(bool) 142 | pred_mask_i = (pred_mask_i.squeeze(0) > 0.).cpu().numpy().astype(bool) 143 | ori_mask_i = ori_mask_i.astype(bool) 144 | mask_color = np.concatenate([color, [0.7]]) 145 | 146 | masked_image_1[ori_mask_i] = mask_color 147 | axs[1].imshow(masked_image_1) 148 | masked_image_2[base_mask_i] = mask_color 149 | axs[2].imshow(masked_image_2) 150 | masked_image_3[pred_mask_i] = mask_color 151 | axs[3].imshow(masked_image_3) 152 | 153 | plt.subplots_adjust(wspace=0) 154 | plt.savefig(os.path.join(save_path, f"{image_id}.jpg"), dpi=100, bbox_inches='tight', pad_inches=0) 155 | plt.close(fig) 156 | 157 | def main(cfg: Box, args) -> None: 158 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())] 159 | num_devices = len(gpu_ids) 160 | 161 | fabric = L.Fabric(accelerator="auto", 162 | devices=num_devices, 163 | strategy="auto", 164 | loggers=[TensorBoardLogger(cfg.out_dir)]) 165 | fabric.launch() 166 | fabric.seed_everything(1337 + fabric.global_rank) 167 | 168 | with fabric.device: 169 | anchor_model = Model(cfg) 170 | anchor_model.setup() 171 | 172 | model = Model(cfg) 173 | model.setup() 174 | full_checkpoint = fabric.load(args.ckpt) 175 | model.load_state_dict(full_checkpoint["model"]) 176 | 177 | load_datasets = call_load_dataset(cfg) 178 | val_data = load_datasets(cfg, model.model.image_encoder.img_size) 179 | val_data = fabric._setup_dataloader(val_data) 180 | 181 | plotwithpoint(fabric, anchor_model, model, val_data) 182 | 183 | def parse_args(): 184 | 185 | parser = argparse.ArgumentParser(description='Test a detector with specified config and checkpoint.') 186 | parser.add_argument( 187 | '--cfg', 188 | type=str, 189 | default="configs.config_hrsid", 190 | help='Path to the configuration file (e.g., "configs.config_hrsid").' 191 | ) 192 | parser.add_argument( 193 | '--out_dir', 194 | type=str, 195 | default="output", 196 | help='Directory to save predicted results.' 197 | ) 198 | parser.add_argument( 199 | '--ckpt', 200 | type=str, 201 | default="checkpoints/best-ckpt.pth", 202 | help='Path to the model checkpoint file.' 203 | ) 204 | 205 | args = parser.parse_args() 206 | return args 207 | 208 | if __name__ == "__main__": 209 | torch.cuda.empty_cache() 210 | args = parse_args() 211 | args_dict = vars(args) 212 | exec(f'from {args.cfg} import cfg') 213 | cfg.merge_update(args_dict) 214 | cfg.visual = True 215 | cfg.load_type = 'visual' 216 | main(cfg, args) -------------------------------------------------------------------------------- /pretrain/Where_To_Save_Pretrained_SAM_Checkpoints: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lans1ng/PointSAM/d6831edf23de4bd7304fb46a5ce2a0baa69dcbd5/pretrain/Where_To_Save_Pretrained_SAM_Checkpoints -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-box==7.0.1 2 | pycocotools==2.0.6 3 | numpy==1.26.4 4 | opencv_python==4.7.0.72 5 | Pillow>=9.4.0 6 | lightning==2.0.1 7 | segmentation-models-pytorch==0.3.2 8 | albumentations==1.3.1 9 | imagecorruptions==1.1.2 10 | safetensors==0.4.1 11 | torchsummary==1.5.1 12 | tensorboard==2.14.0 13 | pandas 14 | -------------------------------------------------------------------------------- /scripts/train_hrsid_pointsam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_hrsid" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/hrsid/pointsam") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_hrsid_selftrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_hrsid" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/hrsid/selftrain") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_hrsid_supervise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_hrsid" 4 | prompt="point" 5 | load_type="load" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/hrsid/supervise") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_nwpu_pointsam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_nwpu" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/nwpu/pointsam") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_nwpu_selftrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_nwpu" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/nwpu/selftrain") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_nwpu_supervise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_nwpu" 4 | prompt="point" 5 | load_type="load" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/nwpu/supervise") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_whu_pointsam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_whu" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/whu/pointsam") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_whu_selftrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_whu" 4 | prompt="point" 5 | load_type="soft" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/whu/selftrain") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/train_whu_supervise.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cfg_file="configs.config_nwpu" 4 | prompt="point" 5 | load_type="load" 6 | num_points_list=(1 2 3) 7 | output_dirs=("work_dir/nwpu/supervise") 8 | 9 | for output_dir in "${output_dirs[@]}"; do 10 | for num_points in "${num_points_list[@]}"; do 11 | out_dir="${output_dir}/point_${num_points}" 12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | # sam.load_state_dict(state_dict) 107 | # Create a new state dictionary with only the parameters that exist in the model 108 | new_state_dict = {k: v for k, v in state_dict.items() if k in sam.state_dict() and sam.state_dict()[k].shape == v.shape} 109 | sam.load_state_dict(new_state_dict, strict = False) 110 | return sam 111 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | # from .image_encoder_adapter import ImageEncoderViT 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | # print('points.shape',points.shape) 85 | # print('padding_point',padding_point.shape) 86 | # print('point',points) 87 | # print('padding_point',padding_point) 88 | points = torch.cat([points, padding_point], dim=1) 89 | labels = torch.cat([labels, padding_label], dim=1) 90 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 91 | point_embedding[labels == -1] = 0.0 92 | point_embedding[labels == -1] += self.not_a_point_embed.weight 93 | point_embedding[labels == 0] += self.point_embeddings[0].weight 94 | point_embedding[labels == 1] += self.point_embeddings[1].weight 95 | return point_embedding 96 | 97 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 98 | """Embeds box prompts.""" 99 | boxes = boxes + 0.5 # Shift to center of pixel 100 | coords = boxes.reshape(-1, 2, 2) 101 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 102 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 103 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 104 | return corner_embedding 105 | 106 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 107 | """Embeds mask inputs.""" 108 | mask_embedding = self.mask_downscaling(masks) 109 | return mask_embedding 110 | 111 | def _get_batch_size( 112 | self, 113 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 114 | boxes: Optional[torch.Tensor], 115 | masks: Optional[torch.Tensor], 116 | ) -> int: 117 | """ 118 | Gets the batch size of the output given the batch size of the input prompts. 119 | """ 120 | if points is not None: 121 | return points[0].shape[0] 122 | elif boxes is not None: 123 | return boxes.shape[0] 124 | elif masks is not None: 125 | return masks.shape[0] 126 | else: 127 | return 1 128 | 129 | def _get_device(self) -> torch.device: 130 | return self.point_embeddings[0].weight.device 131 | 132 | def forward( 133 | self, 134 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 135 | boxes: Optional[torch.Tensor], 136 | masks: Optional[torch.Tensor], 137 | ) -> Tuple[torch.Tensor, torch.Tensor]: 138 | """ 139 | Embeds different types of prompts, returning both sparse and dense 140 | embeddings. 141 | 142 | Arguments: 143 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 144 | and labels to embed. 145 | boxes (torch.Tensor or none): boxes to embed 146 | masks (torch.Tensor or none): masks to embed 147 | 148 | Returns: 149 | torch.Tensor: sparse embeddings for the points and boxes, with shape 150 | BxNx(embed_dim), where N is determined by the number of input points 151 | and boxes. 152 | torch.Tensor: dense embeddings for the masks, in the shape 153 | Bx(embed_dim)x(embed_H)x(embed_W) 154 | """ 155 | bs = self._get_batch_size(points, boxes, masks) 156 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 157 | if points is not None: 158 | coords, labels = points 159 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 160 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 161 | if boxes is not None: 162 | box_embeddings = self._embed_boxes(boxes) 163 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 164 | 165 | if masks is not None: 166 | dense_embeddings = self._embed_masks(masks) 167 | else: 168 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 169 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 170 | ) 171 | 172 | return sparse_embeddings, dense_embeddings 173 | 174 | 175 | class PositionEmbeddingRandom(nn.Module): 176 | """ 177 | Positional encoding using random spatial frequencies. 178 | """ 179 | 180 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 181 | super().__init__() 182 | if scale is None or scale <= 0.0: 183 | scale = 1.0 184 | self.register_buffer( 185 | "positional_encoding_gaussian_matrix", 186 | scale * torch.randn((2, num_pos_feats)), 187 | ) 188 | 189 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 190 | """Positionally encode points that are normalized to [0,1].""" 191 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 192 | coords = 2 * coords - 1 193 | coords = coords @ self.positional_encoding_gaussian_matrix 194 | coords = 2 * np.pi * coords 195 | # outputs d_1 x ... x d_n x C shape 196 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 197 | 198 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 199 | """Generate positional encoding for a grid of the specified size.""" 200 | h, w = size 201 | device: Any = self.positional_encoding_gaussian_matrix.device 202 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 203 | y_embed = grid.cumsum(dim=0) - 0.5 204 | x_embed = grid.cumsum(dim=1) - 0.5 205 | y_embed = y_embed / h 206 | x_embed = x_embed / w 207 | 208 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 209 | return pe.permute(2, 0, 1) # C x H x W 210 | 211 | def forward_with_coords( 212 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 213 | ) -> torch.Tensor: 214 | """Positionally encode points that are not normalized to [0,1].""" 215 | coords = coords_input.clone() 216 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 217 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 218 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 219 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /train_pointsam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | from abc import ABC 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | 13 | import lightning as L 14 | from lightning.fabric.loggers import TensorBoardLogger 15 | from lightning.fabric.fabric import _FabricOptimizer 16 | 17 | from box import Box 18 | from datasets import call_load_dataset 19 | from utils.model import Model 20 | from utils.losses import DiceLoss, FocalLoss, Matching_Loss 21 | from utils.eval_utils import AverageMeter, validate, get_prompts, calc_iou 22 | from utils.tools import copy_model, create_csv, reduce_instances 23 | from utils.utils import * 24 | from utils.finch import FINCH 25 | 26 | vis = False 27 | 28 | def train_sam( 29 | cfg: Box, 30 | fabric: L.Fabric, 31 | model: Model, 32 | optimizer: _FabricOptimizer, 33 | scheduler: _FabricOptimizer, 34 | train_dataloader: DataLoader, 35 | val_dataloader: DataLoader, 36 | target_pts, 37 | ): 38 | 39 | focal_loss = FocalLoss() 40 | dice_loss = DiceLoss() 41 | max_iou = 0. 42 | mem_bank = Store(1, cfg.mem_bank_max_len) 43 | match_interval = cfg.match_interval 44 | for epoch in range(1, cfg.num_epochs + 1): 45 | batch_time = AverageMeter() 46 | data_time = AverageMeter() 47 | focal_losses = AverageMeter() 48 | dice_losses = AverageMeter() 49 | iou_losses = AverageMeter() 50 | total_losses = AverageMeter() 51 | match_losses = AverageMeter() 52 | end = time.time() 53 | num_iter = len(train_dataloader) 54 | 55 | for iter, data in enumerate(train_dataloader): 56 | 57 | data_time.update(time.time() - end) 58 | images_weak, images_strong, bboxes, gt_masks, img_paths= data 59 | del data 60 | 61 | batch_size = images_weak.size(0) 62 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks) 63 | if num_insts > cfg.max_nums: 64 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums) 65 | prompts = get_prompts(cfg, bboxes, gt_masks) 66 | 67 | #1. caculate pairwise IoUs of masks 68 | mask_ious, init_masks = cal_mask_ious(cfg, model, images_weak, prompts, gt_masks) 69 | 70 | #2. get new prompts through neg_prompt_calibration 71 | new_prompts = neg_prompt_calibration(cfg, mask_ious, prompts) 72 | 73 | #3. start training using new prompt 74 | soft_image_embeds, soft_masks, _, _ = model(images_weak, new_prompts) # teacher 75 | 76 | if isinstance(soft_image_embeds, dict): 77 | soft_image_embeds = soft_image_embeds['vision_features'] 78 | 79 | _, pred_masks, iou_predictions, _= model(images_strong, prompts) # student 80 | 81 | del _ 82 | 83 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks) 84 | loss_focal = torch.tensor(0., device=fabric.device) 85 | loss_dice = torch.tensor(0., device=fabric.device) 86 | loss_match = torch.tensor(0., device=fabric.device) 87 | loss_iou = torch.tensor(0., device=fabric.device) 88 | 89 | for i, (embed, pred_mask, soft_mask, gt_mask, prompt, iou_prediction) in enumerate(zip(soft_image_embeds, pred_masks, soft_masks, gt_masks, prompts, iou_predictions)): 90 | 91 | soft_mask = (soft_mask > 0.).float() 92 | pred_feats = generate_predict_feats(cfg, embed, soft_mask, prompt) 93 | target_pts_ = target_pts['target_pts'] 94 | pred_feats = pred_feats.cpu().tolist() 95 | for pred_feat in pred_feats: 96 | mem_bank.add([[pred_feat]], [0]) 97 | if len(mem_bank.retrieve(0)) >= cfg.mem_bank_max_len and (iter + 1) % match_interval == 0: 98 | pred_feats = mem_bank.retrieve(0) 99 | pred_feats = np.array(pred_feats) 100 | 101 | #FINCH 102 | fin = FINCH(verbose=False) 103 | results = fin.fit(pred_feats) 104 | last_key = list(results.partitions.keys())[-1] 105 | pred_pts = results.partitions[last_key]['cluster_centers'] 106 | loss_match += Matching_Loss(pred_pts, target_pts_, device = fabric.device) 107 | 108 | del embed 109 | 110 | if vis: 111 | img_name = os.path.basename(img_paths[i]).split('.')[0] 112 | 113 | image_weak = images_weak[0].permute(1,2,0).cpu().numpy()* 255 114 | image_weak = cv2.cvtColor(image_weak, cv2.COLOR_BGR2RGB) 115 | 116 | if vis: 117 | for j in range(len(soft_mask)): 118 | mask_iou = torch.max(mask_ious[j]) 119 | image_weak_ = image_weak.copy() 120 | mask_area = torch.sum(gt_mask[j]) 121 | 122 | gt_mask_np = gt_mask[j].cpu().numpy() * 255 123 | gt_mask_img = cv2.cvtColor(gt_mask_np, cv2.COLOR_GRAY2RGB) 124 | 125 | init_prompt_po = prompts[0][0][j][:cfg.num_points] 126 | init_prompt_ne = prompts[0][0][j][cfg.num_points:] 127 | 128 | for po in init_prompt_po: 129 | cv2.circle(image_weak_, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1) 130 | 131 | init_mask_img = init_masks[j].cpu().detach().numpy() * 255 132 | init_mask_img = cv2.cvtColor(init_mask_img, cv2.COLOR_GRAY2RGB) 133 | for po,ne in zip(init_prompt_po,init_prompt_ne): 134 | cv2.circle(init_mask_img, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1) 135 | cv2.circle(init_mask_img, (int(ne[0]), int(ne[1])), 12, (0, 255, 0), -1) 136 | 137 | prompt_po = new_prompts[0][0][j][:cfg.num_points] 138 | prompt_ne = new_prompts[0][0][j][cfg.num_points:] 139 | soft_mask_img = soft_mask[j].cpu().detach().numpy() * 255 140 | 141 | soft_mask_img = cv2.cvtColor(soft_mask_img, cv2.COLOR_GRAY2RGB) 142 | 143 | for po,ne in zip(prompt_po,prompt_ne): 144 | cv2.circle(soft_mask_img, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1) 145 | cv2.circle(soft_mask_img, (int(ne[0]), int(ne[1])), 12, (0, 255, 0), -1) 146 | 147 | output_dir = "./save_mask_{}/{}/{}/".format(str(float(cfg.num_points)),cfg.dataset,str(epoch)) 148 | if not os.path.exists(output_dir): 149 | os.makedirs(output_dir) 150 | merged_image = concatenate_images_with_padding([image_weak_, gt_mask_img, init_mask_img, soft_mask_img]) 151 | img_name_ = '{}_{}_iou{}.jpg'.format(img_name,str(j),str(mask_iou)) 152 | if mask_iou>float(cfg.iou_thr) and mask_area>3000: 153 | cv2.imwrite(os.path.join(output_dir,img_name_), merged_image) 154 | del init_masks, mask_ious 155 | 156 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks) 157 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks) 158 | batch_iou = calc_iou(pred_mask, soft_mask) 159 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks 160 | 161 | del soft_image_embeds, pred_masks, iou_predictions, gt_masks 162 | 163 | loss_total = 20. * loss_focal + loss_dice + loss_iou + 0.1*loss_match 164 | 165 | fabric.backward(loss_total) 166 | 167 | optimizer.step() 168 | scheduler.step() 169 | optimizer.zero_grad() 170 | torch.cuda.empty_cache() 171 | 172 | batch_time.update(time.time() - end) 173 | end = time.time() 174 | 175 | focal_losses.update(loss_focal.item(), batch_size) 176 | dice_losses.update(loss_dice.item(), batch_size) 177 | iou_losses.update(loss_iou.item(), batch_size) 178 | total_losses.update(loss_total.item(), batch_size) 179 | match_losses.update(loss_match.item(), batch_size) 180 | 181 | if (iter+1) %match_interval==0: 182 | fabric.print(f'Epoch: [{epoch}][{iter + 1}/{len(train_dataloader)}]' 183 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 184 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' 185 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' 186 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' 187 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]' 188 | f' | Match Loss [{match_losses.val:.4f} ({match_losses.avg:.4f})]' 189 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]') 190 | 191 | # loss_logger = { 192 | # "Focal Loss": focal_losses.avg, 193 | # "Dice Loss": dice_losses.avg, 194 | # "IoU Loss": iou_losses.avg, 195 | # "Total Loss": total_losses.avg 196 | # } 197 | # fabric.log_dict(loss_logger, num_iter * (epoch - 1) + iter) 198 | torch.cuda.empty_cache() 199 | 200 | if epoch % cfg.eval_interval == 0: 201 | iou, _= validate(fabric, cfg, model, val_dataloader, cfg.name, epoch) 202 | if iou > max_iou: 203 | state = {"model": model, "optimizer": optimizer} 204 | fabric.save(os.path.join(cfg.out_dir, "save", "best-ckpt.pth"), state) 205 | max_iou = iou 206 | del iou 207 | 208 | def configure_opt(cfg: Box, model: Model): 209 | 210 | def lr_lambda(step): 211 | if step < cfg.opt.warmup_steps: 212 | return step / cfg.opt.warmup_steps 213 | elif step < cfg.opt.steps[0]: 214 | return 1.0 215 | elif step < cfg.opt.steps[1]: 216 | return 1 / cfg.opt.decay_factor 217 | else: 218 | return 1 / (cfg.opt.decay_factor**2) 219 | 220 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 221 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 222 | 223 | return optimizer, scheduler 224 | 225 | 226 | def corrupt_main(cfg): 227 | for corrupt in cfg.corruptions: 228 | cfg.corrupt = corrupt 229 | cfg.out_name = corrupt 230 | torch.cuda.empty_cache() 231 | main(cfg) 232 | 233 | 234 | def main(cfg: Box) -> None: 235 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())] 236 | num_devices = len(gpu_ids) 237 | fabric = L.Fabric(accelerator="auto", 238 | devices=num_devices, 239 | strategy="auto", 240 | loggers=[TensorBoardLogger(cfg.out_dir)]) 241 | fabric.launch() 242 | fabric.seed_everything(1337 + fabric.global_rank) 243 | 244 | if fabric.global_rank == 0: 245 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True) 246 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys) 247 | 248 | with fabric.device: 249 | model = Model(cfg) 250 | model.setup() 251 | 252 | load_datasets = call_load_dataset(cfg) 253 | train_data, val_data, pt_data = load_datasets(cfg, img_size=1024, return_pt = True) 254 | train_data = fabric._setup_dataloader(train_data) 255 | val_data = fabric._setup_dataloader(val_data) 256 | pt_data = fabric._setup_dataloader(pt_data) 257 | optimizer, scheduler = configure_opt(cfg, model) 258 | model, optimizer = fabric.setup(model, optimizer) 259 | 260 | if cfg.resume and cfg.model.ckpt is not None: 261 | full_checkpoint = fabric.load(cfg.model.ckpt) 262 | model.load_state_dict(full_checkpoint["model"]) 263 | optimizer.load_state_dict(full_checkpoint["optimizer"]) 264 | print('-'*100) 265 | print('\033[92mDirect test on the original SAM.\033[0m') 266 | _, _, = validate(fabric, cfg, model, val_data, name=cfg.name, epoch=0) 267 | print('-'*100) 268 | del _ 269 | 270 | target_pts = offline_prototypes_generation(cfg, model, pt_data) 271 | 272 | train_sam(cfg, fabric, model, optimizer, scheduler, train_data, val_data, target_pts) 273 | 274 | del model, train_data, val_data 275 | 276 | 277 | def parse_args(): 278 | parser = argparse.ArgumentParser(description='Train a detector') 279 | parser.add_argument('--cfg', help='train config file path') 280 | parser.add_argument('--prompt', help='the type of prompt') 281 | parser.add_argument('--num_points',type=int, help='the number of points') 282 | parser.add_argument('--out_dir', help='the dir to save logs and models') 283 | parser.add_argument('--load_type', help='the dir to save logs and models') 284 | args = parser.parse_args() 285 | return args 286 | 287 | if __name__ == "__main__": 288 | print(torch.cuda.current_device()) 289 | torch.cuda.empty_cache() 290 | torch.set_float32_matmul_precision('high') 291 | args = parse_args() 292 | 293 | exec(f'from {args.cfg} import cfg') 294 | 295 | # transfer the args to a dict 296 | args_dict = vars(args) 297 | cfg.merge_update(args_dict) 298 | 299 | main(cfg) 300 | torch.cuda.empty_cache() 301 | -------------------------------------------------------------------------------- /train_selftrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from box import Box 7 | import lightning as L 8 | from lightning.fabric.fabric import _FabricOptimizer 9 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger 10 | from torch.utils.data import DataLoader 11 | import cv2 12 | import argparse 13 | from datasets import call_load_dataset 14 | from utils.losses import DiceLoss, FocalLoss, ContraLoss 15 | from utils.model import Model 16 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts, calc_iou_matrix 17 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances 18 | 19 | 20 | 21 | def train_sam( 22 | cfg: Box, 23 | fabric: L.Fabric, 24 | model: Model, 25 | teacher_model: Model, 26 | optimizer: _FabricOptimizer, 27 | scheduler: _FabricOptimizer, 28 | train_dataloader: DataLoader, 29 | val_dataloader: DataLoader, 30 | ): 31 | """The SAM training loop.""" 32 | 33 | focal_loss = FocalLoss() 34 | dice_loss = DiceLoss() 35 | max_iou = 0. 36 | max_camo = 0. 37 | max_chame = 0. 38 | max_cod = 0. 39 | 40 | for epoch in range(1, cfg.num_epochs + 1): 41 | batch_time = AverageMeter() 42 | data_time = AverageMeter() 43 | focal_losses = AverageMeter() 44 | dice_losses = AverageMeter() 45 | iou_losses = AverageMeter() 46 | total_losses = AverageMeter() 47 | end = time.time() 48 | num_iter = len(train_dataloader) 49 | 50 | for iter, data in enumerate(train_dataloader): 51 | 52 | data_time.update(time.time() - end) 53 | images_weak, images_strong, bboxes, gt_masks, img_paths= data 54 | batch_size = images_weak.size(0) 55 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks) 56 | if num_insts > 100: 57 | continue 58 | if num_insts > cfg.max_nums: 59 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums) 60 | 61 | prompts = get_prompts(cfg, bboxes, gt_masks) 62 | 63 | soft_image_embeds, soft_masks, soft_iou_predictions, soft_res_masks = teacher_model(images_weak, prompts) # teacher 64 | pred_image_embeds, pred_masks, iou_predictions, pred_res_masks = model(images_strong, prompts) # student 65 | 66 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks) 67 | loss_focal = torch.tensor(0., device=fabric.device) 68 | loss_dice = torch.tensor(0., device=fabric.device) 69 | loss_iou = torch.tensor(0., device=fabric.device) 70 | for pred_mask, soft_mask, iou_prediction in zip(pred_masks, soft_masks, iou_predictions): 71 | soft_mask = (soft_mask > 0.).float() 72 | 73 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks) 74 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks) 75 | batch_iou = calc_iou(pred_mask, soft_mask) 76 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks 77 | 78 | loss_total = 20. * loss_focal + loss_dice + loss_iou 79 | fabric.backward(loss_total) 80 | 81 | # if iter + 1 % 5 == 0: 82 | optimizer.step() 83 | scheduler.step() 84 | optimizer.zero_grad() 85 | torch.cuda.empty_cache() 86 | 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | 90 | focal_losses.update(loss_focal.item(), batch_size) 91 | dice_losses.update(loss_dice.item(), batch_size) 92 | iou_losses.update(loss_iou.item(), batch_size) 93 | total_losses.update(loss_total.item(), batch_size) 94 | if iter%50==0: 95 | fabric.print(f'Epoch: [{epoch}][{iter + 1}/{len(train_dataloader)}]' 96 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 97 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' 98 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' 99 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' 100 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]' 101 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]') 102 | 103 | # loss_logger = { 104 | # "Focal Loss": focal_losses.avg, 105 | # "Dice Loss": dice_losses.avg, 106 | # "IoU Loss": iou_losses.avg, 107 | # "Total Loss": total_losses.avg 108 | # } 109 | # fabric.log_dict(loss_logger, num_iter * (epoch - 1) + iter) 110 | torch.cuda.empty_cache() 111 | 112 | if epoch % cfg.eval_interval == 0: 113 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch) 114 | if iou > max_iou: 115 | state = {"model": model, "optimizer": optimizer} 116 | fabric.save(os.path.join(cfg.out_dir, "save", "last-ckpt.pth"), state) 117 | max_iou = iou 118 | 119 | def configure_opt(cfg: Box, model: Model): 120 | 121 | def lr_lambda(step): 122 | if step < cfg.opt.warmup_steps: 123 | return step / cfg.opt.warmup_steps 124 | elif step < cfg.opt.steps[0]: 125 | return 1.0 126 | elif step < cfg.opt.steps[1]: 127 | return 1 / cfg.opt.decay_factor 128 | else: 129 | return 1 / (cfg.opt.decay_factor**2) 130 | 131 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 132 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 133 | 134 | return optimizer, scheduler 135 | 136 | 137 | def corrupt_main(cfg): 138 | for corrupt in cfg.corruptions: 139 | cfg.corrupt = corrupt 140 | cfg.out_name = corrupt 141 | torch.cuda.empty_cache() 142 | main(cfg) 143 | 144 | 145 | def main(cfg: Box) -> None: 146 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())] 147 | num_devices = len(gpu_ids) 148 | fabric = L.Fabric(accelerator="auto", 149 | devices=num_devices, 150 | strategy="auto", 151 | loggers=[TensorBoardLogger(cfg.out_dir)]) 152 | fabric.launch() 153 | fabric.seed_everything(1337 + fabric.global_rank) 154 | 155 | if fabric.global_rank == 0: 156 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True) 157 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys) 158 | 159 | with fabric.device: 160 | model = Model(cfg) 161 | model.setup() 162 | 163 | load_datasets = call_load_dataset(cfg) 164 | train_data, val_data = load_datasets(cfg, 1024) 165 | train_data = fabric._setup_dataloader(train_data) 166 | val_data = fabric._setup_dataloader(val_data) 167 | optimizer, scheduler = configure_opt(cfg, model) 168 | model, optimizer = fabric.setup(model, optimizer) 169 | 170 | if cfg.resume and cfg.model.ckpt is not None: 171 | full_checkpoint = fabric.load(cfg.model.ckpt) 172 | model.load_state_dict(full_checkpoint["model"]) 173 | optimizer.load_state_dict(full_checkpoint["optimizer"]) 174 | 175 | anchor_model = copy_model(model) 176 | print('-'*100) 177 | print('\033[92mDirect test on the original SAM.\033[0m') 178 | print('-'*100) 179 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, epoch=0) 180 | train_sam(cfg, fabric, model, anchor_model, optimizer, scheduler, train_data, val_data) 181 | 182 | del model, anchor_model, train_data, val_data 183 | 184 | 185 | def parse_args(): 186 | parser = argparse.ArgumentParser(description='Train a detector') 187 | parser.add_argument('--cfg', help='train config file path') 188 | parser.add_argument('--prompt', help='the type of prompt') 189 | parser.add_argument('--num_points',type=int, help='the number of points') 190 | parser.add_argument('--out_dir', help='the dir to save logs and models') 191 | parser.add_argument('--load_type', help='the dir to save logs and models') 192 | args = parser.parse_args() 193 | return args 194 | 195 | if __name__ == "__main__": 196 | print(torch.cuda.current_device()) 197 | torch.cuda.empty_cache() 198 | torch.set_float32_matmul_precision('high') 199 | args = parse_args() 200 | 201 | exec(f'from {args.cfg} import cfg') 202 | 203 | # transfer the args to a dict 204 | args_dict = vars(args) 205 | cfg.merge_update(args_dict) 206 | 207 | main(cfg) 208 | torch.cuda.empty_cache() 209 | -------------------------------------------------------------------------------- /train_supervise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import lightning as L 5 | import torch.nn.functional as F 6 | import segmentation_models_pytorch as smp 7 | from box import Box 8 | from lightning.fabric.fabric import _FabricOptimizer 9 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger 10 | from torch.utils.data import DataLoader 11 | import cv2 12 | import argparse 13 | 14 | from datasets import call_load_dataset 15 | from utils.losses import DiceLoss, FocalLoss, ContraLoss 16 | from utils.model import Model 17 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts 18 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances 19 | 20 | def train_sam( 21 | cfg: Box, 22 | fabric: L.Fabric, 23 | model: Model, 24 | optimizer: _FabricOptimizer, 25 | scheduler: _FabricOptimizer, 26 | train_dataloader: DataLoader, 27 | val_dataloader: DataLoader, 28 | ): 29 | """The SAM training loop.""" 30 | 31 | focal_loss = FocalLoss() 32 | dice_loss = DiceLoss() 33 | max_iou = 0. 34 | 35 | for epoch in range(1, cfg.num_epochs + 1): 36 | batch_time = AverageMeter() 37 | data_time = AverageMeter() 38 | focal_losses = AverageMeter() 39 | dice_losses = AverageMeter() 40 | iou_losses = AverageMeter() 41 | total_losses = AverageMeter() 42 | end = time.time() 43 | num_iter = len(train_dataloader) 44 | for iter, data in enumerate(train_dataloader): 45 | 46 | data_time.update(time.time() - end) 47 | images, bboxes, gt_masks, img_paths = data 48 | 49 | batch_size = images.size(0) 50 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks) 51 | if num_insts > cfg.max_nums: 52 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums) 53 | 54 | prompts = get_prompts(cfg, bboxes, gt_masks) 55 | 56 | _, pred_masks, iou_predictions, _ = model(images, prompts) 57 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks) 58 | loss_focal = torch.tensor(0., device=fabric.device) 59 | loss_dice = torch.tensor(0., device=fabric.device) 60 | loss_iou = torch.tensor(0., device=fabric.device) 61 | for i,(pred_mask, gt_mask, iou_prediction) in enumerate(zip(pred_masks, gt_masks, iou_predictions)): 62 | gt_mask = gt_mask.to(device=fabric.device) 63 | batch_iou = calc_iou(pred_mask, gt_mask) 64 | loss_focal += focal_loss(pred_mask, gt_mask, num_masks) 65 | loss_dice += dice_loss(pred_mask, gt_mask, num_masks) 66 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks 67 | 68 | loss_total = 20. * loss_focal + loss_dice + loss_iou 69 | optimizer.zero_grad() 70 | fabric.backward(loss_total) 71 | optimizer.step() 72 | scheduler.step() 73 | batch_time.update(time.time() - end) 74 | end = time.time() 75 | 76 | focal_losses.update(loss_focal.item(), batch_size) 77 | dice_losses.update(loss_dice.item(), batch_size) 78 | iou_losses.update(loss_iou.item(), batch_size) 79 | total_losses.update(loss_total.item(), batch_size) 80 | if iter%50==0: 81 | fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' 82 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 83 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' 84 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' 85 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' 86 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]' 87 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]') 88 | 89 | if epoch % cfg.eval_interval == 0: 90 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch) 91 | if iou > max_iou: 92 | state = {"model": model, "optimizer": optimizer} 93 | fabric.save(os.path.join(cfg.out_dir, "save", "last-ckpt.pth"), state) 94 | max_iou = iou 95 | 96 | def configure_opt(cfg: Box, model: Model): 97 | 98 | def lr_lambda(step): 99 | if step < cfg.opt.warmup_steps: 100 | return step / cfg.opt.warmup_steps 101 | elif step < cfg.opt.steps[0]: 102 | return 1.0 103 | elif step < cfg.opt.steps[1]: 104 | return 1 / cfg.opt.decay_factor 105 | else: 106 | return 1 / (cfg.opt.decay_factor**2) 107 | 108 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 109 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 110 | 111 | return optimizer, scheduler 112 | 113 | 114 | def corrupt_main(cfg): 115 | for corrupt in cfg.corruptions: 116 | cfg.corrupt = corrupt 117 | cfg.out_name = corrupt 118 | torch.cuda.empty_cache() 119 | main(cfg) 120 | 121 | 122 | def main(cfg: Box) -> None: 123 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())] 124 | num_devices = len(gpu_ids) 125 | fabric = L.Fabric(accelerator="auto", 126 | devices=num_devices, 127 | strategy="auto", 128 | loggers=[TensorBoardLogger(cfg.out_dir)]) 129 | fabric.launch() 130 | fabric.seed_everything(1337 + fabric.global_rank) 131 | 132 | if fabric.global_rank == 0: 133 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True) 134 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys) 135 | 136 | with fabric.device: 137 | model = Model(cfg) 138 | model.setup() 139 | 140 | load_datasets = call_load_dataset(cfg) 141 | train_data, val_data = load_datasets(cfg, 1024) 142 | train_data = fabric._setup_dataloader(train_data) 143 | val_data = fabric._setup_dataloader(val_data) 144 | optimizer, scheduler = configure_opt(cfg, model) 145 | model, optimizer = fabric.setup(model, optimizer) 146 | 147 | if cfg.resume and cfg.model.ckpt is not None: 148 | full_checkpoint = fabric.load(cfg.model.ckpt) 149 | model.load_state_dict(full_checkpoint["model"]) 150 | optimizer.load_state_dict(full_checkpoint["optimizer"]) 151 | 152 | anchor_model = copy_model(model) 153 | print('-'*100) 154 | print('\033[92mDirect test on the original SAM.\033[0m') 155 | print('-'*100) 156 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, epoch=0) 157 | train_sam(cfg, fabric, model, optimizer, scheduler, train_data, val_data) 158 | 159 | del model, anchor_model, train_data, val_data 160 | 161 | 162 | def parse_args(): 163 | parser = argparse.ArgumentParser(description='Train a detector') 164 | parser.add_argument('--cfg', help='train config file path') 165 | parser.add_argument('--prompt', help='the type of prompt') 166 | parser.add_argument('--num_points',type=int, help='the number of points') 167 | parser.add_argument('--out_dir', help='the dir to save logs and models') 168 | parser.add_argument('--load_type', help='the dir to save logs and models') 169 | args = parser.parse_args() 170 | return args 171 | 172 | if __name__ == "__main__": 173 | print(torch.cuda.current_device()) 174 | torch.cuda.empty_cache() 175 | torch.set_float32_matmul_precision('high') 176 | args = parse_args() 177 | 178 | exec(f'from {args.cfg} import cfg') 179 | 180 | # transfer the args to a dict 181 | args_dict = vars(args) 182 | cfg.merge_update(args_dict) 183 | 184 | main(cfg) 185 | torch.cuda.empty_cache() 186 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from scipy.ndimage import map_coordinates 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | 10 | import lightning as L 11 | 12 | import segmentation_models_pytorch as smp 13 | 14 | from box import Box 15 | from utils.model import Model 16 | from utils.sample_utils import get_point_prompts 17 | from utils.tools import write_csv 18 | 19 | 20 | class AverageMeter: 21 | """Computes and stores the average and current value.""" 22 | 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 40 | pred_mask = (pred_mask >= 0.5).float() 41 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) 42 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection 43 | epsilon = 1e-7 44 | batch_iou = intersection / (union + epsilon) 45 | 46 | batch_iou = batch_iou.unsqueeze(1) 47 | return batch_iou 48 | 49 | def calc_iou_instance(pred_masks: torch.Tensor, gt_masks: torch.Tensor): 50 | iou_list = [] 51 | for pred_mask, gt_mask in zip(pred_masks, gt_masks): 52 | pred_mask = (pred_mask >= 0.5).float() 53 | # print(pred_mask.shape) 54 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(0, 1)) 55 | 56 | union = torch.sum(pred_mask, dim=(0, 1)) + torch.sum(gt_mask, dim=(0, 1)) - intersection 57 | epsilon = 1e-7 58 | iou = intersection / (union + epsilon) 59 | # print(iou) 60 | # batch_iou = batch_iou.unsqueeze(1) 61 | iou_list.append(iou) 62 | return iou_list 63 | 64 | 65 | #intersection 66 | # mask1: mask2: intersection: 67 | # [[1, 1, 0, 0], [[1, 0, 0, 0], [[1, 0, 0, 0], 68 | # [1, 0, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0], 69 | # [1, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], 70 | # [0, 0, 0, 0]] [0, 0, 0, 0]] [0, 0, 0, 0]] 71 | 72 | #union 73 | # mask1: mask2: union: 74 | # [[1, 1, 0, 0], [[1, 0, 0, 0], [[1, 1, 0, 0], 75 | # [1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0], 76 | # [1, 1, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], 77 | # [0, 0, 0, 0]] [0, 0, 0, 0]] [0, 0, 0, 0]] 78 | 79 | 80 | def calculate_iou(mask1, mask2): 81 | 82 | intersection = torch.logical_and(mask1, mask2) 83 | union = torch.logical_or(mask1, mask2) 84 | iou = torch.sum(intersection).float() / torch.sum(union).float() 85 | return iou 86 | 87 | def calc_iou_matrix(mask_list1, mask_list2): 88 | 89 | iou_matrix = torch.zeros((len(mask_list1), len(mask_list2))) 90 | for i, mask1 in enumerate(mask_list1): 91 | for j, mask2 in enumerate(mask_list2): 92 | iou_matrix[i, j] = calculate_iou(mask1, mask2) 93 | return iou_matrix 94 | 95 | def get_prompts(cfg: Box, bboxes, gt_masks): 96 | if cfg.prompt == "box" or cfg.prompt == "coarse": 97 | prompts = bboxes 98 | elif cfg.prompt == "point": 99 | prompts = get_point_prompts(gt_masks, cfg.num_points) 100 | else: 101 | raise ValueError("Prompt Type Error!") 102 | return prompts 103 | 104 | def validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, epoch: int = 0): 105 | model.eval() 106 | ious = AverageMeter() 107 | f1_scores = AverageMeter() 108 | recall = AverageMeter() 109 | precision = AverageMeter() 110 | 111 | with torch.no_grad(): 112 | for iter, data in enumerate(tqdm(val_dataloader, desc='Validation', ncols=100)): 113 | images, bboxes, gt_masks, img_paths = data 114 | num_images = images.size(0) 115 | prompts = get_prompts(cfg, bboxes, gt_masks) 116 | 117 | _, pred_masks, _, _ = model(images, prompts) 118 | 119 | for pred_mask, gt_mask in zip(pred_masks, gt_masks): 120 | batch_stats = smp.metrics.get_stats( 121 | pred_mask, 122 | gt_mask.int(), 123 | mode='binary', 124 | threshold=0.5, 125 | ) 126 | batch_recall = smp.metrics.recall(*batch_stats, reduction="micro-imagewise") 127 | batch_precision = smp.metrics.precision(*batch_stats, reduction="micro-imagewise") 128 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise") 129 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise") 130 | ious.update(batch_iou, num_images) 131 | f1_scores.update(batch_f1, num_images) 132 | recall.update(batch_recall, num_images) 133 | precision.update(batch_precision, num_images) 134 | 135 | torch.cuda.empty_cache() 136 | 137 | fabric.print( 138 | f'Val: [{epoch}] - [{iter+1}/{len(val_dataloader)}]: IoU: [{ious.avg:.4f}] -- Recall: [{recall.avg:.4f}] -- Precision [{precision.avg:.4f}] -- F1: [{f1_scores.avg:.4f}]' 139 | ) 140 | csv_dict = {"Prompt": cfg.prompt, "IoU": f"{ious.avg:.4f}","Recall": f"{recall.avg:.4f}", "Precision": f"{precision.avg:.4f}", "F1": f"{f1_scores.avg:.4f}", "epoch": epoch} 141 | 142 | if fabric.global_rank == 0: 143 | write_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_dict, csv_head=cfg.csv_keys) 144 | return ious.avg, f1_scores.avg 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /utils/finch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | FINCH - First Integer Neighbor Clustering Hierarchy Algorithm 4 | """ 5 | 6 | # Author: Eren Cakmak 7 | # 8 | # License: MIT 9 | 10 | import numpy as np 11 | from sklearn.neighbors import NearestNeighbors 12 | from scipy.sparse.csgraph import connected_components 13 | from sklearn.utils import check_array 14 | from sklearn.metrics import silhouette_score 15 | 16 | 17 | class FINCH(): 18 | """ 19 | A class to perform the FINCH clustering 20 | 21 | Read more in paper see reference below. 22 | 23 | Parameters 24 | ---------- 25 | metric : string default='euclidean' 26 | The used distance metric - more options are 27 | ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, 28 | ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘jensenshannon’, 29 | ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘rogerstanimoto’, ‘sqeuclidean’, 30 | ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘yule’. 31 | 32 | n_jobs : int or None, default=1 33 | The number of processes to start -1 means use all processors 34 | 35 | Attributes 36 | ---------- 37 | labels : array, shape = [n_samples] 38 | Cluster labels for the data 39 | 40 | partitions : dict, contains all partitioning and their resulting labels, cluster centroids 41 | Cluster labels for the data 42 | 43 | References 44 | ---------- 45 | Sarfraz, Saquib, Vivek Sharma, and Rainer Stiefelhagen. 46 | "Efficient parameter-free clustering using first neighbor relations." 47 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. 48 | 49 | """ 50 | def __init__(self, metric='cosine', n_jobs=1, verbose=True): 51 | self.metric = metric 52 | self.n_jobs = n_jobs 53 | self.verbose = verbose 54 | 55 | def _finch(self, X, prev_clusters, prev_cluster_core_indices): 56 | """ 57 | Compute the adjacency link matrix as described in the paper Eq.1 58 | Afterwards get connected components and their cluster centroids 59 | ---------- 60 | X : ndarray of shape (n_samples, n_features) 61 | The input data samples that should be clustered. 62 | 63 | prev_clusters : list of ndarray of shape (2,) 64 | The cluster centroids of the previous partitioning. 65 | 66 | prev_cluster_core_indices : list 67 | The previous samples belonging to the cluster of prev_clusters 68 | Returns 69 | ------- 70 | n_connected_components_ : int 71 | The number of clusters in the partitioning. 72 | 73 | labels_ : list 74 | Cluster labels for all data samples between 0 and n_connected_components_ 75 | 76 | cluster_centers_ : list of ndarray of shape (2,) 77 | The cluster centroids of the partitioning. 78 | 79 | cluster_core_indices_ : list 80 | The samples belonging to the cluster of cluster_centers_ 81 | """ 82 | 83 | # Adjacency link matrix by Eq.1 84 | connectivity = None 85 | 86 | # compute the adjacency link matrix 87 | if not prev_clusters: 88 | # first partitioning 89 | data = X 90 | else: 91 | data = prev_clusters 92 | 93 | # Compute the adjacency link matrix as described in the paper Eq.1 94 | # NN in sklearn 95 | nbrs = NearestNeighbors(n_neighbors=2,#2 96 | metric=self.metric, 97 | n_jobs=self.n_jobs).fit(data) 98 | 99 | # condition j = k_i - link nearest neighbors 100 | connectivity = nbrs.kneighbors_graph(data) 101 | 102 | # condition k_i = k_j - link same first neighbors 103 | # dot product forces symmtery therefore k_j plus k_j = i 104 | connectivity @= connectivity.T 105 | 106 | # remove diagonal 107 | connectivity.setdiag(0) 108 | connectivity.eliminate_zeros() 109 | 110 | # TODO this could also be solved by computing a linkage matrix 111 | # and then just calling the method scipy.cluster.hierarchy.fcluster 112 | # This will be probably increase the performance of the method further 113 | # 114 | # set values to one required for the linkage matrix 115 | # connectivity.data[:] = 1 116 | 117 | # get connected components 118 | n_connected_components_, labels_ = connected_components( 119 | csgraph=connectivity) 120 | 121 | # labels remap to previous cluster core indices 122 | # only called for second paritioning 123 | if len(labels_) < self.n_samples: 124 | new_labels = np.full(self.n_samples, 0) 125 | for i in range(n_connected_components_): 126 | idx = np.where(labels_ == i)[0] 127 | idx = sum([prev_cluster_core_indices[j] for j in idx], []) 128 | new_labels[idx] = i 129 | labels_ = new_labels 130 | 131 | # list of centroids and sample indices for each cluster 132 | cluster_centers_ = [] 133 | cluster_core_indices_ = [] 134 | 135 | # compute cluster centers with labels indicies 136 | for i in range(n_connected_components_): 137 | # update the cluster core indicies 138 | idx = np.where(labels_ == i)[0] 139 | cluster_core_indices_.append(idx.tolist()) 140 | 141 | # compute the cluster means 142 | xc_mean = X[idx].mean(axis=0) 143 | cluster_centers_.append(xc_mean) 144 | 145 | return n_connected_components_, labels_, cluster_centers_, cluster_core_indices_ 146 | 147 | def fit(self, X): 148 | """ 149 | Apply the FINCH algorithm 150 | ---------- 151 | X : ndarray of shape (n_samples, n_features) 152 | The data samples that are clustered 153 | 154 | Returns 155 | ------- 156 | self 157 | """ 158 | # check if input is correct 159 | X = check_array(X) 160 | 161 | self.n_samples = X.shape[0] 162 | 163 | # the results of the partitioning 164 | results = {} 165 | 166 | # intermediate results 167 | cluster_centers_ = None 168 | cluster_core_indices_ = None 169 | 170 | n_connected_components_ = len(X) 171 | if self.verbose: 172 | print('FINCH Partitionings') 173 | print('-------------------') 174 | 175 | i = 0 176 | while n_connected_components_ > 1: 177 | n_connected_components_, labels_, cluster_centers_, cluster_core_indices_ = self._finch( 178 | X, cluster_centers_, cluster_core_indices_) 179 | 180 | if n_connected_components_ == 1: 181 | break 182 | else: 183 | if self.verbose: 184 | print('Clusters in %s partition: %d' % 185 | (i, n_connected_components_)) 186 | 187 | results['parition_' + str(i)] = { 188 | 'n_clusters': n_connected_components_, 189 | 'labels': labels_, 190 | 'cluster_centers': cluster_centers_, 191 | 'cluster_core_indices': cluster_core_indices_ 192 | } 193 | i += 1 194 | 195 | self.partitions = results 196 | 197 | return self 198 | 199 | def fit_predict(self, X, verbose): 200 | """ 201 | Apply the FINCH algorithm and returns a reasonable partitioning labels based on the silhouette coeffcient 202 | ---------- 203 | X : ndarray of shape (n_samples, n_features) 204 | The data samples that are clustered 205 | 206 | Returns 207 | ------- 208 | self 209 | """ 210 | # check if input is correct 211 | X = check_array(X) 212 | 213 | self.n_samples = X.shape[0] 214 | 215 | # the results of the partitioning 216 | results = {} 217 | 218 | # intermediate results 219 | cluster_centers_ = None 220 | cluster_core_indices_ = None 221 | 222 | # min silhouette coefficent score 223 | max_sil_score = -1 224 | 225 | n_connected_components_ = len(X) 226 | 227 | print('FINCH Partitionings') 228 | print('-------------------') 229 | 230 | i = 0 231 | while n_connected_components_ > 1: 232 | n_connected_components_, labels_, cluster_centers_, cluster_core_indices_ = self._finch( 233 | X, cluster_centers_, cluster_core_indices_) 234 | 235 | if n_connected_components_ == 1: 236 | break 237 | else: 238 | # in this version the silhouette coefficent is computed 239 | sil_score = silhouette_score(X, labels_, metric=self.metric) 240 | # store the max silhouette coefficent 241 | # do not pick the first partitioning 242 | if max_sil_score <= sil_score and i != 0: 243 | best_labels = labels_ 244 | max_sil_score = sil_score 245 | 246 | print( 247 | 'Clusters in %s partition: %d with average silhouette score %0.2f' 248 | % (i, n_connected_components_, sil_score)) 249 | 250 | results['parition_' + str(i)] = { 251 | 'n_clusters': n_connected_components_, 252 | 'labels': labels_, 253 | 'cluster_centers': cluster_centers_, 254 | 'cluster_core_indices': cluster_core_indices_, 255 | 'silhouette_coefficient': sil_score 256 | } 257 | i += 1 258 | 259 | self.labels = best_labels 260 | self.partitions = results 261 | 262 | return self.labels 263 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from abc import ABC 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | ALPHA = 0.8 9 | GAMMA = 2 10 | 11 | class FocalLoss(nn.Module): 12 | 13 | def __init__(self, weight=None, size_average=True): 14 | super().__init__() 15 | 16 | # def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): 17 | # inputs = F.sigmoid(inputs) 18 | # inputs = torch.clamp(inputs, min=0, max=1) 19 | # #flatten label and prediction tensors 20 | # inputs = inputs.view(-1) 21 | # targets = targets.view(-1) 22 | 23 | # BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 24 | # BCE_EXP = torch.exp(-BCE) 25 | # focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE 26 | 27 | # return focal_loss 28 | 29 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): 30 | inputs = F.sigmoid(inputs) 31 | inputs = torch.clamp(inputs, min=0, max=1) 32 | #flatten label and prediction tensors 33 | inputs = inputs.view(-1) 34 | targets = targets.view(-1) 35 | 36 | BCE = F.binary_cross_entropy(inputs, targets, reduction='none') 37 | BCE_EXP = torch.exp(-BCE) 38 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE 39 | focal_loss = focal_loss.mean() 40 | 41 | return focal_loss.mean() 42 | 43 | 44 | def dice_coefficient(x, target): 45 | eps = 1e-5 46 | n_inst = x.size(0) 47 | print(x.shape) 48 | x = x.reshape(n_inst, -1) 49 | target = target.reshape(n_inst, -1) 50 | intersection = (x * target).sum(dim=1) 51 | union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps 52 | loss = 1. - (2 * intersection / union) 53 | return loss 54 | 55 | class DiceLoss(nn.Module): 56 | 57 | def __init__(self, weight=None, size_average=True): 58 | super().__init__() 59 | 60 | def forward(self, inputs, targets, smooth=1): 61 | inputs = F.sigmoid(inputs) 62 | inputs = torch.clamp(inputs, min=0, max=1) 63 | #flatten label and prediction tensors 64 | inputs = inputs.view(-1) 65 | targets = targets.view(-1) 66 | 67 | intersection = (inputs * targets).sum() 68 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) 69 | 70 | return 1 - dice 71 | 72 | class ContraLoss(nn.Module): 73 | 74 | def __init__(self, temperature = 0.3, weight=None, size_average=True): 75 | super().__init__() 76 | self.temperature = temperature 77 | self.criterion = torch.nn.CrossEntropyLoss() 78 | 79 | def forward(self, embedd_x: torch.Tensor, embedd_y: torch.Tensor, mask_x: torch.Tensor, mask_y: torch.Tensor): 80 | x_embedding = self.norm_embed(embedd_x) # embedd_x: [256, 64, 64] 81 | y_embedding = self.norm_embed(embedd_y) 82 | 83 | # mask_x = mask_x.float() 84 | # mask_y = mask_y.float() 85 | x_masks = F.interpolate(mask_x, size=x_embedding.shape[-2:], mode="bilinear", align_corners=False).detach() 86 | # x_masks = F.sigmoid(x_masks) 87 | # x_masks = torch.clamp(x_masks, min=0, max=1) 88 | # x_masks = (x_masks > 0.5).float() 89 | sum_x = x_masks.sum(dim=[-1, -2]).clone() 90 | # sum_x[sum_x[:, 0] == 0.] = 1. 91 | 92 | y_masks = F.interpolate(mask_y, size=y_embedding.shape[-2:], mode="bilinear", align_corners=False).detach() 93 | # y_masks = F.sigmoid(y_masks) 94 | # y_masks = torch.clamp(y_masks, min=0, max=1) 95 | # y_masks = (y_masks > 0.5).float() 96 | sum_y = y_masks.sum(dim=[-1, -2]).clone() 97 | # sum_y[sum_y[:, 0] == 0.] = 1. 98 | # [n, 1, H, W] 99 | multi_embedd_x = (x_embedding * x_masks).sum(dim=[-1, -2]) / sum_x # [n, 256, 64, 64] >> [n, 256] 100 | multi_embedd_y = (y_embedding * y_masks).sum(dim=[-1, -2]) / sum_y 101 | 102 | flatten_x = multi_embedd_x.view(multi_embedd_x.size(0), -1) 103 | flatten_y = multi_embedd_y.view(multi_embedd_y.size(0), -1) 104 | # similarity_matrix = torch.matmul(multi_embedd_x, multi_embedd_y.T) 105 | similarity_matrix = F.cosine_similarity(flatten_x.unsqueeze(1), flatten_y.unsqueeze(0), dim=2) 106 | 107 | label_pos = torch.eye(x_masks.size(0)).bool().to(embedd_x.device) 108 | label_nag = ~label_pos 109 | 110 | similarity_matrix = similarity_matrix / self.temperature # [n insts, n insts] 111 | loss = -torch.log( 112 | similarity_matrix.masked_select(label_pos).exp().sum() / 113 | similarity_matrix.exp().sum() 114 | ) 115 | # loss = -torch.log( 116 | # similarity_matrix.masked_select(label_pos).exp().sum() 117 | # ) 118 | # loss = -torch.log( 119 | # similarity_matrix.masked_select(label_pos).exp().sum() / 120 | # (similarity_matrix.masked_select(label_nag).exp().sum() + 1e-7) 121 | # ) 122 | return loss 123 | 124 | def norm_embed(self, embedding: torch.Tensor): 125 | embedding = F.normalize(embedding, dim=0, p=2) 126 | return embedding 127 | 128 | 129 | def cosine_similarity(vec1, vec2, device): 130 | vec1 = torch.tensor(vec1).to(device).float() 131 | vec2 = torch.tensor(vec2).to(device).float() 132 | dot_product = torch.dot(vec1, vec2) 133 | norm1 = torch.norm(vec1) 134 | norm2 = torch.norm(vec2) 135 | similarity = dot_product / (norm1 * norm2) 136 | return similarity 137 | 138 | def Matching_Loss(pred_prototypes, target_prototypes, device): 139 | """ 140 | pred_prototypes: list of predicted prototypes 141 | target_prototypes: list of target prototypes 142 | """ 143 | num_pred = len(pred_prototypes) 144 | num_target = len(target_prototypes) 145 | 146 | num = min(num_pred,num_target) 147 | cost_matrix = torch.zeros((num_pred, num_target)) 148 | for i, pred_proto in enumerate(pred_prototypes): 149 | for j, target_proto in enumerate(target_prototypes): 150 | 151 | cos_sim = cosine_similarity(pred_proto, target_proto, device) 152 | cost_matrix[i, j] = 1 - cos_sim 153 | 154 | row_indices, col_indices = linear_sum_assignment(cost_matrix.numpy()) 155 | 156 | total_loss = 0 157 | for row, col in zip(row_indices, col_indices): 158 | total_loss += cost_matrix[row, col].item() 159 | 160 | return total_loss/len(row_indices) 161 | 162 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from segment_anything import sam_model_registry 7 | # from segment_anything_2.model_registry import sam2_model_registry 8 | from .sam_lora import LoRA_Sam 9 | 10 | class Model(nn.Module): 11 | 12 | def __init__(self, cfg): 13 | super().__init__() 14 | self.cfg = cfg 15 | self.image_embeddings = None 16 | 17 | def get_checkpoint(self, model_type): 18 | if model_type == "vit_b": 19 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_b_01ec64.pth") 20 | elif model_type == "vit_l": 21 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_l_0b3195.pth") 22 | elif model_type == "vit_h": 23 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_h_4b8939.pth") 24 | 25 | return checkpoint 26 | 27 | def setup(self): 28 | if self.cfg.model.type in ['vit_b','vit_l','vit_h']: 29 | checkpoint = self.get_checkpoint(self.cfg.model.type) 30 | self.model = sam_model_registry[self.cfg.model.type](checkpoint=checkpoint) 31 | self.base = 'sam' 32 | elif self.cfg.model.type in ["hiera_b",'hiera_l']: 33 | self.model = sam2_model_registry[self.cfg.model.type]() 34 | self.base = 'sam2' 35 | else: 36 | raise ValueError("Model type error!") 37 | 38 | # for param in self.model.parameters(): 39 | # param.requires_grad = False 40 | if self.cfg.model.freeze.image_encoder: 41 | for param in self.model.image_encoder.parameters(): 42 | param.requires_grad = False 43 | if self.cfg.model.freeze.prompt_encoder: 44 | try: 45 | for param in self.model.prompt_encoder.parameters(): 46 | param.requires_grad = False 47 | except: 48 | for param in self.model.sam_prompt_encoder.parameters(): 49 | param.requires_grad = False 50 | 51 | if self.cfg.model.freeze.mask_decoder: 52 | try: 53 | for param in self.model.mask_decoder.parameters(): 54 | param.requires_grad = False 55 | except: 56 | for param in self.model.sam_prompt_encoder.parameters(): 57 | param.requires_grad = False 58 | 59 | self.model.train() 60 | self.finetune() 61 | 62 | def finetune(self): 63 | LoRA_Sam(self.model, self.cfg.lora_rank, lora_layer=list(range(self.cfg.start_lora_layer, len(self.model.image_encoder.blocks)))) 64 | # self.set_adapter_layer() 65 | # self.set_norm_layer() 66 | # print(self.model) 67 | 68 | def set_norm_layer(self): 69 | for name, param in self.model.image_encoder.named_parameters(): 70 | if "norm" in name: 71 | param.requires_grad = True 72 | 73 | def set_adapter_layer(self): 74 | for block in self.model.image_encoder.blocks: 75 | if hasattr(block, "Space_Adapter"): 76 | for param in block.Space_Adapter.parameters(): 77 | param.requires_grad = True 78 | if hasattr(block, "MLP_Adapter"): 79 | for param in block.MLP_Adapter.parameters(): 80 | param.requires_grad = True 81 | 82 | def reset_parameters(self) -> None: 83 | for name, param in self.model.named_parameters(): 84 | if param.requires_grad == True: 85 | if "linear_a" in name: 86 | nn.init.kaiming_uniform_(param, a=math.sqrt(5)) 87 | if "linear_b" in name: 88 | nn.init.zeros_(param) 89 | 90 | def forward(self, images, prompts): 91 | _, _, H, W = images.shape#[n, 3, 1024, 1024] 92 | 93 | image_embeddings = self.model.image_encoder(images) 94 | pred_masks, ious, res_masks = self.decode((H, W), prompts, image_embeddings) 95 | return image_embeddings, pred_masks, ious, res_masks 96 | 97 | # def encode(self, images): 98 | # self.image_embeddings = self.model.image_encoder(images) 99 | # return self.image_embeddings 100 | 101 | def decode(self, image_shape, prompts, image_embeddings): 102 | if self.base == 'sam2': 103 | _bb_feat_sizes = [ 104 | (256, 256), 105 | (128, 128), 106 | (64, 64), 107 | ] 108 | _, vision_feats, _, _ = self.model._prepare_backbone_features(image_embeddings) 109 | 110 | feats = [feat.permute(1, 2, 0).view(1, -1, *feat_size) 111 | for feat, feat_size in zip(vision_feats[::-1], _bb_feat_sizes[::-1])][::-1] 112 | self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} 113 | image_embeddings = feats[-1] 114 | high_res_features = feats[:-1] 115 | 116 | if image_embeddings == None: 117 | raise "No image embeddings" 118 | 119 | pred_masks = [] 120 | ious = [] 121 | res_masks = [] 122 | for prompt, embedding in zip(prompts, image_embeddings): 123 | 124 | if self.base =="sam": 125 | if isinstance(prompt, torch.Tensor): 126 | prompt = prompt.to(device=embedding.device) 127 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 128 | points=None, 129 | boxes=prompt, 130 | masks=None, 131 | ) 132 | elif isinstance(prompt, tuple): 133 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 134 | points=prompt, 135 | boxes=None, 136 | masks=None, 137 | ) 138 | low_res_masks, iou_predictions = self.model.mask_decoder( 139 | image_embeddings=embedding.unsqueeze(0), 140 | image_pe=self.model.prompt_encoder.get_dense_pe(), 141 | sparse_prompt_embeddings=sparse_embeddings, 142 | dense_prompt_embeddings=dense_embeddings, 143 | multimask_output=False, 144 | ) 145 | else: 146 | if isinstance(prompt, torch.Tensor): 147 | prompt = prompt.to(device=embedding.device) 148 | sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( 149 | points=None, 150 | boxes=prompt, 151 | masks=None, 152 | ) 153 | elif isinstance(prompt, tuple): 154 | sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( 155 | points=prompt, 156 | boxes=None, 157 | masks=None, 158 | ) 159 | low_res_masks, iou_predictions,_,_ = self.model.sam_mask_decoder( 160 | image_embeddings=embedding.unsqueeze(0), 161 | image_pe=self.model.sam_prompt_encoder.get_dense_pe(), 162 | sparse_prompt_embeddings=sparse_embeddings, 163 | dense_prompt_embeddings=dense_embeddings, 164 | multimask_output=False, 165 | repeat_image = True, 166 | high_res_features=high_res_features, 167 | ) 168 | 169 | masks = F.interpolate( 170 | low_res_masks, 171 | image_shape, 172 | mode="bilinear", 173 | align_corners=False, 174 | ) 175 | pred_masks.append(masks.squeeze(1)) 176 | ious.append(iou_predictions) 177 | res_masks.append(low_res_masks) 178 | return pred_masks, ious, res_masks -------------------------------------------------------------------------------- /utils/sam_lora.py: -------------------------------------------------------------------------------- 1 | # Sheng Wang at Apr 6 2023 2 | # What a time to be alive (first half of 2023) 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.parameter import Parameter 10 | from safetensors import safe_open 11 | from safetensors.torch import save_file 12 | 13 | 14 | class _LoRA_qkv(nn.Module): 15 | """In Sam it is implemented as 16 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 17 | B, N, C = x.shape 18 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 19 | q, k, v = qkv.unbind(0) 20 | """ 21 | 22 | def __init__( 23 | self, 24 | qkv: nn.Module, 25 | linear_a_q: nn.Module, 26 | linear_b_q: nn.Module, 27 | linear_a_v: nn.Module, 28 | linear_b_v: nn.Module, 29 | ): 30 | super().__init__() 31 | self.qkv = qkv 32 | self.linear_a_q = linear_a_q 33 | self.linear_b_q = linear_b_q 34 | self.linear_a_v = linear_a_v 35 | self.linear_b_v = linear_b_v 36 | self.dim = qkv.in_features 37 | self.w_identity = torch.eye(qkv.in_features) 38 | 39 | def forward(self, x): 40 | # x: [25, 14, 14, 768]; self.qkv: Linear(in_features=768, out_features=2304, bias=True) 41 | qkv = self.qkv(x) # B,N,N,3*org_C 42 | new_q = self.linear_b_q(self.linear_a_q(x)) 43 | new_v = self.linear_b_v(self.linear_a_v(x)) 44 | qkv[:, :, :, : self.dim] += new_q 45 | qkv[:, :, :, -self.dim :] += new_v 46 | return qkv 47 | 48 | 49 | class LoRA(nn.Module): 50 | def __init__(self, *args, **kwargs) -> None: 51 | super().__init__(*args, **kwargs) 52 | 53 | def save_fc_parameters(self, filename: str) -> None: 54 | r"""Only safetensors is supported now. 55 | 56 | pip install safetensor if you do not have one installed yet. 57 | """ 58 | assert filename.endswith(".safetensors") 59 | _in = self.lora_vit.head.in_features 60 | _out = self.lora_vit.head.out_features 61 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} 62 | save_file(fc_tensors, filename) 63 | 64 | def load_fc_parameters(self, filename: str) -> None: 65 | r"""Only safetensors is supported now. 66 | 67 | pip install safetensor if you do not have one installed yet. 68 | """ 69 | 70 | assert filename.endswith(".safetensors") 71 | _in = self.lora_vit.head.in_features 72 | _out = self.lora_vit.head.out_features 73 | with safe_open(filename, framework="pt") as f: 74 | saved_key = f"fc_{_in}in_{_out}out" 75 | try: 76 | saved_tensor = f.get_tensor(saved_key) 77 | self.lora_vit.head.weight = Parameter(saved_tensor) 78 | except ValueError: 79 | print("this fc weight is not for this model") 80 | 81 | def save_lora_parameters(self, filename: str) -> None: 82 | r"""Only safetensors is supported now. 83 | 84 | pip install safetensor if you do not have one installed yet. 85 | 86 | save both lora and fc parameters. 87 | """ 88 | 89 | assert filename.endswith(".safetensors") 90 | 91 | num_layer = len(self.w_As) # actually, it is half 92 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} 93 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} 94 | 95 | _in = self.lora_vit.head.in_features 96 | _out = self.lora_vit.head.out_features 97 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} 98 | 99 | merged_dict = {**a_tensors, **b_tensors, **fc_tensors} 100 | save_file(merged_dict, filename) 101 | 102 | def load_lora_parameters(self, filename: str) -> None: 103 | r"""Only safetensors is supported now. 104 | 105 | pip install safetensor if you do not have one installed yet.\ 106 | 107 | load both lora and fc parameters. 108 | """ 109 | 110 | assert filename.endswith(".safetensors") 111 | 112 | with safe_open(filename, framework="pt") as f: 113 | for i, w_A_linear in enumerate(self.w_As): 114 | saved_key = f"w_a_{i:03d}" 115 | saved_tensor = f.get_tensor(saved_key) 116 | w_A_linear.weight = Parameter(saved_tensor) 117 | 118 | for i, w_B_linear in enumerate(self.w_Bs): 119 | saved_key = f"w_b_{i:03d}" 120 | saved_tensor = f.get_tensor(saved_key) 121 | w_B_linear.weight = Parameter(saved_tensor) 122 | 123 | _in = self.lora_vit.head.in_features 124 | _out = self.lora_vit.head.out_features 125 | saved_key = f"fc_{_in}in_{_out}out" 126 | try: 127 | saved_tensor = f.get_tensor(saved_key) 128 | self.lora_vit.head.weight = Parameter(saved_tensor) 129 | except ValueError: 130 | print("this fc weight is not for this model") 131 | 132 | def reset_parameters(self) -> None: 133 | for w_A in self.w_As: 134 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) 135 | for w_B in self.w_Bs: 136 | nn.init.zeros_(w_B.weight) 137 | 138 | 139 | class LoRA_Sam(LoRA): 140 | """Applies low-rank adaptation to a Sam model's image encoder. 141 | 142 | Args: 143 | sam_model: a vision transformer model, see base_vit.py 144 | r: rank of LoRA 145 | num_classes: how many classes the model output, default to the vit model 146 | lora_layer: which layer we apply LoRA. 147 | 148 | Examples:: 149 | >>> model = ViT('B_16_imagenet1k') 150 | >>> lora_model = LoRA_ViT(model, r=4) 151 | >>> preds = lora_model(img) 152 | >>> print(preds.shape) 153 | torch.Size([1, 1000]) 154 | """ 155 | 156 | def __init__(self, sam_model, r: int, lora_layer=None): 157 | super(LoRA_Sam, self).__init__() 158 | 159 | assert r > 0 160 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels 161 | # dim = base_vit_dim 162 | if lora_layer: 163 | self.lora_layer = lora_layer 164 | flag = 0 165 | else: 166 | try: 167 | self.lora_layer = list(range(len(sam_model.image_encoder.blocks))) 168 | flag = 0 169 | except: 170 | self.lora_layer = list(range(len(sam_model.image_encoder.trunk.blocks))) 171 | flag = 1 172 | # create for storage, then we can init them or load weights 173 | self.w_As = [] # These are linear layers 174 | self.w_Bs = [] 175 | 176 | # lets freeze first 177 | for param in sam_model.image_encoder.parameters(): 178 | param.requires_grad = False 179 | if flag ==0: 180 | blocks = sam_model.image_encoder.blocks 181 | elif flag ==1: 182 | blocks = sam_model.image_encoder.trunk.blocks 183 | # Here, we do the surgery 184 | for t_layer_i, blk in enumerate(blocks): 185 | # If we only want few lora layer instead of all 186 | if t_layer_i not in self.lora_layer: 187 | continue 188 | w_qkv_linear = blk.attn.qkv 189 | 190 | self.dim = w_qkv_linear.in_features 191 | w_a_linear_q = nn.Linear(self.dim, r, bias=False) 192 | w_b_linear_q = nn.Linear(r, self.dim, bias=False) 193 | w_a_linear_v = nn.Linear(self.dim, r, bias=False) 194 | w_b_linear_v = nn.Linear(r, self.dim, bias=False) 195 | self.w_As.append(w_a_linear_q) 196 | self.w_Bs.append(w_b_linear_q) 197 | self.w_As.append(w_a_linear_v) 198 | self.w_Bs.append(w_b_linear_v) 199 | blk.attn.qkv = _LoRA_qkv( 200 | w_qkv_linear, 201 | w_a_linear_q, 202 | w_b_linear_q, 203 | w_a_linear_v, 204 | w_b_linear_v, 205 | ) 206 | self.reset_parameters() 207 | # self.sam = sam_model 208 | self.lora_vit = sam_model 209 | 210 | -------------------------------------------------------------------------------- /utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.cluster import KMeans 4 | 5 | 6 | def uniform_sampling(masks, N=1): 7 | n_points = [] 8 | for i, mask in enumerate(masks): 9 | if not isinstance(mask, np.ndarray): 10 | mask = mask.cpu().numpy() 11 | indices = np.argwhere(mask == 1) # [y, x] 12 | sampled_indices = np.random.choice(len(indices), N, replace=True) 13 | sampled_points = np.flip(indices[sampled_indices], axis=1) 14 | n_points.append(sampled_points.tolist()) 15 | 16 | return n_points 17 | 18 | 19 | def get_multi_distance_points(input_point, mask, points_nubmer): 20 | new_points = np.zeros((points_nubmer + 1, 2)) 21 | new_points[0] = [input_point[1], input_point[0]] 22 | for i in range(points_nubmer): 23 | new_points[i + 1] = get_next_distance_point(new_points[:i + 1, :], mask) 24 | 25 | new_points = swap_xy(new_points) 26 | return new_points 27 | 28 | 29 | def get_next_distance_point(input_points, mask): 30 | max_distance_point = [0, 0] 31 | max_distance = 0 32 | input_points = np.array(input_points) 33 | 34 | indices = np.argwhere(mask == True) 35 | for x, y in indices: 36 | # print(x,y,input_points) 37 | distance = np.sum(np.sqrt((x - input_points[:, 0]) ** 2 + (y - input_points[:, 1]) ** 2)) 38 | if max_distance < distance: 39 | max_distance_point = [x, y] 40 | max_distance = distance 41 | return max_distance_point 42 | 43 | 44 | def swap_xy(points): 45 | new_points = np.zeros((len(points),2)) 46 | new_points[:,0] = points[:,1] 47 | new_points[:,1] = points[:,0] 48 | return new_points 49 | 50 | 51 | def k_means_sampling(mask, k): 52 | points = np.argwhere(mask == 1) # [y, x] 53 | points = np.flip(points, axis=1) 54 | 55 | kmeans = KMeans(n_clusters=k) 56 | kmeans.fit(points) 57 | points = kmeans.cluster_centers_ 58 | return points 59 | 60 | 61 | def get_point_prompt_max_dist(masks, num_points): 62 | n_points = [] 63 | for mask in masks: 64 | mask_np = mask.cpu().numpy() 65 | 66 | indices = np.argwhere(mask_np > 0) 67 | random_index = np.random.choice(len(indices), 1)[0] 68 | 69 | first_point = [indices[random_index][1], indices[random_index][0]] 70 | new_points = get_multi_distance_points(first_point, mask_np, num_points - 1) 71 | n_points.append(new_points) 72 | 73 | return n_points 74 | 75 | 76 | def get_point_prompt_kmeans(masks, num_points): 77 | n_points = [] 78 | for mask in masks: 79 | mask_np = mask.cpu().numpy() 80 | points = k_means_sampling(mask_np, num_points) 81 | n_points.append(points.astype(int)) 82 | return n_points 83 | 84 | 85 | def get_point_prompts(gt_masks, num_points): 86 | prompts = [] 87 | # print('prompt',len(gt_masks)) 88 | for mask in gt_masks: 89 | # print('prompt',len(mask)) 90 | po_points = uniform_sampling(mask, num_points) 91 | # print('ori_po',po_points) 92 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points) 93 | # print('na_points',na_points) 94 | # print('na_points',na_points) 95 | po_point_coords = torch.tensor(po_points, device=mask.device) 96 | na_point_coords = torch.tensor(na_points, device=mask.device) 97 | # print('po_point_coords',po_point_coords.shape) 98 | # print('na_point_coords',na_point_coords.shape) 99 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1) 100 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=po_point_coords.device) 101 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=na_point_coords.device) 102 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1) 103 | in_points = (point_coords, point_labels) 104 | prompts.append(in_points) 105 | return prompts -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import copy 5 | import numpy as np 6 | from torchsummary import summary 7 | from torch import Tensor 8 | 9 | def freeze(model: torch.nn.Module): 10 | model.eval() 11 | for param in model.parameters(): 12 | param.requires_grad = False 13 | 14 | 15 | def momentum_update(student_model, teacher_model, momentum=0.99): 16 | for (src_name, src_param), (tgt_name, tgt_param) in zip( 17 | student_model.named_parameters(), teacher_model.named_parameters() 18 | ): 19 | if src_param.requires_grad: 20 | # print('src_name',src_name) 21 | # print('tgt_name',tgt_name) 22 | tgt_param.data.mul_(momentum).add_(src_param.data, alpha=1 - momentum) 23 | 24 | 25 | def decode_mask(mask): 26 | """ 27 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects 28 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects. 29 | 30 | Args: 31 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 32 | 33 | Returns: 34 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 35 | """ 36 | unique_labels = torch.unique(mask) 37 | unique_labels = unique_labels[unique_labels != 0] 38 | n_objects = len(unique_labels) 39 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64) 40 | for i, label in enumerate(unique_labels): 41 | new_mask[i] = (mask == label).squeeze(0) 42 | return new_mask 43 | 44 | 45 | def encode_mask(mask): 46 | """ 47 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects 48 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 49 | 50 | Args: 51 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 52 | 53 | Returns: 54 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 55 | """ 56 | n_objects = mask.shape[0] 57 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64) 58 | for i in range(n_objects): 59 | new_mask[0][mask[i] == 1] = i + 1 60 | return new_mask 61 | 62 | 63 | def copy_model(model: torch.nn.Module): 64 | new_model = copy.deepcopy(model) 65 | freeze(new_model) 66 | return new_model 67 | 68 | 69 | def create_csv(filename, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]): 70 | if os.path.exists(filename): 71 | return 72 | with open(filename, 'w') as csvfile: 73 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head) 74 | csv_write.writeheader() 75 | 76 | 77 | def write_csv(filename, csv_dict, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]): 78 | with open(filename, 'a+') as csvfile: 79 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head, extrasaction='ignore') 80 | csv_write.writerow(csv_dict) 81 | 82 | 83 | def check_grad(model: torch.nn.Module): 84 | for name, param in model.named_parameters(): 85 | print(f"{name}: {param.requires_grad}") 86 | 87 | 88 | def check_model(model): 89 | return summary(model, (3, 1024, 1024), batch_size=1, device="cuda") 90 | 91 | 92 | def reduce_instances(bboxes, gt_masks, max_nums=50): 93 | bboxes_ = [] 94 | gt_masks_ = [] 95 | for bbox, gt_mask in zip(bboxes, gt_masks): 96 | idx = np.arange(bbox.shape[0]) 97 | np.random.shuffle(idx) 98 | bboxes_.append(bbox[idx[:max_nums]]) 99 | gt_masks_.append(gt_mask[idx[:max_nums]]) 100 | 101 | bboxes = bboxes_ 102 | gt_masks = gt_masks_ 103 | return bboxes, gt_masks 104 | 105 | def _to_cpu(data): 106 | """transfer all tensors to cpu.""" 107 | if isinstance(data, Tensor): 108 | return data.to('cpu') 109 | elif isinstance(data, list): 110 | return [_to_cpu(d) for d in data] 111 | elif isinstance(data, tuple): 112 | return tuple(_to_cpu(d) for d in data) 113 | elif isinstance(data, dict): 114 | return {k: _to_cpu(v) for k, v in data.items()} 115 | else: 116 | return data 117 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import deque 3 | from tqdm import tqdm 4 | from box import Box 5 | import numpy as np 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .finch import FINCH 13 | from .sample_utils import get_point_prompts 14 | 15 | class Store: 16 | def __init__(self, total_num_classes, items_per_class, shuffle=False): 17 | self.shuffle = shuffle 18 | self.items_per_class = items_per_class 19 | self.total_num_classes = total_num_classes 20 | self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)] 21 | 22 | def add(self, items, class_ids): 23 | for idx, class_id in enumerate(class_ids): 24 | self.store[class_id].append(items[idx]) 25 | 26 | def retrieve(self, class_id): 27 | if class_id != -1: 28 | items = [] 29 | for item in self.store[class_id]: 30 | items.extend(list(item)) 31 | if self.shuffle: 32 | random.shuffle(items) 33 | return items 34 | else: 35 | all_items = [] 36 | for i in range(self.total_num_classes): 37 | items = [] 38 | for item in self.store[i]: 39 | items.append(list(item)) 40 | all_items.append(items) 41 | return all_items 42 | 43 | def reset(self): 44 | self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)] 45 | 46 | def __str__(self): 47 | s = self.__class__.__name__ + '(' 48 | for idx, item in enumerate(self.store): 49 | s += '\n Class ' + str(idx) + ' --> ' + str(len(list(item))) + ' items' 50 | s = s + ' )' 51 | return s 52 | 53 | def __repr__(self): 54 | return self.__str__() 55 | 56 | def __len__(self): 57 | return sum([len(s) for s in self.store]) 58 | 59 | def concatenate_images_with_padding(images, padding=10, color=(255, 255, 255)): 60 | heights = [image.shape[0] for image in images] 61 | widths = [image.shape[1] for image in images] 62 | 63 | total_width = sum(widths) + padding * (len(images) - 1) 64 | max_height = max(heights) 65 | 66 | if len(images[0].shape) == 3: 67 | new_image = np.full((max_height, total_width, 3), color, dtype=np.uint8) 68 | else: 69 | new_image = np.full((max_height, total_width), color[0], dtype=np.uint8) 70 | 71 | x_offset = 0 72 | for image in images: 73 | new_image[0:image.shape[0], x_offset:x_offset + image.shape[1]] = image 74 | x_offset += image.shape[1] + padding 75 | 76 | return new_image 77 | 78 | def calculate_iou(mask1, mask2): 79 | intersection = torch.logical_and(mask1, mask2) 80 | union = torch.logical_or(mask1, mask2) 81 | iou = torch.sum(intersection).float() / torch.sum(union).float() 82 | return iou 83 | 84 | def calc_iou_matrix(mask_list1, mask_list2): 85 | iou_matrix = torch.zeros((len(mask_list1), len(mask_list2))) 86 | for i, mask1 in enumerate(mask_list1): 87 | for j, mask2 in enumerate(mask_list2): 88 | iou_matrix[i, j] = calculate_iou(mask1, mask2) 89 | return iou_matrix 90 | 91 | def cal_mask_ious( 92 | cfg, 93 | model, 94 | images_weak, 95 | prompts, 96 | gt_masks, 97 | ): 98 | with torch.no_grad(): 99 | _, soft_masks, _, _ = model(images_weak, prompts) 100 | 101 | for i, (soft_mask, gt_mask) in enumerate(zip(soft_masks, gt_masks)): 102 | soft_mask = (soft_mask > 0).float() 103 | mask_ious = calc_iou_matrix(soft_mask, soft_mask) 104 | indices = torch.arange(mask_ious.size(0)) 105 | mask_ious[indices, indices] = 0.0 106 | return mask_ious, soft_mask 107 | 108 | 109 | def neg_prompt_calibration( 110 | cfg, 111 | mask_ious, 112 | prompts, 113 | ): 114 | ''' 115 | mask_ious:[mask_nums,mask_nums] 116 | ''' 117 | point_list = [] 118 | point_labels_list = [] 119 | num_points = cfg.num_points 120 | for m in range(len(mask_ious)): 121 | 122 | pos_point_coords = prompts[0][0][m][:num_points].unsqueeze(0) 123 | neg_point = prompts[0][0][m][num_points:].unsqueeze(0) 124 | neg_points_list = [] 125 | neg_points_list.extend(neg_point[0]) 126 | 127 | indices = torch.nonzero(mask_ious[m] > float(cfg.iou_thr)).squeeze(1) 128 | 129 | if indices.numel() != 0: 130 | # neg_points_list = [] 131 | for indice in indices: 132 | neg_points_list.extend(prompts[0][0][indice][:num_points]) 133 | neg_points = random.sample(neg_points_list, num_points) 134 | else: 135 | neg_points =neg_points_list 136 | 137 | neg_point_coords = torch.tensor([p.tolist() for p in neg_points], device=neg_point.device).unsqueeze(0) 138 | 139 | point_coords = torch.cat((pos_point_coords, neg_point_coords), dim=1) 140 | 141 | point_list.append(point_coords) 142 | pos_point_labels = torch.ones(pos_point_coords.shape[0:2], dtype=torch.int, device=neg_point.device) 143 | neg_point_labels = torch.zeros(neg_point_coords.shape[0:2], dtype=torch.int, device=neg_point.device) 144 | point_labels = torch.cat((pos_point_labels, neg_point_labels), dim=1) 145 | point_labels_list.append(point_labels) 146 | 147 | point_ = torch.cat(point_list).squeeze(1) 148 | point_labels_ = torch.cat(point_labels_list) 149 | new_prompts = [(point_, point_labels_)] 150 | return new_prompts 151 | 152 | def get_prompts(cfg: Box, bboxes, gt_masks): 153 | if cfg.prompt == "box" or cfg.prompt == "coarse": 154 | prompts = bboxes 155 | elif cfg.prompt == "point": 156 | prompts = get_point_prompts(gt_masks, cfg.num_points) 157 | else: 158 | raise ValueError("Prompt Type Error!") 159 | return prompts 160 | 161 | def generate_predict_feats(cfg, embed, pseudo_label, gts): 162 | coords, lbls = gts 163 | selected_coords = [] 164 | 165 | num_insts = len(pseudo_label) 166 | num_points = cfg.num_points 167 | for coord_grp, lbl_grp in zip(coords, lbls): 168 | for coord, lbl in zip(coord_grp, lbl_grp): 169 | if lbl.item() == 1: 170 | selected_coords.append(coord.tolist()) 171 | 172 | # Downsample coordinates (SAM's stride is 16) 173 | coords = [[int(c // 16) for c in pair] for pair in selected_coords] 174 | 175 | embed = embed.permute(1, 2, 0) # [H, W, C] 176 | 177 | pos_pts = [] 178 | 179 | for index in range(0, num_insts * num_points, num_points): 180 | index = random.randint(0, num_points - 1) 181 | x, y = coords[index] 182 | pos_pt = embed[x, y] 183 | pos_pts.append(pos_pt) 184 | 185 | predict_feats = torch.stack(pos_pts, dim=0) 186 | 187 | return predict_feats 188 | 189 | 190 | def offline_prototypes_generation(cfg, model, loader): 191 | model.eval() 192 | pts = [] 193 | max_iters = 128 194 | num_points = cfg.num_points 195 | 196 | with torch.no_grad(): 197 | for i, batch in enumerate(tqdm(loader, desc='Generating target prototypes', ncols=100)): 198 | if i >= max_iters: 199 | break 200 | imgs, boxes, masks, _ = batch 201 | prompts = get_prompts(cfg, boxes, masks) 202 | 203 | embeds, masks, _, _ = model(imgs, prompts) 204 | del _ 205 | 206 | if isinstance(embeds, dict): 207 | embeds = embeds['vision_features'] 208 | 209 | for embed, prompt, mask in zip(embeds, prompts, masks): 210 | num_insts = len(mask) 211 | embed = embed.permute(1, 2, 0) # [H, W, C] 212 | coords = [] 213 | 214 | points, labels = prompt 215 | for point_grp, label_grp in zip(points, labels): 216 | for point, label in zip(point_grp, label_grp): 217 | if label.item() == 1: 218 | coords.append(point.tolist()) 219 | 220 | # 16 is the stride of SAM 221 | coords = [[int(pt / 16) for pt in pair] for pair in coords] 222 | for index in range(0, num_insts*num_points, num_points): 223 | x, y = coords[index] 224 | pt = embed[x,y] 225 | pts.append(pt) 226 | 227 | fin = FINCH(verbose=True) 228 | pts = torch.stack(pts).cpu().numpy() 229 | res = fin.fit(pts) 230 | 231 | last_key = list(res.partitions.keys())[-1] 232 | pt_stats = {'target_pts': res.partitions[last_key]['cluster_centers']} 233 | return pt_stats 234 | 235 | 236 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torchvision.utils import draw_bounding_boxes 5 | from torchvision.utils import draw_segmentation_masks 6 | 7 | from box import Box 8 | from tqdm import tqdm 9 | from model import Model 10 | from datasets.COCO import COCODataset 11 | 12 | 13 | def draw_image(image, masks, boxes, labels, alpha=0.4): 14 | image = torch.from_numpy(image).permute(2, 0, 1) 15 | if boxes is not None: 16 | image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2) 17 | if masks is not None: 18 | image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha) 19 | return image.numpy().transpose(1, 2, 0) 20 | 21 | 22 | def visualize(cfg: Box): 23 | model = Model(cfg) 24 | model.setup() 25 | model.eval() 26 | model.cuda() 27 | dataset = COCODataset(root_dir=cfg.dataset.val.root_dir, 28 | annotation_file=cfg.dataset.val.annotation_file, 29 | transform=None) 30 | predictor = model.get_predictor() 31 | os.makedirs(cfg.out_dir, exist_ok=True) 32 | 33 | for image_id in tqdm(dataset.image_ids): 34 | image_info = dataset.coco.loadImgs(image_id)[0] 35 | image_path = os.path.join(dataset.root_dir, image_info['file_name']) 36 | image_output_path = os.path.join(cfg.out_dir, image_info['file_name']) 37 | image = cv2.imread(image_path) 38 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 39 | ann_ids = dataset.coco.getAnnIds(imgIds=image_id) 40 | anns = dataset.coco.loadAnns(ann_ids) 41 | bboxes = [] 42 | for ann in anns: 43 | x, y, w, h = ann['bbox'] 44 | bboxes.append([x, y, x + w, y + h]) 45 | bboxes = torch.as_tensor(bboxes, device=model.model.device) 46 | transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2]) 47 | predictor.set_image(image) 48 | masks, _, _ = predictor.predict_torch( 49 | point_coords=None, 50 | point_labels=None, 51 | boxes=transformed_boxes, 52 | multimask_output=False, 53 | ) 54 | image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None) 55 | cv2.imwrite(image_output_path, image_output) 56 | 57 | --------------------------------------------------------------------------------