├── .gitignore
├── LICENSE
├── README.md
├── assets
└── overview.jpg
├── configs
├── base_config.py
├── config_hrsid.py
├── config_nwpu.py
└── config_whu.py
├── data
├── HRSID
│ └── Annotations
│ │ ├── inshore
│ │ ├── inshore_test.json
│ │ └── inshore_train.json
│ │ └── offshore
│ │ ├── offshore_test.json
│ │ └── offshore_train.json
├── NWPU
│ └── Annotations
│ │ ├── NWPU_instances_train.json
│ │ └── NWPU_instances_val.json
└── WHU
│ └── annotations
│ ├── WHU_building_test.json
│ ├── WHU_building_train.json
│ └── WHU_building_val.json
├── datasets
├── HRSID.py
├── NWPU.py
├── WHU.py
├── __init__.py
├── augmentation.py
└── tools.py
├── inference.py
├── pretrain
└── Where_To_Save_Pretrained_SAM_Checkpoints
├── requirements.txt
├── scripts
├── train_hrsid_pointsam.sh
├── train_hrsid_selftrain.sh
├── train_hrsid_supervise.sh
├── train_nwpu_pointsam.sh
├── train_nwpu_selftrain.sh
├── train_nwpu_supervise.sh
├── train_whu_pointsam.sh
├── train_whu_selftrain.sh
└── train_whu_supervise.sh
├── segment_anything
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── image_encoder_adapter.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ └── transforms.py
├── train_pointsam.py
├── train_selftrain.py
├── train_supervise.py
└── utils
├── eval_utils.py
├── finch.py
├── losses.py
├── model.py
├── sam_lora.py
├── sample_utils.py
├── tools.py
├── utils.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 | .idea/*
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/en/_build/
70 | docs/zh_cn/_build/
71 | src
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # SageMath parsed files
90 | *.sage.py
91 |
92 | # Environments
93 | .env
94 | .venv
95 | env/
96 | venv/
97 | ENV/
98 | env.bak/
99 | venv.bak/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 | .dmypy.json
114 | dmypy.json
115 |
116 | # Pyre type checker
117 | .pyre/
118 | .DS_Store
119 | .idea
120 | *work_dirs*
121 | tmp
122 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Nanqing Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # PointSAM: Pointly-Supervised Segment Anything Model for Remote Sensing Images
4 |
5 |
6 |
7 |
8 | [](https://arxiv.org/abs/2409.13401)
9 |
10 |
11 |
12 |
19 |
20 |
21 | ---
22 | ## 📢 Latest Updates
23 | - **2 Jan 2025**: **PointSAM** has been accepted by TGRS and is now available [here](https://ieeexplore.ieee.org/document/10839471).
24 | - **8 Dec 2024**: The complete code is released.
25 | - **20 Sep 2024**: The arXiv version is released [here](https://arxiv.org/abs/2409.13401).
26 | ---
27 |
28 |
29 |
30 | ## 🎨 Overview
31 |
32 | 
33 |
34 | ## 🎮 Getting Started
35 | ### 1.Install Environment
36 | To ensure compatibility, **Python version must not exceed 3.10**. Follow these steps to set up your environment:
37 | ```bash
38 | conda create --name pointsam python=3.10
39 | conda activate pointsam
40 |
41 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
42 | git clone https://github.com/Lans1ng/PointSAM.git
43 | cd PointSAM
44 | pip install -r requirements.txt
45 | ```
46 |
47 | **Note:**
48 | The CUDA version in the `pip install` command is specified as `cu118` (CUDA 11.8). If your system uses a different CUDA version (e.g., CUDA 12.1), replace `cu118` with the appropriate version tag (e.g., `cu121`).
49 |
50 | ### 2.Prepare Dataset
51 |
52 | #### WHU Building Dataset
53 |
54 | - Dataset download address: [WHU Building Dataset](https://aistudio.baidu.com/datasetdetail/56502)。
55 |
56 | - For converting semantic label to instance label, you can refer to corresponding [conversion script](https://github.com/KyanChen/RSPrompter/blob/release/tools/rsprompter/whu2coco.py).
57 |
58 | #### HRSID Dataset
59 |
60 | - Dataset download address: [HRSID Dataset](https://github.com/chaozhong2010/HRSID).
61 |
62 | #### NWPU VHR-10 Dataset
63 |
64 | - Dataset download address: [NWPU VHR-10 Dataset](https://aistudio.baidu.com/datasetdetail/52812).
65 |
66 | - Instance label download address: [NWPU VHR-10 Instance Label](https://github.com/chaozhong2010/VHR-10_dataset_coco).
67 |
68 | For convenience, the necessary JSON annotations are included in this repo. You only need to download the corresponding images. Organize your dataset as follows:
69 |
70 | ```
71 | data
72 | ├── WHU
73 | │ ├── annotations
74 | │ │ ├── WHU_building_train.json
75 | │ │ ├── WHU_building_test.json
76 | │ │ └── WHU_building_val.json
77 | │ └── images
78 | │ ├── train
79 | │ │ ├── image
80 | │ │ └── label
81 | │ ├── val
82 | │ │ ├── image
83 | │ │ └── label
84 | │ └── test
85 | │ ├── image
86 | │ └── label
87 | ├── HRSID
88 | │ ├── Annotations
89 | │ │ ├── all
90 | │ │ ├── inshore
91 | │ │ │ ├── inshore_test.json
92 | │ │ │ └── inshore_train.json
93 | │ │ └── offshore
94 | │ └── Images
95 | └── NWPU
96 | ├── Annotations
97 | │ ├── NWPU_instnaces_train.json
98 | │ └── NWPU_instnaces_val.json
99 | └── Images
100 |
101 | ```
102 | ### 3.Download Checkpoints
103 |
104 | Click the links below to download the checkpoint for the corresponding model type.
105 |
106 | - `vit-h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)
107 | - `vit-l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
108 | - `vit-b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
109 |
110 | After downloading, move the models to the `pretrain` folder.
111 |
112 | **Note**: In our project, only the `vit-b` model is used.
113 |
114 | ### 4.Training
115 | For convenience, the `scripts` folder contains instructions for **Supervised Training**, **Self-Training**, and **PointSAM** on the NWPU VHR-10, WHU, and HRSID datasets.
116 |
117 | Here’s an example of training PointSAM on the WHU dataset:
118 | ```bash
119 | bash scripts/train_whu_pointsam.sh
120 | ```
121 |
122 | ### 5. Inference
123 |
124 | Here’s an example of how to perform inference:
125 |
126 | ```
127 | python inference.py --cfg --out_dir --ckpt
128 | ```
129 |
130 | Please replace ``, ``, and `` with the values of the actual path.
131 |
132 | **Note:** The generated results consist of four images arranged in parallel:
133 |
134 | - The first image is the original input image.
135 | - The second image is the visualization of the GT mask.
136 | - The third image is the result obtained by direct testing through the original SAM.
137 | - The fourth image is the result obtained using the provided checkpoint.
138 |
139 |
140 | ## 💡 Acknowledgement
141 |
142 | - [wesam](https://github.com/zhang-haojie/wesam)
143 | - [OWOD](https://github.com/JosephKJ/OWOD)
144 | - [RSPrompter](https://github.com/KyanChen/RSPrompter)
145 |
146 |
147 | ## 🖊️ Citation
148 |
149 | If you find this project useful in your research, please consider starring ⭐ and citing 📚:
150 |
151 | ```BibTeX
152 | @ARTICLE{10839471,
153 | author={Liu, Nanqing and Xu, Xun and Su, Yongyi and Zhang, Haojie and Li, Heng-Chao},
154 | journal={IEEE Transactions on Geoscience and Remote Sensing},
155 | title={PointSAM: Pointly-Supervised Segment Anything Model for Remote Sensing Images},
156 | year={2025},
157 | volume={63},
158 | number={},
159 | pages={1-15},
160 | doi={10.1109/TGRS.2025.3529031}}
161 |
162 | ```
163 |
--------------------------------------------------------------------------------
/assets/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lans1ng/PointSAM/d6831edf23de4bd7304fb46a5ce2a0baa69dcbd5/assets/overview.jpg
--------------------------------------------------------------------------------
/configs/base_config.py:
--------------------------------------------------------------------------------
1 | base_config = {
2 | "eval_interval": 1,
3 | "ema_rate": 0.999,
4 | "csv_keys": [ "Prompt", "IoU", "Recall", "Precision", "F1", "epoch"],
5 | "opt": {
6 | "learning_rate": 1e-5,
7 | "weight_decay": 1e-4,#
8 | "decay_factor": 10,
9 | "steps": [3000, 8000],
10 | "warmup_steps": 250,
11 | },
12 | "corruptions": [
13 | "gaussian_noise",
14 | "shot_noise",
15 | "impulse_noise",
16 | "defocus_blur",
17 | "glass_blur",
18 | "motion_blur",
19 | "zoom_blur",
20 | "snow",
21 | "frost",
22 | "fog",
23 | "brightness",
24 | "contrast",
25 | "elastic_transform",
26 | "pixelate",
27 | "jpeg_compression",
28 | ],
29 | "model": {
30 | "type": "vit_b",
31 | "checkpoint": "./pretrain/",
32 | "ckpt": "",
33 | "freeze": {
34 | "image_encoder": True,
35 | "prompt_encoder": True,
36 | "mask_decoder": True,
37 | },
38 | },
39 | "datasets": {
40 | "NWPU": {
41 | "root_dir": "data/NWPU/Images",
42 | "annotation_file_train": "data/NWPU/Annotations/NWPU_instances_train.json",
43 | "annotation_file_val": "data/NWPU/Annotations/NWPU_instances_val.json",
44 | },
45 | "WHU": {
46 | "root_dir": "data/WHU",
47 | "annotation_file_train": "data/WHU/annotations/WHU_building_train.json",
48 | "annotation_file_val": "data/WHU/annotations/WHU_building_val.json",
49 | },
50 | "HRSID": {
51 | "root_dir": "/root/autodl-fs/_DATASETS/HRSID/Images",
52 | "annotation_file_train": "/root/autodl-fs/_DATASETS/HRSID/Annotations/inshore/inshore_train.json",
53 | "annotation_file_val": "/root/autodl-fs/_DATASETS/HRSID/Annotations/inshore/inshore_test.json"
54 | },
55 | },
56 | }
57 |
--------------------------------------------------------------------------------
/configs/config_hrsid.py:
--------------------------------------------------------------------------------
1 | from box import Box
2 | from configs.base_config import base_config
3 |
4 | config = {
5 | "dataset": "HRSID",
6 | "load_type": "soft",
7 | "num_points": 1,
8 |
9 | "batch_size": 1, #only support 1
10 | "val_batchsize": 1,
11 | "num_workers": 0,
12 | "num_epochs": 10,
13 | "max_nums": 50,
14 | "resume": False,
15 |
16 | "start_lora_layer": 1,
17 | "lora_rank": 4,
18 | "mem_bank_max_len": 512,
19 | "match_interval": 30,
20 | "iou_thr": 0.1,
21 |
22 | "prompt": "point",
23 | "out_dir": "",
24 | "name": "base",
25 | "corrupt": None,
26 | "visual": False,
27 | "model": {
28 | "type": "vit_b",
29 | },
30 | "opt": {
31 | "learning_rate": 5e-4,
32 | "weight_decay": 1e-4,
33 | "decay_factor": 10,
34 | "steps": [1500, 2000],
35 | "warmup_steps": 250,
36 | },
37 | }
38 |
39 | cfg = Box(base_config)
40 | cfg.merge_update(config)
41 |
--------------------------------------------------------------------------------
/configs/config_nwpu.py:
--------------------------------------------------------------------------------
1 | from box import Box
2 | from configs.base_config import base_config
3 |
4 | config = {
5 | "dataset": "NWPU",
6 | "load_type": "soft",
7 | "num_points": 1,
8 |
9 | "batch_size": 1, #only support 1
10 | "val_batchsize": 1,
11 | "num_workers": 0,
12 | "num_epochs": 10,
13 | "max_nums": 50,
14 | "resume": False,
15 |
16 | "start_lora_layer": 6,
17 | "lora_rank": 4,
18 | "mem_bank_max_len": 128,
19 | "match_interval": 30,
20 | "iou_thr": 0.1,
21 |
22 | "prompt": "point",
23 | "out_dir": "",
24 | "name": "base",
25 | "corrupt": None,
26 | "visual": False,
27 | "model": {
28 | "type": "vit_b",
29 | },
30 | "opt": {
31 | "learning_rate": 5e-4,
32 | "weight_decay": 1e-4,
33 | "decay_factor": 10,
34 | "steps": [2000, 4000],
35 | "warmup_steps": 250,
36 | },
37 | }
38 |
39 | cfg = Box(base_config)
40 | cfg.merge_update(config)
41 |
--------------------------------------------------------------------------------
/configs/config_whu.py:
--------------------------------------------------------------------------------
1 | from box import Box
2 | from configs.base_config import base_config
3 |
4 | config = {
5 | "dataset": "WHU",
6 | "load_type": "soft",
7 | "num_points": 1,
8 |
9 | "batch_size": 1, #only support 1
10 | "val_batchsize": 1,
11 | "num_workers": 0,
12 | "num_epochs": 5,
13 | "max_nums": 30,
14 | "resume": False,
15 |
16 | "start_lora_layer": 1,
17 | "lora_rank": 4,
18 | "mem_bank_max_len": 512,
19 | "match_interval": 30,
20 | "iou_thr": 0.1,
21 |
22 | "prompt": "point",
23 | "out_dir": "",
24 | "name": "base",
25 | "corrupt": None,
26 | "visual": False,
27 | "model": {
28 | "type": "vit_b",
29 | },
30 | "opt": {
31 | "learning_rate": 5e-4,
32 | "weight_decay": 1e-4,
33 | "decay_factor": 10,
34 | "steps": [8000,15000],
35 | "warmup_steps": 250,
36 | },
37 | }
38 |
39 | cfg = Box(base_config)
40 | cfg.merge_update(config)
41 |
--------------------------------------------------------------------------------
/datasets/HRSID.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data import Dataset
7 | from pycocotools.coco import COCO
8 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_
9 |
10 | class HRSIDDataset(Dataset):
11 | def __init__(self, cfg, root_dir, annotation_file, transform=None, training=False, if_self_training=False, gen_pt=False):
12 | self.cfg = cfg
13 | self.root_dir = root_dir
14 | self.transform = transform
15 | self.coco = COCO(annotation_file)
16 |
17 | image_ids = sorted(list(self.coco.imgs.keys()))
18 |
19 | # only for inshore
20 | if gen_pt:
21 | removed_ids = [351,375,376]
22 | else:
23 | if training:
24 | removed_ids = [351,375,376]
25 | else:
26 | removed_ids = [166]
27 |
28 |
29 | if removed_ids:
30 | for i in removed_ids:
31 | image_ids.remove(i)
32 |
33 | # Filter out image_ids without any annotations
34 | self.image_ids = [
35 | image_id
36 | for image_id in image_ids
37 | if len(self.coco.getAnnIds(imgIds=image_id)) > 0
38 | ]
39 | self.if_self_training = if_self_training
40 |
41 | def __len__(self):
42 | return len(self.image_ids)
43 |
44 | def __getitem__(self, idx):
45 | image_id = self.image_ids[idx]
46 | image_info = self.coco.loadImgs(image_id)[0]
47 | image_path = os.path.join(self.root_dir, image_info["file_name"])
48 | image = cv2.imread(image_path)
49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50 | ann_ids = self.coco.getAnnIds(imgIds=image_id)#the num of objects in one image
51 | anns = self.coco.loadAnns(ann_ids)
52 | bboxes = []
53 | masks = []
54 | categories = []
55 | for ann in anns:
56 | x, y, w, h = ann["bbox"]
57 |
58 | mask = self.coco.annToMask(ann)
59 |
60 | bboxes.append([x, y, x + w, y + h])
61 | masks.append(mask)
62 | categories.append(ann["category_id"])
63 |
64 |
65 | if self.if_self_training:
66 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
67 | if self.transform:
68 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
69 | image_strong = self.transform.transform_image(image_strong)
70 |
71 | bboxes_weak = np.stack(bboxes_weak, axis=0)
72 | masks_weak = np.stack(masks_weak, axis=0)
73 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path
74 |
75 | elif self.cfg.visual:
76 | origin_image = image
77 | origin_bboxes = bboxes
78 | origin_masks = masks
79 | if self.transform:
80 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
81 |
82 | bboxes = np.stack(bboxes, axis=0)
83 | masks = np.stack(masks, axis=0)
84 | origin_bboxes = np.stack(origin_bboxes, axis=0)
85 | origin_masks = np.stack(origin_masks, axis=0)
86 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
87 |
88 | else:
89 | if self.transform:
90 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
91 | bboxes = np.stack(bboxes, axis=0)
92 | masks = np.stack(masks, axis=0)
93 | return image, torch.tensor(bboxes), torch.tensor(masks).float(), image_path
94 |
95 | def load_datasets(cfg, img_size):
96 | transform = ResizeAndPad(img_size)
97 | train = HRSIDDataset(
98 | cfg,
99 | root_dir=cfg.datasets.HRSID.root_dir,
100 | annotation_file=cfg.datasets.HRSID.annotation_file_train,
101 | transform=transform,
102 | training=True,
103 | )
104 | train_dataloader = DataLoader(
105 | train,
106 | batch_size=cfg.batch_size,
107 | shuffle=True,
108 | num_workers=cfg.num_workers,
109 | collate_fn=collate_fn,
110 | )
111 |
112 | val = HRSIDDataset(
113 | cfg,
114 | root_dir=cfg.datasets.HRSID.root_dir,
115 | annotation_file=cfg.datasets.HRSID.annotation_file_val,
116 | transform=transform,
117 | )
118 | val_dataloader = DataLoader(
119 | val,
120 | batch_size=cfg.val_batchsize,
121 | shuffle=False,
122 | num_workers=cfg.num_workers,
123 | collate_fn=collate_fn,
124 | )
125 | return train_dataloader, val_dataloader
126 |
127 |
128 | def load_datasets_soft(cfg, img_size, return_pt = False):
129 | transform = ResizeAndPad(img_size)
130 |
131 | soft_train = HRSIDDataset(
132 | cfg,
133 | root_dir=cfg.datasets.HRSID.root_dir,
134 | annotation_file=cfg.datasets.HRSID.annotation_file_train,
135 | transform=transform,
136 | training=True,
137 | if_self_training=True,
138 | )
139 | soft_train_dataloader = DataLoader(
140 | soft_train,
141 | batch_size=cfg.batch_size,
142 | shuffle=True,
143 | num_workers=cfg.num_workers,
144 | collate_fn=collate_fn_soft,
145 | )
146 |
147 | val = HRSIDDataset(
148 | cfg,
149 | root_dir=cfg.datasets.HRSID.root_dir,
150 | annotation_file=cfg.datasets.HRSID.annotation_file_val,
151 | transform=transform,
152 | )
153 | val_dataloader = DataLoader(
154 | val,
155 | batch_size=cfg.val_batchsize,
156 | shuffle=False,
157 | num_workers=cfg.num_workers,
158 | collate_fn=collate_fn,
159 | )
160 |
161 | if return_pt:
162 | pt = HRSIDDataset(
163 | cfg,
164 | root_dir=cfg.datasets.HRSID.root_dir,
165 | annotation_file=cfg.datasets.HRSID.annotation_file_train,
166 | transform=transform,
167 | gen_pt = return_pt
168 | )
169 | pt_dataloader = DataLoader(
170 | pt,
171 | batch_size=cfg.val_batchsize,
172 | shuffle=False,
173 | num_workers=cfg.num_workers,
174 | collate_fn=collate_fn,
175 | )
176 | return soft_train_dataloader, val_dataloader, pt_dataloader
177 | else:
178 | return soft_train_dataloader, val_dataloader
179 |
180 | def load_datasets_visual(cfg, img_size):
181 | transform = ResizeAndPad(img_size)
182 | val = HRSIDDataset(
183 | cfg,
184 | root_dir=cfg.datasets.HRSID.root_dir,
185 | annotation_file=cfg.datasets.HRSID.annotation_file_val,
186 | transform=transform,
187 | )
188 | val_dataloader = DataLoader(
189 | val,
190 | batch_size=cfg.val_batchsize,
191 | shuffle=False,
192 | num_workers=cfg.num_workers,
193 | collate_fn=collate_fn_,
194 | )
195 | return val_dataloader
196 |
--------------------------------------------------------------------------------
/datasets/NWPU.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data import Dataset
7 | from pycocotools.coco import COCO
8 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_
9 |
10 | class NWPUDataset(Dataset):
11 | def __init__(self, cfg, root_dir, annotation_file, transform=None, training=False, if_self_training=False):
12 | self.cfg = cfg
13 | self.root_dir = root_dir
14 | self.transform = transform
15 | self.coco = COCO(annotation_file)
16 | self.image_ids = sorted(list(self.coco.imgs.keys()))
17 |
18 | self.if_self_training = if_self_training
19 |
20 | def __len__(self):
21 | return len(self.image_ids)
22 |
23 | def __getitem__(self, idx):
24 | image_id = self.image_ids[idx]
25 | image_info = self.coco.loadImgs(image_id)[0]
26 | image_path = os.path.join(self.root_dir, image_info["file_name"])
27 | image = cv2.imread(image_path)
28 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
29 | origin_image = image
30 | ann_ids = self.coco.getAnnIds(imgIds=image_id)
31 | anns = self.coco.loadAnns(ann_ids)
32 | bboxes = []
33 | masks = []
34 | categories = []
35 | for ann in anns:
36 | x, y, w, h = ann["bbox"]
37 | bboxes.append([x, y, x + w, y + h])
38 | mask = self.coco.annToMask(ann)
39 | masks.append(mask)
40 | categories.append(ann["category_id"])
41 | if self.if_self_training:
42 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
43 | # image_origin = image_weak
44 |
45 | if self.transform:
46 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
47 | image_strong = self.transform.transform_image(image_strong)
48 |
49 | bboxes_weak = np.stack(bboxes_weak, axis=0)
50 | masks_weak = np.stack(masks_weak, axis=0)
51 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path
52 |
53 | elif self.cfg.visual:
54 | origin_image = image
55 | origin_bboxes = bboxes
56 | origin_masks = masks
57 | if self.transform:
58 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
59 |
60 | bboxes = np.stack(bboxes, axis=0)
61 | masks = np.stack(masks, axis=0)
62 | origin_bboxes = np.stack(origin_bboxes, axis=0)
63 | origin_masks = np.stack(origin_masks, axis=0)
64 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
65 |
66 | else:
67 | if self.transform:
68 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
69 |
70 | bboxes = np.stack(bboxes, axis=0)
71 | masks = np.stack(masks, axis=0)
72 | return image, torch.tensor(bboxes), torch.tensor(masks).float(), image_path
73 |
74 | def load_datasets(cfg, img_size):
75 | transform = ResizeAndPad(img_size)
76 | train = NWPUDataset(
77 | cfg,
78 | root_dir=cfg.datasets.NWPU.root_dir,
79 | annotation_file=cfg.datasets.NWPU.annotation_file_train,
80 | transform=transform,
81 | training=True,
82 | )
83 | train_dataloader = DataLoader(
84 | train,
85 | batch_size=cfg.batch_size,
86 | shuffle=True,
87 | num_workers=cfg.num_workers,
88 | collate_fn=collate_fn,
89 | )
90 |
91 | val = NWPUDataset(
92 | cfg,
93 | root_dir=cfg.datasets.NWPU.root_dir,
94 | annotation_file=cfg.datasets.NWPU.annotation_file_val,
95 | transform=transform,
96 | )
97 | val_dataloader = DataLoader(
98 | val,
99 | batch_size=cfg.val_batchsize,
100 | shuffle=False,
101 | num_workers=cfg.num_workers,
102 | collate_fn=collate_fn,
103 | )
104 | return train_dataloader, val_dataloader
105 |
106 |
107 | def load_datasets_soft(cfg, img_size, return_pt = False):
108 | transform = ResizeAndPad(img_size)
109 |
110 | soft_train = NWPUDataset(
111 | cfg,
112 | root_dir=cfg.datasets.NWPU.root_dir,
113 | annotation_file=cfg.datasets.NWPU.annotation_file_train,
114 | transform=transform,
115 | training=True,
116 | if_self_training=True,
117 | )
118 | soft_train_dataloader = DataLoader(
119 | soft_train,
120 | batch_size=cfg.batch_size,
121 | shuffle=True,
122 | num_workers=cfg.num_workers,
123 | collate_fn=collate_fn_soft,
124 | )
125 |
126 | val = NWPUDataset(
127 | cfg,
128 | root_dir=cfg.datasets.NWPU.root_dir,
129 | annotation_file=cfg.datasets.NWPU.annotation_file_val,
130 | transform=transform,
131 | )
132 | val_dataloader = DataLoader(
133 | val,
134 | batch_size=cfg.val_batchsize,
135 | shuffle=False,
136 | num_workers=cfg.num_workers,
137 | collate_fn=collate_fn,
138 | )
139 | if return_pt:
140 | pt = NWPUDataset(
141 | cfg,
142 | root_dir=cfg.datasets.NWPU.root_dir,
143 | annotation_file=cfg.datasets.NWPU.annotation_file_train,
144 | transform=transform,
145 | )
146 | pt_dataloader = DataLoader(
147 | pt,
148 | batch_size=cfg.val_batchsize,
149 | shuffle=False,
150 | num_workers=cfg.num_workers,
151 | collate_fn=collate_fn,
152 | )
153 |
154 | return soft_train_dataloader, val_dataloader, pt_dataloader
155 | else:
156 | return soft_train_dataloader, val_dataloader
157 |
158 | def load_datasets_visual(cfg, img_size):
159 | transform = ResizeAndPad(img_size)
160 | val = NWPUDataset(
161 | cfg,
162 | root_dir=cfg.datasets.NWPU.root_dir,
163 | annotation_file=cfg.datasets.NWPU.annotation_file_val,
164 | transform=transform,
165 | )
166 | val_dataloader = DataLoader(
167 | val,
168 | batch_size=cfg.val_batchsize,
169 | shuffle=False,
170 | num_workers=cfg.num_workers,
171 | collate_fn=collate_fn_,
172 | )
173 | return val_dataloader
174 |
175 |
--------------------------------------------------------------------------------
/datasets/WHU.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data import Dataset
8 | from pycocotools.coco import COCO
9 | from skimage.draw import polygon2mask
10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_soft, collate_fn_
11 |
12 |
13 | class WHUDataset(Dataset):
14 | def __init__(self, cfg, root_dir, annotation_file, rate=(5, 1), transform=None, training=False, if_self_training=False, gen_pt = False):
15 | self.cfg = cfg
16 | self.root_dir = root_dir
17 | self.transform = transform
18 | self.coco = COCO(annotation_file)
19 | self.image_ids = sorted(list(self.coco.imgs.keys()))
20 | self.training = training
21 | self.gen_pt = gen_pt
22 |
23 | # # Filter out image_ids without any annotations
24 | self.image_ids = [
25 | image_id
26 | for image_id in self.image_ids
27 | if len(self.coco.getAnnIds(imgIds=image_id)) > 0
28 | ]
29 |
30 | self.if_self_training = if_self_training
31 |
32 | def __len__(self):
33 | return len(self.image_ids)
34 |
35 | def __getitem__(self, idx):
36 | image_id = self.image_ids[idx]
37 | image_info = self.coco.loadImgs(image_id)[0]
38 | if self.training or self.gen_pt:
39 | image_path = os.path.join(self.root_dir,'train/image', image_info["file_name"])
40 | else:
41 | image_path = os.path.join(self.root_dir,'val/image', image_info["file_name"])
42 | image = cv2.imread(image_path)
43 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
44 |
45 | ann_ids = self.coco.getAnnIds(imgIds=image_id)
46 | anns = self.coco.loadAnns(ann_ids)
47 | bboxes = []
48 | masks = []
49 | categories = []
50 | for ann in anns:
51 | x, y, w, h = ann["bbox"]
52 | bboxes.append([x, y, x + w, y + h])
53 | mask = self.coco.annToMask(ann)
54 | masks.append(mask)
55 | categories.append(ann["category_id"])
56 |
57 | if self.if_self_training:
58 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
59 |
60 | if self.transform:
61 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
62 | image_strong = self.transform.transform_image(image_strong)
63 |
64 | bboxes_weak = np.stack(bboxes_weak, axis=0)
65 | masks_weak = np.stack(masks_weak, axis=0)
66 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float(), image_path
67 |
68 | elif self.cfg.visual:
69 | origin_image = image
70 | origin_bboxes = bboxes
71 | origin_masks = masks
72 | if self.transform:
73 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
74 |
75 | bboxes = np.stack(bboxes, axis=0)
76 | masks = np.stack(masks, axis=0)
77 | origin_bboxes = np.stack(origin_bboxes, axis=0)
78 | origin_masks = np.stack(origin_masks, axis=0)
79 | return image_id, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
80 |
81 | else:
82 | if self.transform:
83 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
84 | bboxes = np.stack(bboxes, axis=0)
85 | masks = np.stack(masks, axis=0)
86 | return image, torch.tensor(bboxes), torch.tensor(masks).float(),image_path
87 |
88 | def load_datasets(cfg, img_size):
89 | transform = ResizeAndPad(img_size)
90 | train = WHUDataset(
91 | cfg,
92 | root_dir=cfg.datasets.WHU.root_dir,
93 | annotation_file=cfg.datasets.WHU.annotation_file_train,
94 | transform=transform,
95 | training=True,
96 | )
97 | train_dataloader = DataLoader(
98 | train,
99 | batch_size=cfg.batch_size,
100 | shuffle=True,
101 | num_workers=cfg.num_workers,
102 | collate_fn=collate_fn,
103 | )
104 |
105 | val = WHUDataset(
106 | cfg,
107 | root_dir=cfg.datasets.WHU.root_dir,
108 | annotation_file=cfg.datasets.WHU.annotation_file_val,
109 | transform=transform,
110 | )
111 | val_dataloader = DataLoader(
112 | val,
113 | batch_size=cfg.val_batchsize,
114 | shuffle=False,
115 | num_workers=cfg.num_workers,
116 | collate_fn=collate_fn,
117 | )
118 | return train_dataloader, val_dataloader
119 |
120 |
121 | def load_datasets_soft(cfg, img_size, return_pt = False):
122 | transform = ResizeAndPad(img_size)
123 |
124 | soft_train = WHUDataset(
125 | cfg,
126 | root_dir=cfg.datasets.WHU.root_dir,
127 | annotation_file=cfg.datasets.WHU.annotation_file_train,
128 | transform=transform,
129 | training=True,
130 | if_self_training=True,
131 | )
132 | soft_train_dataloader = DataLoader(
133 | soft_train,
134 | batch_size=cfg.batch_size,
135 | shuffle=True,
136 | pin_memory=True,
137 | num_workers=cfg.num_workers,
138 | collate_fn=collate_fn_soft,
139 | )
140 |
141 | val = WHUDataset(
142 | cfg,
143 | root_dir=cfg.datasets.WHU.root_dir,
144 | annotation_file=cfg.datasets.WHU.annotation_file_val,
145 | transform=transform,
146 | )
147 | val_dataloader = DataLoader(
148 | val,
149 | batch_size=cfg.val_batchsize,
150 | shuffle=False,
151 | num_workers=cfg.num_workers,
152 | collate_fn=collate_fn,
153 | )
154 |
155 | if return_pt:
156 | pt = WHUDataset(
157 | cfg,
158 | root_dir=cfg.datasets.WHU.root_dir,
159 | annotation_file=cfg.datasets.WHU.annotation_file_train,
160 | transform=transform,
161 | gen_pt = return_pt,
162 | )
163 | pt_dataloader = DataLoader(
164 | pt,
165 | batch_size=cfg.val_batchsize,
166 | shuffle=False,
167 | num_workers=cfg.num_workers,
168 | collate_fn=collate_fn,
169 | )
170 | return soft_train_dataloader, val_dataloader, pt_dataloader
171 | else:
172 | return soft_train_dataloader, val_dataloader
173 |
174 | def load_datasets_visual(cfg, img_size):
175 | transform = ResizeAndPad(img_size)
176 | val = WHUDataset(
177 | cfg,
178 | root_dir=cfg.datasets.WHU.root_dir,
179 | annotation_file=cfg.datasets.WHU.annotation_file_val,
180 | transform=transform,
181 | )
182 | val_dataloader = DataLoader(
183 | val,
184 | batch_size=cfg.val_batchsize,
185 | shuffle=False,
186 | num_workers=cfg.num_workers,
187 | collate_fn=collate_fn_,
188 | )
189 | return val_dataloader
190 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | def call_load_dataset(cfg):
2 | name, type = cfg.dataset, cfg.load_type
3 | key = name.split("-")[0]
4 | module_name = f"datasets.{key}"
5 |
6 | if type == "load":
7 | function_name = "load_datasets"
8 | elif type == "soft":
9 | function_name = "load_datasets_soft"
10 | elif type == "visual":
11 | function_name = "load_datasets_visual"
12 |
13 | if cfg.prompt == "coarse":
14 | function_name = function_name + "_" + "coarse"
15 |
16 | exec(f"from {module_name} import {function_name}")
17 | func = eval(function_name)
18 | return func
19 |
--------------------------------------------------------------------------------
/datasets/augmentation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | # import kornia as K
5 | import albumentations as A
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from typing import List, Union
9 | from imagecorruptions import corrupt, get_corruption_names
10 |
11 | weak_transforms = A.Compose(
12 | [A.Flip()],
13 | bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]),
14 | # keypoint_params=A.KeypointParams(format='xy')
15 | )
16 |
17 | strong_transforms = A.Compose(
18 | [
19 | A.Posterize(),
20 | A.Equalize(),
21 | A.Sharpen(),
22 | A.Solarize(),
23 | A.RandomBrightnessContrast(),
24 | A.RandomShadow(),
25 | ]
26 | )
27 |
28 |
29 | def corrupt_image(image, filename):
30 | file_name = os.path.basename(os.path.abspath(filename))
31 | file_path = os.path.dirname(os.path.abspath(filename))
32 | for corruption in get_corruption_names():
33 | corrupted = corrupt(image, severity=5, corruption_name=corruption)
34 | corrupt_path = file_path.replace(
35 | "val2017", os.path.join("corruption", corruption)
36 | )
37 | if not os.path.exists(corrupt_path):
38 | os.makedirs(corrupt_path, exist_ok=True)
39 | cv2.imwrite(os.path.join(corrupt_path, file_name), corrupted)
40 |
--------------------------------------------------------------------------------
/datasets/tools.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms as transforms
6 | from segment_anything.utils.transforms import ResizeLongestSide
7 | from datasets.augmentation import weak_transforms, strong_transforms
8 | import torchvision.transforms.functional as F
9 |
10 | class ResizeAndPad:
11 |
12 | def __init__(self, target_size):
13 | self.target_size = target_size
14 | self.transform = ResizeLongestSide(target_size)
15 | self.to_tensor = transforms.ToTensor()
16 |
17 | def __call__(self, image, masks, bboxes=None, visual=False):
18 | # Resize image and masks
19 | og_h, og_w, _ = image.shape
20 | image = self.transform.apply_image(image)
21 | masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks]
22 | image = self.to_tensor(image)
23 |
24 | # Pad image and masks to form a square
25 | _, h, w = image.shape
26 | max_dim = max(w, h)
27 | pad_w = (max_dim - w) // 2
28 | pad_h = (max_dim - h) // 2
29 |
30 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
31 | image = transforms.Pad(padding)(image)
32 | masks = [transforms.Pad(padding)(mask) for mask in masks]
33 |
34 | # Adjust bounding boxes
35 | if bboxes is not None:
36 | bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w))
37 | bboxes = [
38 | [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h]
39 | for bbox in bboxes
40 | ]
41 | if visual:
42 | return padding, image, masks, bboxes
43 | else:
44 | return image, masks, bboxes
45 | else:
46 | if visual:
47 | return padding, image, masks
48 | else:
49 | return image, masks
50 |
51 | def transform_image(self, image):
52 | # Resize image and masks
53 | image = self.transform.apply_image(image)
54 | image = self.to_tensor(image)
55 |
56 | # Pad image and masks to form a square
57 | _, h, w = image.shape
58 | max_dim = max(w, h)
59 | pad_w = (max_dim - w) // 2
60 | pad_h = (max_dim - h) // 2
61 |
62 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
63 | image = transforms.Pad(padding)(image)
64 | return image
65 |
66 | def transform_coord(self, points, image):
67 | og_h, og_w, _ = image.shape
68 | coords = points.reshape(1, -1, 2)
69 | points = self.transform.apply_coords(coords, (og_h, og_w))
70 | return points.reshape(-1, 2)
71 |
72 | def transform_coords(self, points, image, n):
73 | og_h, og_w, _ = image.shape
74 | coords = points.reshape(-1, n, 2)
75 | points = self.transform.apply_coords(coords, (og_h, og_w))
76 | return points.reshape(-1, n, 2)
77 |
78 |
79 | def soft_transform(
80 | image: np.ndarray, bboxes: list, masks: list, categories: list
81 | ):
82 | weak_transformed = weak_transforms(
83 | image=image, bboxes=bboxes, masks=masks, category_ids=categories)
84 | image_weak = weak_transformed["image"]
85 | bboxes_weak = weak_transformed["bboxes"]
86 | masks_weak = weak_transformed["masks"]
87 |
88 | strong_transformed = strong_transforms(image=image_weak)
89 | image_strong = strong_transformed["image"]
90 | return image_weak, bboxes_weak, masks_weak, image_strong
91 |
92 |
93 | def soft_transform_all(
94 | image: np.ndarray, bboxes: list, masks: list, points: list, categories: list
95 | ):
96 | weak_transformed = weak_transforms(
97 | image=image, bboxes=bboxes, masks=masks, category_ids=categories, keypoints=points)
98 | image_weak = weak_transformed["image"]
99 | bboxes_weak = weak_transformed["bboxes"]
100 | masks_weak = weak_transformed["masks"]
101 | keypoints_weak = weak_transformed["keypoints"]
102 |
103 | strong_transformed = strong_transforms(image=image_weak)
104 | image_strong = strong_transformed["image"]
105 | return image_weak, bboxes_weak, masks_weak, keypoints_weak, image_strong
106 |
107 |
108 | def collate_fn(batch):
109 | images, bboxes, masks, img_paths= zip(*batch)
110 | images = torch.stack(images)
111 | return images, bboxes, masks, img_paths
112 |
113 |
114 | def collate_fn_soft(batch):
115 | images_soft, images, bboxes, masks, img_paths = zip(*batch)
116 | images = torch.stack(images)
117 | images_soft = torch.stack(images_soft)
118 | return images_soft, images, bboxes, masks, img_paths
119 |
120 |
121 | def collate_fn_coarse(batch):
122 | images, bboxes, masks, coarse_masks = zip(*batch)
123 | images = torch.stack(images)
124 | return images, bboxes, masks, coarse_masks
125 |
126 |
127 | def collate_fn_(batch):
128 | return zip(*batch)
129 |
130 |
131 | def decode_mask(mask):
132 | """
133 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects
134 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects.
135 |
136 | Args:
137 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
138 |
139 | Returns:
140 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
141 | """
142 | unique_labels = torch.unique(mask)
143 | unique_labels = unique_labels[unique_labels != 0]
144 | n_objects = len(unique_labels)
145 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64)
146 | for i, label in enumerate(unique_labels):
147 | new_mask[i] = (mask == label).squeeze(0)
148 | return new_mask
149 |
150 |
151 | def encode_mask(mask):
152 | """
153 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects
154 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
155 |
156 | Args:
157 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
158 |
159 | Returns:
160 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
161 | """
162 | n_objects = mask.shape[0]
163 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64)
164 | for i in range(n_objects):
165 | new_mask[0][mask[i] == 1] = i + 1
166 | return new_mask
167 |
168 |
169 | if __name__ == "__main__":
170 | mask_encode = np.array([[[0, 0, 1], [2, 0, 2], [0, 3, 3]]])
171 | mask_decode = np.array([[[0, 0, 1], [0, 0, 0], [0, 0, 0]],
172 | [[0, 0, 0], [1, 0, 1], [0, 0, 0]],
173 | [[0, 0, 0], [0, 0, 0], [0, 1, 1]]])
174 | encoded_mask = encode_mask(torch.tensor(mask_decode))
175 | decoded_mask = decode_mask(torch.tensor(mask_encode))
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import copy
4 | import torch
5 | import argparse
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import segmentation_models_pytorch as smp
9 | import lightning as L
10 | from lightning.fabric.loggers import TensorBoardLogger
11 | from lightning.fabric.fabric import _FabricOptimizer
12 | import matplotlib.pyplot as plt
13 | from torch.utils.data import DataLoader
14 | from box import Box
15 | from datasets import call_load_dataset
16 | from utils.model import Model
17 | from utils.tools import copy_model, create_csv, reduce_instances
18 | from utils.eval_utils import AverageMeter
19 | from utils.sample_utils import uniform_sampling
20 | from tqdm import tqdm
21 |
22 | torch.set_float32_matmul_precision('high')
23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
24 |
25 | def compute_centroids(masks):
26 | centroids = []
27 | for mask in masks:
28 | if not isinstance(mask, np.ndarray):
29 | mask = mask.cpu().numpy().astype(np.uint8)
30 |
31 | # Find connected components
32 | num_labels, _, _, centroids_data = cv2.connectedComponentsWithStats(mask, connectivity=8)
33 |
34 | # Extract centroids (skip background label 0)
35 | component_centroids = centroids_data[1:] # Skip the first entry, which is for the background
36 |
37 | # Append centroids as a list
38 | centroids.append(component_centroids.tolist())
39 |
40 | return centroids
41 |
42 | def plotwithpoint(fabric: L.Fabric, anchor_model: Model, model: Model, val_dataloader: DataLoader):
43 | model.eval()
44 | anchor_model.eval()
45 | ious = AverageMeter()
46 | f1_scores = AverageMeter()
47 | num_points = cfg.num_points
48 | transform = val_dataloader.dataset.transform
49 |
50 | save_path = cfg.out_dir
51 | if not os.path.exists(save_path):
52 | os.makedirs(save_path)
53 |
54 | with torch.no_grad():
55 | for iter, data in enumerate(tqdm(val_dataloader, desc="Processing", unit="batch")):
56 | image_ids, paddings, ori_images, ori_bboxes, origin_masks, images, bboxes, gt_masks = data
57 | images = torch.stack(images).to(device=fabric.device)
58 | num_images = images.size(0)
59 |
60 | prompts = []
61 | for mask in gt_masks:
62 | try:
63 | po_points = compute_centroids(mask)
64 | po_point_coords = torch.tensor(po_points, device=fabric.device)
65 | except:
66 | continue
67 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points)
68 | na_point_coords = torch.tensor(na_points, device=fabric.device)
69 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1)
70 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=fabric.device)
71 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=fabric.device)
72 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1)
73 | in_points = (point_coords, point_labels)
74 | prompts.append(in_points)
75 |
76 | _, base_masks, _, _ = anchor_model(images, prompts)
77 | _, pred_masks, _, _ = model(images, prompts)
78 |
79 | draw_points = []
80 | for ori_mask in origin_masks:
81 | ori_po_points = uniform_sampling(ori_mask, num_points)
82 | draw_points.append(ori_po_points)
83 |
84 | # for pred_mask, gt_mask in zip(pred_masks, gt_masks):
85 | # batch_stats = smp.metrics.get_stats(
86 | # pred_mask,
87 | # gt_mask.to(device=fabric.device).int(),
88 | # mode='binary',
89 | # threshold=0.5,
90 | # )
91 | # batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise")
92 | # batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise")
93 | # ious.update(batch_iou, num_images)
94 | # f1_scores.update(batch_f1, num_images)
95 | # fabric.print(
96 | # f'Val:[{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]'
97 | # )
98 | # torch.cuda.empty_cache()
99 |
100 | for image_id, padding, base_mask, pred_mask, ori_mask, points, image in zip(image_ids, paddings, base_masks, pred_masks, origin_masks, draw_points, ori_images):
101 | H, W, C = image.shape
102 |
103 | base_mask = base_mask.unsqueeze(1)
104 | base_mask = base_mask[..., padding[1] : base_mask.shape[-2] - padding[3], padding[0] : base_mask.shape[-1] - padding[2]]
105 | base_mask = F.interpolate(base_mask, (H, W), mode="bilinear", align_corners=False)
106 |
107 | pred_mask = pred_mask.unsqueeze(1)
108 | pred_mask = pred_mask[..., padding[1] : pred_mask.shape[-2] - padding[3], padding[0] : pred_mask.shape[-1] - padding[2]]
109 | pred_mask = F.interpolate(pred_mask, (H, W), mode="bilinear", align_corners=False)
110 |
111 | fig, axs = plt.subplots(1, 4)
112 | fig.set_size_inches(W/100.0*4, H/100.0)
113 |
114 | image_0 = copy.deepcopy(image)
115 | image_1 = copy.deepcopy(image)
116 | image_2 = copy.deepcopy(image)
117 | image_3 = copy.deepcopy(image)
118 | axs[0].imshow(image_0)
119 | axs[1].imshow(image_1)
120 | axs[2].imshow(image_2)
121 | axs[3].imshow(image_3)
122 | axs[0].axis('off')
123 | axs[1].axis('off')
124 | axs[2].axis('off')
125 | axs[3].axis('off')
126 |
127 | masked_image_1 = np.zeros((H, W, 4))
128 | masked_image_2 = np.zeros((H, W, 4))
129 | masked_image_3 = np.zeros((H, W, 4))
130 | for point, ori_mask_i, base_mask_i, pred_mask_i in zip(points, ori_mask, base_mask, pred_mask):
131 | color = np.random.random(3)
132 | x_coords = []
133 | y_coords = []
134 | for point_i in point:
135 | x, y = point_i
136 | x_coords.append(x)
137 | y_coords.append(y)
138 | point_color = np.concatenate([color, [1.0]])
139 | axs[0].scatter(x_coords, y_coords, color=point_color)
140 |
141 | base_mask_i = (base_mask_i.squeeze(0) > 0.).cpu().numpy().astype(bool)
142 | pred_mask_i = (pred_mask_i.squeeze(0) > 0.).cpu().numpy().astype(bool)
143 | ori_mask_i = ori_mask_i.astype(bool)
144 | mask_color = np.concatenate([color, [0.7]])
145 |
146 | masked_image_1[ori_mask_i] = mask_color
147 | axs[1].imshow(masked_image_1)
148 | masked_image_2[base_mask_i] = mask_color
149 | axs[2].imshow(masked_image_2)
150 | masked_image_3[pred_mask_i] = mask_color
151 | axs[3].imshow(masked_image_3)
152 |
153 | plt.subplots_adjust(wspace=0)
154 | plt.savefig(os.path.join(save_path, f"{image_id}.jpg"), dpi=100, bbox_inches='tight', pad_inches=0)
155 | plt.close(fig)
156 |
157 | def main(cfg: Box, args) -> None:
158 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())]
159 | num_devices = len(gpu_ids)
160 |
161 | fabric = L.Fabric(accelerator="auto",
162 | devices=num_devices,
163 | strategy="auto",
164 | loggers=[TensorBoardLogger(cfg.out_dir)])
165 | fabric.launch()
166 | fabric.seed_everything(1337 + fabric.global_rank)
167 |
168 | with fabric.device:
169 | anchor_model = Model(cfg)
170 | anchor_model.setup()
171 |
172 | model = Model(cfg)
173 | model.setup()
174 | full_checkpoint = fabric.load(args.ckpt)
175 | model.load_state_dict(full_checkpoint["model"])
176 |
177 | load_datasets = call_load_dataset(cfg)
178 | val_data = load_datasets(cfg, model.model.image_encoder.img_size)
179 | val_data = fabric._setup_dataloader(val_data)
180 |
181 | plotwithpoint(fabric, anchor_model, model, val_data)
182 |
183 | def parse_args():
184 |
185 | parser = argparse.ArgumentParser(description='Test a detector with specified config and checkpoint.')
186 | parser.add_argument(
187 | '--cfg',
188 | type=str,
189 | default="configs.config_hrsid",
190 | help='Path to the configuration file (e.g., "configs.config_hrsid").'
191 | )
192 | parser.add_argument(
193 | '--out_dir',
194 | type=str,
195 | default="output",
196 | help='Directory to save predicted results.'
197 | )
198 | parser.add_argument(
199 | '--ckpt',
200 | type=str,
201 | default="checkpoints/best-ckpt.pth",
202 | help='Path to the model checkpoint file.'
203 | )
204 |
205 | args = parser.parse_args()
206 | return args
207 |
208 | if __name__ == "__main__":
209 | torch.cuda.empty_cache()
210 | args = parse_args()
211 | args_dict = vars(args)
212 | exec(f'from {args.cfg} import cfg')
213 | cfg.merge_update(args_dict)
214 | cfg.visual = True
215 | cfg.load_type = 'visual'
216 | main(cfg, args)
--------------------------------------------------------------------------------
/pretrain/Where_To_Save_Pretrained_SAM_Checkpoints:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lans1ng/PointSAM/d6831edf23de4bd7304fb46a5ce2a0baa69dcbd5/pretrain/Where_To_Save_Pretrained_SAM_Checkpoints
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python-box==7.0.1
2 | pycocotools==2.0.6
3 | numpy==1.26.4
4 | opencv_python==4.7.0.72
5 | Pillow>=9.4.0
6 | lightning==2.0.1
7 | segmentation-models-pytorch==0.3.2
8 | albumentations==1.3.1
9 | imagecorruptions==1.1.2
10 | safetensors==0.4.1
11 | torchsummary==1.5.1
12 | tensorboard==2.14.0
13 | pandas
14 |
--------------------------------------------------------------------------------
/scripts/train_hrsid_pointsam.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_hrsid"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/hrsid/pointsam")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_hrsid_selftrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_hrsid"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/hrsid/selftrain")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_hrsid_supervise.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_hrsid"
4 | prompt="point"
5 | load_type="load"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/hrsid/supervise")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_nwpu_pointsam.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_nwpu"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/nwpu/pointsam")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_nwpu_selftrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_nwpu"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/nwpu/selftrain")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_nwpu_supervise.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_nwpu"
4 | prompt="point"
5 | load_type="load"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/nwpu/supervise")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_whu_pointsam.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_whu"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/whu/pointsam")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_pointsam.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_whu_selftrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_whu"
4 | prompt="point"
5 | load_type="soft"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/whu/selftrain")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_selftrain.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/train_whu_supervise.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cfg_file="configs.config_nwpu"
4 | prompt="point"
5 | load_type="load"
6 | num_points_list=(1 2 3)
7 | output_dirs=("work_dir/nwpu/supervise")
8 |
9 | for output_dir in "${output_dirs[@]}"; do
10 | for num_points in "${num_points_list[@]}"; do
11 | out_dir="${output_dir}/point_${num_points}"
12 | CUDA_VISIBLE_DEVICES=0 python train_supervise.py --cfg "$cfg_file" --prompt "$prompt" --num_points "$num_points" --out_dir "$out_dir" --load_type "$load_type"
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | # sam.load_state_dict(state_dict)
107 | # Create a new state dictionary with only the parameters that exist in the model
108 | new_state_dict = {k: v for k, v in state_dict.items() if k in sam.state_dict() and sam.state_dict()[k].shape == v.shape}
109 | sam.load_state_dict(new_state_dict, strict = False)
110 | return sam
111 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | # from .image_encoder_adapter import ImageEncoderViT
10 | from .mask_decoder import MaskDecoder
11 | from .prompt_encoder import PromptEncoder
12 | from .transformer import TwoWayTransformer
13 |
--------------------------------------------------------------------------------
/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | src = src + dense_prompt_embeddings
128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129 | b, c, h, w = src.shape
130 |
131 | # Run the transformer
132 | hs, src = self.transformer(src, pos_src, tokens)
133 | iou_token_out = hs[:, 0, :]
134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135 |
136 | # Upscale mask embeddings and predict masks using the mask tokens
137 | src = src.transpose(1, 2).view(b, c, h, w)
138 | upscaled_embedding = self.output_upscaling(src)
139 | hyper_in_list: List[torch.Tensor] = []
140 | for i in range(self.num_mask_tokens):
141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142 | hyper_in = torch.stack(hyper_in_list, dim=1)
143 | b, c, h, w = upscaled_embedding.shape
144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145 |
146 | # Generate mask quality predictions
147 | iou_pred = self.iou_prediction_head(iou_token_out)
148 |
149 | return masks, iou_pred
150 |
151 |
152 | # Lightly adapted from
153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154 | class MLP(nn.Module):
155 | def __init__(
156 | self,
157 | input_dim: int,
158 | hidden_dim: int,
159 | output_dim: int,
160 | num_layers: int,
161 | sigmoid_output: bool = False,
162 | ) -> None:
163 | super().__init__()
164 | self.num_layers = num_layers
165 | h = [hidden_dim] * (num_layers - 1)
166 | self.layers = nn.ModuleList(
167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168 | )
169 | self.sigmoid_output = sigmoid_output
170 |
171 | def forward(self, x):
172 | for i, layer in enumerate(self.layers):
173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174 | if self.sigmoid_output:
175 | x = F.sigmoid(x)
176 | return x
177 |
--------------------------------------------------------------------------------
/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | # print('points.shape',points.shape)
85 | # print('padding_point',padding_point.shape)
86 | # print('point',points)
87 | # print('padding_point',padding_point)
88 | points = torch.cat([points, padding_point], dim=1)
89 | labels = torch.cat([labels, padding_label], dim=1)
90 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
91 | point_embedding[labels == -1] = 0.0
92 | point_embedding[labels == -1] += self.not_a_point_embed.weight
93 | point_embedding[labels == 0] += self.point_embeddings[0].weight
94 | point_embedding[labels == 1] += self.point_embeddings[1].weight
95 | return point_embedding
96 |
97 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
98 | """Embeds box prompts."""
99 | boxes = boxes + 0.5 # Shift to center of pixel
100 | coords = boxes.reshape(-1, 2, 2)
101 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
102 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
103 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
104 | return corner_embedding
105 |
106 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
107 | """Embeds mask inputs."""
108 | mask_embedding = self.mask_downscaling(masks)
109 | return mask_embedding
110 |
111 | def _get_batch_size(
112 | self,
113 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
114 | boxes: Optional[torch.Tensor],
115 | masks: Optional[torch.Tensor],
116 | ) -> int:
117 | """
118 | Gets the batch size of the output given the batch size of the input prompts.
119 | """
120 | if points is not None:
121 | return points[0].shape[0]
122 | elif boxes is not None:
123 | return boxes.shape[0]
124 | elif masks is not None:
125 | return masks.shape[0]
126 | else:
127 | return 1
128 |
129 | def _get_device(self) -> torch.device:
130 | return self.point_embeddings[0].weight.device
131 |
132 | def forward(
133 | self,
134 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
135 | boxes: Optional[torch.Tensor],
136 | masks: Optional[torch.Tensor],
137 | ) -> Tuple[torch.Tensor, torch.Tensor]:
138 | """
139 | Embeds different types of prompts, returning both sparse and dense
140 | embeddings.
141 |
142 | Arguments:
143 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
144 | and labels to embed.
145 | boxes (torch.Tensor or none): boxes to embed
146 | masks (torch.Tensor or none): masks to embed
147 |
148 | Returns:
149 | torch.Tensor: sparse embeddings for the points and boxes, with shape
150 | BxNx(embed_dim), where N is determined by the number of input points
151 | and boxes.
152 | torch.Tensor: dense embeddings for the masks, in the shape
153 | Bx(embed_dim)x(embed_H)x(embed_W)
154 | """
155 | bs = self._get_batch_size(points, boxes, masks)
156 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
157 | if points is not None:
158 | coords, labels = points
159 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
160 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
161 | if boxes is not None:
162 | box_embeddings = self._embed_boxes(boxes)
163 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
164 |
165 | if masks is not None:
166 | dense_embeddings = self._embed_masks(masks)
167 | else:
168 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
169 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
170 | )
171 |
172 | return sparse_embeddings, dense_embeddings
173 |
174 |
175 | class PositionEmbeddingRandom(nn.Module):
176 | """
177 | Positional encoding using random spatial frequencies.
178 | """
179 |
180 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
181 | super().__init__()
182 | if scale is None or scale <= 0.0:
183 | scale = 1.0
184 | self.register_buffer(
185 | "positional_encoding_gaussian_matrix",
186 | scale * torch.randn((2, num_pos_feats)),
187 | )
188 |
189 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
190 | """Positionally encode points that are normalized to [0,1]."""
191 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
192 | coords = 2 * coords - 1
193 | coords = coords @ self.positional_encoding_gaussian_matrix
194 | coords = 2 * np.pi * coords
195 | # outputs d_1 x ... x d_n x C shape
196 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
197 |
198 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
199 | """Generate positional encoding for a grid of the specified size."""
200 | h, w = size
201 | device: Any = self.positional_encoding_gaussian_matrix.device
202 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
203 | y_embed = grid.cumsum(dim=0) - 0.5
204 | x_embed = grid.cumsum(dim=1) - 0.5
205 | y_embed = y_embed / h
206 | x_embed = x_embed / w
207 |
208 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
209 | return pe.permute(2, 0, 1) # C x H x W
210 |
211 | def forward_with_coords(
212 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
213 | ) -> torch.Tensor:
214 | """Positionally encode points that are not normalized to [0,1]."""
215 | coords = coords_input.clone()
216 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
217 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
218 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
219 |
--------------------------------------------------------------------------------
/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | @torch.no_grad()
54 | def forward(
55 | self,
56 | batched_input: List[Dict[str, Any]],
57 | multimask_output: bool,
58 | ) -> List[Dict[str, torch.Tensor]]:
59 | """
60 | Predicts masks end-to-end from provided images and prompts.
61 | If prompts are not known in advance, using SamPredictor is
62 | recommended over calling the model directly.
63 |
64 | Arguments:
65 | batched_input (list(dict)): A list over input images, each a
66 | dictionary with the following keys. A prompt key can be
67 | excluded if it is not present.
68 | 'image': The image as a torch tensor in 3xHxW format,
69 | already transformed for input to the model.
70 | 'original_size': (tuple(int, int)) The original size of
71 | the image before transformation, as (H, W).
72 | 'point_coords': (torch.Tensor) Batched point prompts for
73 | this image, with shape BxNx2. Already transformed to the
74 | input frame of the model.
75 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
76 | with shape BxN.
77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78 | Already transformed to the input frame of the model.
79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80 | in the form Bx1xHxW.
81 | multimask_output (bool): Whether the model should predict multiple
82 | disambiguating masks, or return a single mask.
83 |
84 | Returns:
85 | (list(dict)): A list over input images, where each element is
86 | as dictionary with the following keys.
87 | 'masks': (torch.Tensor) Batched binary mask predictions,
88 | with shape BxCxHxW, where B is the number of input prompts,
89 | C is determined by multimask_output, and (H, W) is the
90 | original size of the image.
91 | 'iou_predictions': (torch.Tensor) The model's predictions
92 | of mask quality, in shape BxC.
93 | 'low_res_logits': (torch.Tensor) Low resolution logits with
94 | shape BxCxHxW, where H=W=256. Can be passed as mask input
95 | to subsequent iterations of prediction.
96 | """
97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98 | image_embeddings = self.image_encoder(input_images)
99 |
100 | outputs = []
101 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
102 | if "point_coords" in image_record:
103 | points = (image_record["point_coords"], image_record["point_labels"])
104 | else:
105 | points = None
106 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
107 | points=points,
108 | boxes=image_record.get("boxes", None),
109 | masks=image_record.get("mask_inputs", None),
110 | )
111 | low_res_masks, iou_predictions = self.mask_decoder(
112 | image_embeddings=curr_embedding.unsqueeze(0),
113 | image_pe=self.prompt_encoder.get_dense_pe(),
114 | sparse_prompt_embeddings=sparse_embeddings,
115 | dense_prompt_embeddings=dense_embeddings,
116 | multimask_output=multimask_output,
117 | )
118 | masks = self.postprocess_masks(
119 | low_res_masks,
120 | input_size=image_record["image"].shape[-2:],
121 | original_size=image_record["original_size"],
122 | )
123 | masks = masks > self.mask_threshold
124 | outputs.append(
125 | {
126 | "masks": masks,
127 | "iou_predictions": iou_predictions,
128 | "low_res_logits": low_res_masks,
129 | }
130 | )
131 | return outputs
132 |
133 | def postprocess_masks(
134 | self,
135 | masks: torch.Tensor,
136 | input_size: Tuple[int, ...],
137 | original_size: Tuple[int, ...],
138 | ) -> torch.Tensor:
139 | """
140 | Remove padding and upscale masks to the original image size.
141 |
142 | Arguments:
143 | masks (torch.Tensor): Batched masks from the mask_decoder,
144 | in BxCxHxW format.
145 | input_size (tuple(int, int)): The size of the image input to the
146 | model, in (H, W) format. Used to remove padding.
147 | original_size (tuple(int, int)): The original size of the image
148 | before resizing for input to the model, in (H, W) format.
149 |
150 | Returns:
151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152 | is given by original_size.
153 | """
154 | masks = F.interpolate(
155 | masks,
156 | (self.image_encoder.img_size, self.image_encoder.img_size),
157 | mode="bilinear",
158 | align_corners=False,
159 | )
160 | masks = masks[..., : input_size[0], : input_size[1]]
161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162 | return masks
163 |
164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165 | """Normalize pixel values and pad to a square input."""
166 | # Normalize colors
167 | x = (x - self.pixel_mean) / self.pixel_std
168 |
169 | # Pad
170 | h, w = x.shape[-2:]
171 | padh = self.image_encoder.img_size - h
172 | padw = self.image_encoder.img_size - w
173 | x = F.pad(x, (0, padw, 0, padh))
174 | return x
175 |
--------------------------------------------------------------------------------
/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attention layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from segment_anything.modeling import Sam
11 |
12 | from typing import Optional, Tuple
13 |
14 | from .utils.transforms import ResizeLongestSide
15 |
16 |
17 | class SamPredictor:
18 | def __init__(
19 | self,
20 | sam_model: Sam,
21 | ) -> None:
22 | """
23 | Uses SAM to calculate the image embedding for an image, and then
24 | allow repeated, efficient mask prediction given prompts.
25 |
26 | Arguments:
27 | sam_model (Sam): The model to use for mask prediction.
28 | """
29 | super().__init__()
30 | self.model = sam_model
31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32 | self.reset_image()
33 |
34 | def set_image(
35 | self,
36 | image: np.ndarray,
37 | image_format: str = "RGB",
38 | ) -> None:
39 | """
40 | Calculates the image embeddings for the provided image, allowing
41 | masks to be predicted with the 'predict' method.
42 |
43 | Arguments:
44 | image (np.ndarray): The image for calculating masks. Expects an
45 | image in HWC uint8 format, with pixel values in [0, 255].
46 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
47 | """
48 | assert image_format in [
49 | "RGB",
50 | "BGR",
51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52 | if image_format != self.model.image_format:
53 | image = image[..., ::-1]
54 |
55 | # Transform the image to the form expected by the model
56 | input_image = self.transform.apply_image(image)
57 | input_image_torch = torch.as_tensor(input_image, device=self.device)
58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59 |
60 | self.set_torch_image(input_image_torch, image.shape[:2])
61 |
62 | @torch.no_grad()
63 | def set_torch_image(
64 | self,
65 | transformed_image: torch.Tensor,
66 | original_image_size: Tuple[int, ...],
67 | ) -> None:
68 | """
69 | Calculates the image embeddings for the provided image, allowing
70 | masks to be predicted with the 'predict' method. Expects the input
71 | image to be already transformed to the format expected by the model.
72 |
73 | Arguments:
74 | transformed_image (torch.Tensor): The input image, with shape
75 | 1x3xHxW, which has been transformed with ResizeLongestSide.
76 | original_image_size (tuple(int, int)): The size of the image
77 | before transformation, in (H, W) format.
78 | """
79 | assert (
80 | len(transformed_image.shape) == 4
81 | and transformed_image.shape[1] == 3
82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 | self.reset_image()
85 |
86 | self.original_size = original_image_size
87 | self.input_size = tuple(transformed_image.shape[-2:])
88 | input_image = self.model.preprocess(transformed_image)
89 | self.features = self.model.image_encoder(input_image)
90 | self.is_image_set = True
91 |
92 | def predict(
93 | self,
94 | point_coords: Optional[np.ndarray] = None,
95 | point_labels: Optional[np.ndarray] = None,
96 | box: Optional[np.ndarray] = None,
97 | mask_input: Optional[np.ndarray] = None,
98 | multimask_output: bool = True,
99 | return_logits: bool = False,
100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101 | """
102 | Predict masks for the given input prompts, using the currently set image.
103 |
104 | Arguments:
105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106 | model. Each point is in (X,Y) in pixels.
107 | point_labels (np.ndarray or None): A length N array of labels for the
108 | point prompts. 1 indicates a foreground point and 0 indicates a
109 | background point.
110 | box (np.ndarray or None): A length 4 array given a box prompt to the
111 | model, in XYXY format.
112 | mask_input (np.ndarray): A low resolution mask input to the model, typically
113 | coming from a previous prediction iteration. Has form 1xHxW, where
114 | for SAM, H=W=256.
115 | multimask_output (bool): If true, the model will return three masks.
116 | For ambiguous input prompts (such as a single click), this will often
117 | produce better masks than a single prediction. If only a single
118 | mask is needed, the model's predicted quality score can be used
119 | to select the best mask. For non-ambiguous prompts, such as multiple
120 | input prompts, multimask_output=False can give better results.
121 | return_logits (bool): If true, returns un-thresholded masks logits
122 | instead of a binary mask.
123 |
124 | Returns:
125 | (np.ndarray): The output masks in CxHxW format, where C is the
126 | number of masks, and (H, W) is the original image size.
127 | (np.ndarray): An array of length C containing the model's
128 | predictions for the quality of each mask.
129 | (np.ndarray): An array of shape CxHxW, where C is the number
130 | of masks and H=W=256. These low resolution logits can be passed to
131 | a subsequent iteration as mask input.
132 | """
133 | if not self.is_image_set:
134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135 |
136 | # Transform input prompts
137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138 | if point_coords is not None:
139 | assert (
140 | point_labels is not None
141 | ), "point_labels must be supplied if point_coords is supplied."
142 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146 | if box is not None:
147 | box = self.transform.apply_boxes(box, self.original_size)
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149 | box_torch = box_torch[None, :]
150 | if mask_input is not None:
151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152 | mask_input_torch = mask_input_torch[None, :, :, :]
153 |
154 | masks, iou_predictions, low_res_masks = self.predict_torch(
155 | coords_torch,
156 | labels_torch,
157 | box_torch,
158 | mask_input_torch,
159 | multimask_output,
160 | return_logits=return_logits,
161 | )
162 |
163 | masks_np = masks[0].detach().cpu().numpy()
164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166 | return masks_np, iou_predictions_np, low_res_masks_np
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | """
179 | Predict masks for the given input prompts, using the currently set image.
180 | Input prompts are batched torch tensors and are expected to already be
181 | transformed to the input frame using ResizeLongestSide.
182 |
183 | Arguments:
184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (torch.Tensor or None): A BxN array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form Bx1xHxW, where
193 | for SAM, H=W=256. Masks returned by a previous iteration of the
194 | predict method do not need further transformation.
195 | multimask_output (bool): If true, the model will return three masks.
196 | For ambiguous input prompts (such as a single click), this will often
197 | produce better masks than a single prediction. If only a single
198 | mask is needed, the model's predicted quality score can be used
199 | to select the best mask. For non-ambiguous prompts, such as multiple
200 | input prompts, multimask_output=False can give better results.
201 | return_logits (bool): If true, returns un-thresholded masks logits
202 | instead of a binary mask.
203 |
204 | Returns:
205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
206 | number of masks, and (H, W) is the original image size.
207 | (torch.Tensor): An array of shape BxC containing the model's
208 | predictions for the quality of each mask.
209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
210 | of masks and H=W=256. These low res logits can be passed to
211 | a subsequent iteration as mask input.
212 | """
213 | if not self.is_image_set:
214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 |
216 | if point_coords is not None:
217 | points = (point_coords, point_labels)
218 | else:
219 | points = None
220 |
221 | # Embed prompts
222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 | points=points,
224 | boxes=boxes,
225 | masks=mask_input,
226 | )
227 |
228 | # Predict masks
229 | low_res_masks, iou_predictions = self.model.mask_decoder(
230 | image_embeddings=self.features,
231 | image_pe=self.model.prompt_encoder.get_dense_pe(),
232 | sparse_prompt_embeddings=sparse_embeddings,
233 | dense_prompt_embeddings=dense_embeddings,
234 | multimask_output=multimask_output,
235 | )
236 |
237 | # Upscale the masks to the original image resolution
238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239 |
240 | if not return_logits:
241 | masks = masks > self.model.mask_threshold
242 |
243 | return masks, iou_predictions, low_res_masks
244 |
245 | def get_image_embedding(self) -> torch.Tensor:
246 | """
247 | Returns the image embeddings for the currently set image, with
248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
250 | """
251 | if not self.is_image_set:
252 | raise RuntimeError(
253 | "An image must be set with .set_image(...) to generate an embedding."
254 | )
255 | assert self.features is not None, "Features must exist if an image has been set."
256 | return self.features
257 |
258 | @property
259 | def device(self) -> torch.device:
260 | return self.model.device
261 |
262 | def reset_image(self) -> None:
263 | """Resets the currently set image."""
264 | self.is_image_set = False
265 | self.features = None
266 | self.orig_h = None
267 | self.orig_w = None
268 | self.input_h = None
269 | self.input_w = None
270 |
--------------------------------------------------------------------------------
/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | import math
11 | from copy import deepcopy
12 | from itertools import product
13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
14 |
15 |
16 | class MaskData:
17 | """
18 | A structure for storing masks and their related data in batched format.
19 | Implements basic filtering and concatenation.
20 | """
21 |
22 | def __init__(self, **kwargs) -> None:
23 | for v in kwargs.values():
24 | assert isinstance(
25 | v, (list, np.ndarray, torch.Tensor)
26 | ), "MaskData only supports list, numpy arrays, and torch tensors."
27 | self._stats = dict(**kwargs)
28 |
29 | def __setitem__(self, key: str, item: Any) -> None:
30 | assert isinstance(
31 | item, (list, np.ndarray, torch.Tensor)
32 | ), "MaskData only supports list, numpy arrays, and torch tensors."
33 | self._stats[key] = item
34 |
35 | def __delitem__(self, key: str) -> None:
36 | del self._stats[key]
37 |
38 | def __getitem__(self, key: str) -> Any:
39 | return self._stats[key]
40 |
41 | def items(self) -> ItemsView[str, Any]:
42 | return self._stats.items()
43 |
44 | def filter(self, keep: torch.Tensor) -> None:
45 | for k, v in self._stats.items():
46 | if v is None:
47 | self._stats[k] = None
48 | elif isinstance(v, torch.Tensor):
49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50 | elif isinstance(v, np.ndarray):
51 | self._stats[k] = v[keep.detach().cpu().numpy()]
52 | elif isinstance(v, list) and keep.dtype == torch.bool:
53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54 | elif isinstance(v, list):
55 | self._stats[k] = [v[i] for i in keep]
56 | else:
57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58 |
59 | def cat(self, new_stats: "MaskData") -> None:
60 | for k, v in new_stats.items():
61 | if k not in self._stats or self._stats[k] is None:
62 | self._stats[k] = deepcopy(v)
63 | elif isinstance(v, torch.Tensor):
64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65 | elif isinstance(v, np.ndarray):
66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67 | elif isinstance(v, list):
68 | self._stats[k] = self._stats[k] + deepcopy(v)
69 | else:
70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71 |
72 | def to_numpy(self) -> None:
73 | for k, v in self._stats.items():
74 | if isinstance(v, torch.Tensor):
75 | self._stats[k] = v.detach().cpu().numpy()
76 |
77 |
78 | def is_box_near_crop_edge(
79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80 | ) -> torch.Tensor:
81 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88 | return torch.any(near_crop_edge, dim=1)
89 |
90 |
91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92 | box_xywh = deepcopy(box_xyxy)
93 | box_xywh[2] = box_xywh[2] - box_xywh[0]
94 | box_xywh[3] = box_xywh[3] - box_xywh[1]
95 | return box_xywh
96 |
97 |
98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99 | assert len(args) > 0 and all(
100 | len(a) == len(args[0]) for a in args
101 | ), "Batched iteration must have inputs of all the same size."
102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103 | for b in range(n_batches):
104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105 |
106 |
107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108 | """
109 | Encodes masks to an uncompressed RLE, in the format expected by
110 | pycoco tools.
111 | """
112 | # Put in fortran order and flatten h,w
113 | b, h, w = tensor.shape
114 | tensor = tensor.permute(0, 2, 1).flatten(1)
115 |
116 | # Compute change indices
117 | diff = tensor[:, 1:] ^ tensor[:, :-1]
118 | change_indices = diff.nonzero()
119 |
120 | # Encode run length
121 | out = []
122 | for i in range(b):
123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124 | cur_idxs = torch.cat(
125 | [
126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127 | cur_idxs + 1,
128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | ]
130 | )
131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132 | counts = [] if tensor[i, 0] == 0 else [0]
133 | counts.extend(btw_idxs.detach().cpu().tolist())
134 | out.append({"size": [h, w], "counts": counts})
135 | return out
136 |
137 |
138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139 | """Compute a binary mask from an uncompressed RLE."""
140 | h, w = rle["size"]
141 | mask = np.empty(h * w, dtype=bool)
142 | idx = 0
143 | parity = False
144 | for count in rle["counts"]:
145 | mask[idx : idx + count] = parity
146 | idx += count
147 | parity ^= True
148 | mask = mask.reshape(w, h)
149 | return mask.transpose() # Put in C order
150 |
151 |
152 | def area_from_rle(rle: Dict[str, Any]) -> int:
153 | return sum(rle["counts"][1::2])
154 |
155 |
156 | def calculate_stability_score(
157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158 | ) -> torch.Tensor:
159 | """
160 | Computes the stability score for a batch of masks. The stability
161 | score is the IoU between the binary masks obtained by thresholding
162 | the predicted mask logits at high and low values.
163 | """
164 | # One mask is always contained inside the other.
165 | # Save memory by preventing unnecessary cast to torch.int64
166 | intersections = (
167 | (masks > (mask_threshold + threshold_offset))
168 | .sum(-1, dtype=torch.int16)
169 | .sum(-1, dtype=torch.int32)
170 | )
171 | unions = (
172 | (masks > (mask_threshold - threshold_offset))
173 | .sum(-1, dtype=torch.int16)
174 | .sum(-1, dtype=torch.int32)
175 | )
176 | return intersections / unions
177 |
178 |
179 | def build_point_grid(n_per_side: int) -> np.ndarray:
180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181 | offset = 1 / (2 * n_per_side)
182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186 | return points
187 |
188 |
189 | def build_all_layer_point_grids(
190 | n_per_side: int, n_layers: int, scale_per_layer: int
191 | ) -> List[np.ndarray]:
192 | """Generates point grids for all crop layers."""
193 | points_by_layer = []
194 | for i in range(n_layers + 1):
195 | n_points = int(n_per_side / (scale_per_layer**i))
196 | points_by_layer.append(build_point_grid(n_points))
197 | return points_by_layer
198 |
199 |
200 | def generate_crop_boxes(
201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202 | ) -> Tuple[List[List[int]], List[int]]:
203 | """
204 | Generates a list of crop boxes of different sizes. Each layer
205 | has (2**i)**2 boxes for the ith layer.
206 | """
207 | crop_boxes, layer_idxs = [], []
208 | im_h, im_w = im_size
209 | short_side = min(im_h, im_w)
210 |
211 | # Original image
212 | crop_boxes.append([0, 0, im_w, im_h])
213 | layer_idxs.append(0)
214 |
215 | def crop_len(orig_len, n_crops, overlap):
216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217 |
218 | for i_layer in range(n_layers):
219 | n_crops_per_side = 2 ** (i_layer + 1)
220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221 |
222 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
223 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
224 |
225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227 |
228 | # Crops in XYWH format
229 | for x0, y0 in product(crop_box_x0, crop_box_y0):
230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231 | crop_boxes.append(box)
232 | layer_idxs.append(i_layer + 1)
233 |
234 | return crop_boxes, layer_idxs
235 |
236 |
237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238 | x0, y0, _, _ = crop_box
239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240 | # Check if boxes has a channel dimension
241 | if len(boxes.shape) == 3:
242 | offset = offset.unsqueeze(1)
243 | return boxes + offset
244 |
245 |
246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247 | x0, y0, _, _ = crop_box
248 | offset = torch.tensor([[x0, y0]], device=points.device)
249 | # Check if points has a channel dimension
250 | if len(points.shape) == 3:
251 | offset = offset.unsqueeze(1)
252 | return points + offset
253 |
254 |
255 | def uncrop_masks(
256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257 | ) -> torch.Tensor:
258 | x0, y0, x1, y1 = crop_box
259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260 | return masks
261 | # Coordinate transform masks
262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263 | pad = (x0, pad_x - x0, y0, pad_y - y0)
264 | return torch.nn.functional.pad(masks, pad, value=0)
265 |
266 |
267 | def remove_small_regions(
268 | mask: np.ndarray, area_thresh: float, mode: str
269 | ) -> Tuple[np.ndarray, bool]:
270 | """
271 | Removes small disconnected regions and holes in a mask. Returns the
272 | mask and an indicator of if the mask has been modified.
273 | """
274 | import cv2 # type: ignore
275 |
276 | assert mode in ["holes", "islands"]
277 | correct_holes = mode == "holes"
278 | working_mask = (correct_holes ^ mask).astype(np.uint8)
279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280 | sizes = stats[:, -1][1:] # Row 0 is background label
281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282 | if len(small_regions) == 0:
283 | return mask, False
284 | fill_labels = [0] + small_regions
285 | if not correct_holes:
286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287 | # If every region is below threshold, keep largest
288 | if len(fill_labels) == 0:
289 | fill_labels = [int(np.argmax(sizes)) + 1]
290 | mask = np.isin(regions, fill_labels)
291 | return mask, True
292 |
293 |
294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295 | from pycocotools import mask as mask_utils # type: ignore
296 |
297 | h, w = uncompressed_rle["size"]
298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300 | return rle
301 |
302 |
303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304 | """
305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307 | """
308 | # torch.max below raises an error on empty inputs, just skip in this case
309 | if torch.numel(masks) == 0:
310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311 |
312 | # Normalize shape to CxHxW
313 | shape = masks.shape
314 | h, w = shape[-2:]
315 | if len(shape) > 2:
316 | masks = masks.flatten(0, -3)
317 | else:
318 | masks = masks.unsqueeze(0)
319 |
320 | # Get top and bottom edges
321 | in_height, _ = torch.max(masks, dim=-1)
322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324 | in_height_coords = in_height_coords + h * (~in_height)
325 | top_edges, _ = torch.min(in_height_coords, dim=-1)
326 |
327 | # Get left and right edges
328 | in_width, _ = torch.max(masks, dim=-2)
329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330 | right_edges, _ = torch.max(in_width_coords, dim=-1)
331 | in_width_coords = in_width_coords + w * (~in_width)
332 | left_edges, _ = torch.min(in_width_coords, dim=-1)
333 |
334 | # If the mask is empty the right edge will be to the left of the left edge.
335 | # Replace these boxes with [0, 0, 0, 0]
336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338 | out = out * (~empty_filter).unsqueeze(-1)
339 |
340 | # Return to original shape
341 | if len(shape) > 2:
342 | out = out.reshape(*shape[:-2], 4)
343 | else:
344 | out = out[0]
345 |
346 | return out
347 |
--------------------------------------------------------------------------------
/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/train_pointsam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import random
5 | from abc import ABC
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 |
13 | import lightning as L
14 | from lightning.fabric.loggers import TensorBoardLogger
15 | from lightning.fabric.fabric import _FabricOptimizer
16 |
17 | from box import Box
18 | from datasets import call_load_dataset
19 | from utils.model import Model
20 | from utils.losses import DiceLoss, FocalLoss, Matching_Loss
21 | from utils.eval_utils import AverageMeter, validate, get_prompts, calc_iou
22 | from utils.tools import copy_model, create_csv, reduce_instances
23 | from utils.utils import *
24 | from utils.finch import FINCH
25 |
26 | vis = False
27 |
28 | def train_sam(
29 | cfg: Box,
30 | fabric: L.Fabric,
31 | model: Model,
32 | optimizer: _FabricOptimizer,
33 | scheduler: _FabricOptimizer,
34 | train_dataloader: DataLoader,
35 | val_dataloader: DataLoader,
36 | target_pts,
37 | ):
38 |
39 | focal_loss = FocalLoss()
40 | dice_loss = DiceLoss()
41 | max_iou = 0.
42 | mem_bank = Store(1, cfg.mem_bank_max_len)
43 | match_interval = cfg.match_interval
44 | for epoch in range(1, cfg.num_epochs + 1):
45 | batch_time = AverageMeter()
46 | data_time = AverageMeter()
47 | focal_losses = AverageMeter()
48 | dice_losses = AverageMeter()
49 | iou_losses = AverageMeter()
50 | total_losses = AverageMeter()
51 | match_losses = AverageMeter()
52 | end = time.time()
53 | num_iter = len(train_dataloader)
54 |
55 | for iter, data in enumerate(train_dataloader):
56 |
57 | data_time.update(time.time() - end)
58 | images_weak, images_strong, bboxes, gt_masks, img_paths= data
59 | del data
60 |
61 | batch_size = images_weak.size(0)
62 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks)
63 | if num_insts > cfg.max_nums:
64 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums)
65 | prompts = get_prompts(cfg, bboxes, gt_masks)
66 |
67 | #1. caculate pairwise IoUs of masks
68 | mask_ious, init_masks = cal_mask_ious(cfg, model, images_weak, prompts, gt_masks)
69 |
70 | #2. get new prompts through neg_prompt_calibration
71 | new_prompts = neg_prompt_calibration(cfg, mask_ious, prompts)
72 |
73 | #3. start training using new prompt
74 | soft_image_embeds, soft_masks, _, _ = model(images_weak, new_prompts) # teacher
75 |
76 | if isinstance(soft_image_embeds, dict):
77 | soft_image_embeds = soft_image_embeds['vision_features']
78 |
79 | _, pred_masks, iou_predictions, _= model(images_strong, prompts) # student
80 |
81 | del _
82 |
83 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
84 | loss_focal = torch.tensor(0., device=fabric.device)
85 | loss_dice = torch.tensor(0., device=fabric.device)
86 | loss_match = torch.tensor(0., device=fabric.device)
87 | loss_iou = torch.tensor(0., device=fabric.device)
88 |
89 | for i, (embed, pred_mask, soft_mask, gt_mask, prompt, iou_prediction) in enumerate(zip(soft_image_embeds, pred_masks, soft_masks, gt_masks, prompts, iou_predictions)):
90 |
91 | soft_mask = (soft_mask > 0.).float()
92 | pred_feats = generate_predict_feats(cfg, embed, soft_mask, prompt)
93 | target_pts_ = target_pts['target_pts']
94 | pred_feats = pred_feats.cpu().tolist()
95 | for pred_feat in pred_feats:
96 | mem_bank.add([[pred_feat]], [0])
97 | if len(mem_bank.retrieve(0)) >= cfg.mem_bank_max_len and (iter + 1) % match_interval == 0:
98 | pred_feats = mem_bank.retrieve(0)
99 | pred_feats = np.array(pred_feats)
100 |
101 | #FINCH
102 | fin = FINCH(verbose=False)
103 | results = fin.fit(pred_feats)
104 | last_key = list(results.partitions.keys())[-1]
105 | pred_pts = results.partitions[last_key]['cluster_centers']
106 | loss_match += Matching_Loss(pred_pts, target_pts_, device = fabric.device)
107 |
108 | del embed
109 |
110 | if vis:
111 | img_name = os.path.basename(img_paths[i]).split('.')[0]
112 |
113 | image_weak = images_weak[0].permute(1,2,0).cpu().numpy()* 255
114 | image_weak = cv2.cvtColor(image_weak, cv2.COLOR_BGR2RGB)
115 |
116 | if vis:
117 | for j in range(len(soft_mask)):
118 | mask_iou = torch.max(mask_ious[j])
119 | image_weak_ = image_weak.copy()
120 | mask_area = torch.sum(gt_mask[j])
121 |
122 | gt_mask_np = gt_mask[j].cpu().numpy() * 255
123 | gt_mask_img = cv2.cvtColor(gt_mask_np, cv2.COLOR_GRAY2RGB)
124 |
125 | init_prompt_po = prompts[0][0][j][:cfg.num_points]
126 | init_prompt_ne = prompts[0][0][j][cfg.num_points:]
127 |
128 | for po in init_prompt_po:
129 | cv2.circle(image_weak_, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1)
130 |
131 | init_mask_img = init_masks[j].cpu().detach().numpy() * 255
132 | init_mask_img = cv2.cvtColor(init_mask_img, cv2.COLOR_GRAY2RGB)
133 | for po,ne in zip(init_prompt_po,init_prompt_ne):
134 | cv2.circle(init_mask_img, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1)
135 | cv2.circle(init_mask_img, (int(ne[0]), int(ne[1])), 12, (0, 255, 0), -1)
136 |
137 | prompt_po = new_prompts[0][0][j][:cfg.num_points]
138 | prompt_ne = new_prompts[0][0][j][cfg.num_points:]
139 | soft_mask_img = soft_mask[j].cpu().detach().numpy() * 255
140 |
141 | soft_mask_img = cv2.cvtColor(soft_mask_img, cv2.COLOR_GRAY2RGB)
142 |
143 | for po,ne in zip(prompt_po,prompt_ne):
144 | cv2.circle(soft_mask_img, (int(po[0]), int(po[1])), 12, (0, 0, 255), -1)
145 | cv2.circle(soft_mask_img, (int(ne[0]), int(ne[1])), 12, (0, 255, 0), -1)
146 |
147 | output_dir = "./save_mask_{}/{}/{}/".format(str(float(cfg.num_points)),cfg.dataset,str(epoch))
148 | if not os.path.exists(output_dir):
149 | os.makedirs(output_dir)
150 | merged_image = concatenate_images_with_padding([image_weak_, gt_mask_img, init_mask_img, soft_mask_img])
151 | img_name_ = '{}_{}_iou{}.jpg'.format(img_name,str(j),str(mask_iou))
152 | if mask_iou>float(cfg.iou_thr) and mask_area>3000:
153 | cv2.imwrite(os.path.join(output_dir,img_name_), merged_image)
154 | del init_masks, mask_ious
155 |
156 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks)
157 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks)
158 | batch_iou = calc_iou(pred_mask, soft_mask)
159 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
160 |
161 | del soft_image_embeds, pred_masks, iou_predictions, gt_masks
162 |
163 | loss_total = 20. * loss_focal + loss_dice + loss_iou + 0.1*loss_match
164 |
165 | fabric.backward(loss_total)
166 |
167 | optimizer.step()
168 | scheduler.step()
169 | optimizer.zero_grad()
170 | torch.cuda.empty_cache()
171 |
172 | batch_time.update(time.time() - end)
173 | end = time.time()
174 |
175 | focal_losses.update(loss_focal.item(), batch_size)
176 | dice_losses.update(loss_dice.item(), batch_size)
177 | iou_losses.update(loss_iou.item(), batch_size)
178 | total_losses.update(loss_total.item(), batch_size)
179 | match_losses.update(loss_match.item(), batch_size)
180 |
181 | if (iter+1) %match_interval==0:
182 | fabric.print(f'Epoch: [{epoch}][{iter + 1}/{len(train_dataloader)}]'
183 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
184 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
185 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
186 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
187 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'
188 | f' | Match Loss [{match_losses.val:.4f} ({match_losses.avg:.4f})]'
189 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
190 |
191 | # loss_logger = {
192 | # "Focal Loss": focal_losses.avg,
193 | # "Dice Loss": dice_losses.avg,
194 | # "IoU Loss": iou_losses.avg,
195 | # "Total Loss": total_losses.avg
196 | # }
197 | # fabric.log_dict(loss_logger, num_iter * (epoch - 1) + iter)
198 | torch.cuda.empty_cache()
199 |
200 | if epoch % cfg.eval_interval == 0:
201 | iou, _= validate(fabric, cfg, model, val_dataloader, cfg.name, epoch)
202 | if iou > max_iou:
203 | state = {"model": model, "optimizer": optimizer}
204 | fabric.save(os.path.join(cfg.out_dir, "save", "best-ckpt.pth"), state)
205 | max_iou = iou
206 | del iou
207 |
208 | def configure_opt(cfg: Box, model: Model):
209 |
210 | def lr_lambda(step):
211 | if step < cfg.opt.warmup_steps:
212 | return step / cfg.opt.warmup_steps
213 | elif step < cfg.opt.steps[0]:
214 | return 1.0
215 | elif step < cfg.opt.steps[1]:
216 | return 1 / cfg.opt.decay_factor
217 | else:
218 | return 1 / (cfg.opt.decay_factor**2)
219 |
220 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)
221 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
222 |
223 | return optimizer, scheduler
224 |
225 |
226 | def corrupt_main(cfg):
227 | for corrupt in cfg.corruptions:
228 | cfg.corrupt = corrupt
229 | cfg.out_name = corrupt
230 | torch.cuda.empty_cache()
231 | main(cfg)
232 |
233 |
234 | def main(cfg: Box) -> None:
235 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())]
236 | num_devices = len(gpu_ids)
237 | fabric = L.Fabric(accelerator="auto",
238 | devices=num_devices,
239 | strategy="auto",
240 | loggers=[TensorBoardLogger(cfg.out_dir)])
241 | fabric.launch()
242 | fabric.seed_everything(1337 + fabric.global_rank)
243 |
244 | if fabric.global_rank == 0:
245 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True)
246 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys)
247 |
248 | with fabric.device:
249 | model = Model(cfg)
250 | model.setup()
251 |
252 | load_datasets = call_load_dataset(cfg)
253 | train_data, val_data, pt_data = load_datasets(cfg, img_size=1024, return_pt = True)
254 | train_data = fabric._setup_dataloader(train_data)
255 | val_data = fabric._setup_dataloader(val_data)
256 | pt_data = fabric._setup_dataloader(pt_data)
257 | optimizer, scheduler = configure_opt(cfg, model)
258 | model, optimizer = fabric.setup(model, optimizer)
259 |
260 | if cfg.resume and cfg.model.ckpt is not None:
261 | full_checkpoint = fabric.load(cfg.model.ckpt)
262 | model.load_state_dict(full_checkpoint["model"])
263 | optimizer.load_state_dict(full_checkpoint["optimizer"])
264 | print('-'*100)
265 | print('\033[92mDirect test on the original SAM.\033[0m')
266 | _, _, = validate(fabric, cfg, model, val_data, name=cfg.name, epoch=0)
267 | print('-'*100)
268 | del _
269 |
270 | target_pts = offline_prototypes_generation(cfg, model, pt_data)
271 |
272 | train_sam(cfg, fabric, model, optimizer, scheduler, train_data, val_data, target_pts)
273 |
274 | del model, train_data, val_data
275 |
276 |
277 | def parse_args():
278 | parser = argparse.ArgumentParser(description='Train a detector')
279 | parser.add_argument('--cfg', help='train config file path')
280 | parser.add_argument('--prompt', help='the type of prompt')
281 | parser.add_argument('--num_points',type=int, help='the number of points')
282 | parser.add_argument('--out_dir', help='the dir to save logs and models')
283 | parser.add_argument('--load_type', help='the dir to save logs and models')
284 | args = parser.parse_args()
285 | return args
286 |
287 | if __name__ == "__main__":
288 | print(torch.cuda.current_device())
289 | torch.cuda.empty_cache()
290 | torch.set_float32_matmul_precision('high')
291 | args = parse_args()
292 |
293 | exec(f'from {args.cfg} import cfg')
294 |
295 | # transfer the args to a dict
296 | args_dict = vars(args)
297 | cfg.merge_update(args_dict)
298 |
299 | main(cfg)
300 | torch.cuda.empty_cache()
301 |
--------------------------------------------------------------------------------
/train_selftrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from box import Box
7 | import lightning as L
8 | from lightning.fabric.fabric import _FabricOptimizer
9 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
10 | from torch.utils.data import DataLoader
11 | import cv2
12 | import argparse
13 | from datasets import call_load_dataset
14 | from utils.losses import DiceLoss, FocalLoss, ContraLoss
15 | from utils.model import Model
16 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts, calc_iou_matrix
17 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances
18 |
19 |
20 |
21 | def train_sam(
22 | cfg: Box,
23 | fabric: L.Fabric,
24 | model: Model,
25 | teacher_model: Model,
26 | optimizer: _FabricOptimizer,
27 | scheduler: _FabricOptimizer,
28 | train_dataloader: DataLoader,
29 | val_dataloader: DataLoader,
30 | ):
31 | """The SAM training loop."""
32 |
33 | focal_loss = FocalLoss()
34 | dice_loss = DiceLoss()
35 | max_iou = 0.
36 | max_camo = 0.
37 | max_chame = 0.
38 | max_cod = 0.
39 |
40 | for epoch in range(1, cfg.num_epochs + 1):
41 | batch_time = AverageMeter()
42 | data_time = AverageMeter()
43 | focal_losses = AverageMeter()
44 | dice_losses = AverageMeter()
45 | iou_losses = AverageMeter()
46 | total_losses = AverageMeter()
47 | end = time.time()
48 | num_iter = len(train_dataloader)
49 |
50 | for iter, data in enumerate(train_dataloader):
51 |
52 | data_time.update(time.time() - end)
53 | images_weak, images_strong, bboxes, gt_masks, img_paths= data
54 | batch_size = images_weak.size(0)
55 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks)
56 | if num_insts > 100:
57 | continue
58 | if num_insts > cfg.max_nums:
59 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums)
60 |
61 | prompts = get_prompts(cfg, bboxes, gt_masks)
62 |
63 | soft_image_embeds, soft_masks, soft_iou_predictions, soft_res_masks = teacher_model(images_weak, prompts) # teacher
64 | pred_image_embeds, pred_masks, iou_predictions, pred_res_masks = model(images_strong, prompts) # student
65 |
66 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
67 | loss_focal = torch.tensor(0., device=fabric.device)
68 | loss_dice = torch.tensor(0., device=fabric.device)
69 | loss_iou = torch.tensor(0., device=fabric.device)
70 | for pred_mask, soft_mask, iou_prediction in zip(pred_masks, soft_masks, iou_predictions):
71 | soft_mask = (soft_mask > 0.).float()
72 |
73 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks)
74 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks)
75 | batch_iou = calc_iou(pred_mask, soft_mask)
76 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
77 |
78 | loss_total = 20. * loss_focal + loss_dice + loss_iou
79 | fabric.backward(loss_total)
80 |
81 | # if iter + 1 % 5 == 0:
82 | optimizer.step()
83 | scheduler.step()
84 | optimizer.zero_grad()
85 | torch.cuda.empty_cache()
86 |
87 | batch_time.update(time.time() - end)
88 | end = time.time()
89 |
90 | focal_losses.update(loss_focal.item(), batch_size)
91 | dice_losses.update(loss_dice.item(), batch_size)
92 | iou_losses.update(loss_iou.item(), batch_size)
93 | total_losses.update(loss_total.item(), batch_size)
94 | if iter%50==0:
95 | fabric.print(f'Epoch: [{epoch}][{iter + 1}/{len(train_dataloader)}]'
96 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
97 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
98 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
99 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
100 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'
101 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
102 |
103 | # loss_logger = {
104 | # "Focal Loss": focal_losses.avg,
105 | # "Dice Loss": dice_losses.avg,
106 | # "IoU Loss": iou_losses.avg,
107 | # "Total Loss": total_losses.avg
108 | # }
109 | # fabric.log_dict(loss_logger, num_iter * (epoch - 1) + iter)
110 | torch.cuda.empty_cache()
111 |
112 | if epoch % cfg.eval_interval == 0:
113 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch)
114 | if iou > max_iou:
115 | state = {"model": model, "optimizer": optimizer}
116 | fabric.save(os.path.join(cfg.out_dir, "save", "last-ckpt.pth"), state)
117 | max_iou = iou
118 |
119 | def configure_opt(cfg: Box, model: Model):
120 |
121 | def lr_lambda(step):
122 | if step < cfg.opt.warmup_steps:
123 | return step / cfg.opt.warmup_steps
124 | elif step < cfg.opt.steps[0]:
125 | return 1.0
126 | elif step < cfg.opt.steps[1]:
127 | return 1 / cfg.opt.decay_factor
128 | else:
129 | return 1 / (cfg.opt.decay_factor**2)
130 |
131 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)
132 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
133 |
134 | return optimizer, scheduler
135 |
136 |
137 | def corrupt_main(cfg):
138 | for corrupt in cfg.corruptions:
139 | cfg.corrupt = corrupt
140 | cfg.out_name = corrupt
141 | torch.cuda.empty_cache()
142 | main(cfg)
143 |
144 |
145 | def main(cfg: Box) -> None:
146 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())]
147 | num_devices = len(gpu_ids)
148 | fabric = L.Fabric(accelerator="auto",
149 | devices=num_devices,
150 | strategy="auto",
151 | loggers=[TensorBoardLogger(cfg.out_dir)])
152 | fabric.launch()
153 | fabric.seed_everything(1337 + fabric.global_rank)
154 |
155 | if fabric.global_rank == 0:
156 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True)
157 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys)
158 |
159 | with fabric.device:
160 | model = Model(cfg)
161 | model.setup()
162 |
163 | load_datasets = call_load_dataset(cfg)
164 | train_data, val_data = load_datasets(cfg, 1024)
165 | train_data = fabric._setup_dataloader(train_data)
166 | val_data = fabric._setup_dataloader(val_data)
167 | optimizer, scheduler = configure_opt(cfg, model)
168 | model, optimizer = fabric.setup(model, optimizer)
169 |
170 | if cfg.resume and cfg.model.ckpt is not None:
171 | full_checkpoint = fabric.load(cfg.model.ckpt)
172 | model.load_state_dict(full_checkpoint["model"])
173 | optimizer.load_state_dict(full_checkpoint["optimizer"])
174 |
175 | anchor_model = copy_model(model)
176 | print('-'*100)
177 | print('\033[92mDirect test on the original SAM.\033[0m')
178 | print('-'*100)
179 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, epoch=0)
180 | train_sam(cfg, fabric, model, anchor_model, optimizer, scheduler, train_data, val_data)
181 |
182 | del model, anchor_model, train_data, val_data
183 |
184 |
185 | def parse_args():
186 | parser = argparse.ArgumentParser(description='Train a detector')
187 | parser.add_argument('--cfg', help='train config file path')
188 | parser.add_argument('--prompt', help='the type of prompt')
189 | parser.add_argument('--num_points',type=int, help='the number of points')
190 | parser.add_argument('--out_dir', help='the dir to save logs and models')
191 | parser.add_argument('--load_type', help='the dir to save logs and models')
192 | args = parser.parse_args()
193 | return args
194 |
195 | if __name__ == "__main__":
196 | print(torch.cuda.current_device())
197 | torch.cuda.empty_cache()
198 | torch.set_float32_matmul_precision('high')
199 | args = parse_args()
200 |
201 | exec(f'from {args.cfg} import cfg')
202 |
203 | # transfer the args to a dict
204 | args_dict = vars(args)
205 | cfg.merge_update(args_dict)
206 |
207 | main(cfg)
208 | torch.cuda.empty_cache()
209 |
--------------------------------------------------------------------------------
/train_supervise.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import lightning as L
5 | import torch.nn.functional as F
6 | import segmentation_models_pytorch as smp
7 | from box import Box
8 | from lightning.fabric.fabric import _FabricOptimizer
9 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
10 | from torch.utils.data import DataLoader
11 | import cv2
12 | import argparse
13 |
14 | from datasets import call_load_dataset
15 | from utils.losses import DiceLoss, FocalLoss, ContraLoss
16 | from utils.model import Model
17 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts
18 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances
19 |
20 | def train_sam(
21 | cfg: Box,
22 | fabric: L.Fabric,
23 | model: Model,
24 | optimizer: _FabricOptimizer,
25 | scheduler: _FabricOptimizer,
26 | train_dataloader: DataLoader,
27 | val_dataloader: DataLoader,
28 | ):
29 | """The SAM training loop."""
30 |
31 | focal_loss = FocalLoss()
32 | dice_loss = DiceLoss()
33 | max_iou = 0.
34 |
35 | for epoch in range(1, cfg.num_epochs + 1):
36 | batch_time = AverageMeter()
37 | data_time = AverageMeter()
38 | focal_losses = AverageMeter()
39 | dice_losses = AverageMeter()
40 | iou_losses = AverageMeter()
41 | total_losses = AverageMeter()
42 | end = time.time()
43 | num_iter = len(train_dataloader)
44 | for iter, data in enumerate(train_dataloader):
45 |
46 | data_time.update(time.time() - end)
47 | images, bboxes, gt_masks, img_paths = data
48 |
49 | batch_size = images.size(0)
50 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks)
51 | if num_insts > cfg.max_nums:
52 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums)
53 |
54 | prompts = get_prompts(cfg, bboxes, gt_masks)
55 |
56 | _, pred_masks, iou_predictions, _ = model(images, prompts)
57 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
58 | loss_focal = torch.tensor(0., device=fabric.device)
59 | loss_dice = torch.tensor(0., device=fabric.device)
60 | loss_iou = torch.tensor(0., device=fabric.device)
61 | for i,(pred_mask, gt_mask, iou_prediction) in enumerate(zip(pred_masks, gt_masks, iou_predictions)):
62 | gt_mask = gt_mask.to(device=fabric.device)
63 | batch_iou = calc_iou(pred_mask, gt_mask)
64 | loss_focal += focal_loss(pred_mask, gt_mask, num_masks)
65 | loss_dice += dice_loss(pred_mask, gt_mask, num_masks)
66 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
67 |
68 | loss_total = 20. * loss_focal + loss_dice + loss_iou
69 | optimizer.zero_grad()
70 | fabric.backward(loss_total)
71 | optimizer.step()
72 | scheduler.step()
73 | batch_time.update(time.time() - end)
74 | end = time.time()
75 |
76 | focal_losses.update(loss_focal.item(), batch_size)
77 | dice_losses.update(loss_dice.item(), batch_size)
78 | iou_losses.update(loss_iou.item(), batch_size)
79 | total_losses.update(loss_total.item(), batch_size)
80 | if iter%50==0:
81 | fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
82 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
83 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
84 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
85 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
86 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'
87 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
88 |
89 | if epoch % cfg.eval_interval == 0:
90 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch)
91 | if iou > max_iou:
92 | state = {"model": model, "optimizer": optimizer}
93 | fabric.save(os.path.join(cfg.out_dir, "save", "last-ckpt.pth"), state)
94 | max_iou = iou
95 |
96 | def configure_opt(cfg: Box, model: Model):
97 |
98 | def lr_lambda(step):
99 | if step < cfg.opt.warmup_steps:
100 | return step / cfg.opt.warmup_steps
101 | elif step < cfg.opt.steps[0]:
102 | return 1.0
103 | elif step < cfg.opt.steps[1]:
104 | return 1 / cfg.opt.decay_factor
105 | else:
106 | return 1 / (cfg.opt.decay_factor**2)
107 |
108 | optimizer = torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)
109 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
110 |
111 | return optimizer, scheduler
112 |
113 |
114 | def corrupt_main(cfg):
115 | for corrupt in cfg.corruptions:
116 | cfg.corrupt = corrupt
117 | cfg.out_name = corrupt
118 | torch.cuda.empty_cache()
119 | main(cfg)
120 |
121 |
122 | def main(cfg: Box) -> None:
123 | gpu_ids = [str(i) for i in range(torch.cuda.device_count())]
124 | num_devices = len(gpu_ids)
125 | fabric = L.Fabric(accelerator="auto",
126 | devices=num_devices,
127 | strategy="auto",
128 | loggers=[TensorBoardLogger(cfg.out_dir)])
129 | fabric.launch()
130 | fabric.seed_everything(1337 + fabric.global_rank)
131 |
132 | if fabric.global_rank == 0:
133 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True)
134 | create_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_head=cfg.csv_keys)
135 |
136 | with fabric.device:
137 | model = Model(cfg)
138 | model.setup()
139 |
140 | load_datasets = call_load_dataset(cfg)
141 | train_data, val_data = load_datasets(cfg, 1024)
142 | train_data = fabric._setup_dataloader(train_data)
143 | val_data = fabric._setup_dataloader(val_data)
144 | optimizer, scheduler = configure_opt(cfg, model)
145 | model, optimizer = fabric.setup(model, optimizer)
146 |
147 | if cfg.resume and cfg.model.ckpt is not None:
148 | full_checkpoint = fabric.load(cfg.model.ckpt)
149 | model.load_state_dict(full_checkpoint["model"])
150 | optimizer.load_state_dict(full_checkpoint["optimizer"])
151 |
152 | anchor_model = copy_model(model)
153 | print('-'*100)
154 | print('\033[92mDirect test on the original SAM.\033[0m')
155 | print('-'*100)
156 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, epoch=0)
157 | train_sam(cfg, fabric, model, optimizer, scheduler, train_data, val_data)
158 |
159 | del model, anchor_model, train_data, val_data
160 |
161 |
162 | def parse_args():
163 | parser = argparse.ArgumentParser(description='Train a detector')
164 | parser.add_argument('--cfg', help='train config file path')
165 | parser.add_argument('--prompt', help='the type of prompt')
166 | parser.add_argument('--num_points',type=int, help='the number of points')
167 | parser.add_argument('--out_dir', help='the dir to save logs and models')
168 | parser.add_argument('--load_type', help='the dir to save logs and models')
169 | args = parser.parse_args()
170 | return args
171 |
172 | if __name__ == "__main__":
173 | print(torch.cuda.current_device())
174 | torch.cuda.empty_cache()
175 | torch.set_float32_matmul_precision('high')
176 | args = parse_args()
177 |
178 | exec(f'from {args.cfg} import cfg')
179 |
180 | # transfer the args to a dict
181 | args_dict = vars(args)
182 | cfg.merge_update(args_dict)
183 |
184 | main(cfg)
185 | torch.cuda.empty_cache()
186 |
--------------------------------------------------------------------------------
/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from scipy.ndimage import map_coordinates
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.utils.data import DataLoader
9 |
10 | import lightning as L
11 |
12 | import segmentation_models_pytorch as smp
13 |
14 | from box import Box
15 | from utils.model import Model
16 | from utils.sample_utils import get_point_prompts
17 | from utils.tools import write_csv
18 |
19 |
20 | class AverageMeter:
21 | """Computes and stores the average and current value."""
22 |
23 | def __init__(self):
24 | self.reset()
25 |
26 | def reset(self):
27 | self.val = 0
28 | self.avg = 0
29 | self.sum = 0
30 | self.count = 0
31 |
32 | def update(self, val, n=1):
33 | self.val = val
34 | self.sum += val * n
35 | self.count += n
36 | self.avg = self.sum / self.count
37 |
38 |
39 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor):
40 | pred_mask = (pred_mask >= 0.5).float()
41 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2))
42 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection
43 | epsilon = 1e-7
44 | batch_iou = intersection / (union + epsilon)
45 |
46 | batch_iou = batch_iou.unsqueeze(1)
47 | return batch_iou
48 |
49 | def calc_iou_instance(pred_masks: torch.Tensor, gt_masks: torch.Tensor):
50 | iou_list = []
51 | for pred_mask, gt_mask in zip(pred_masks, gt_masks):
52 | pred_mask = (pred_mask >= 0.5).float()
53 | # print(pred_mask.shape)
54 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(0, 1))
55 |
56 | union = torch.sum(pred_mask, dim=(0, 1)) + torch.sum(gt_mask, dim=(0, 1)) - intersection
57 | epsilon = 1e-7
58 | iou = intersection / (union + epsilon)
59 | # print(iou)
60 | # batch_iou = batch_iou.unsqueeze(1)
61 | iou_list.append(iou)
62 | return iou_list
63 |
64 |
65 | #intersection
66 | # mask1: mask2: intersection:
67 | # [[1, 1, 0, 0], [[1, 0, 0, 0], [[1, 0, 0, 0],
68 | # [1, 0, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0],
69 | # [1, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0],
70 | # [0, 0, 0, 0]] [0, 0, 0, 0]] [0, 0, 0, 0]]
71 |
72 | #union
73 | # mask1: mask2: union:
74 | # [[1, 1, 0, 0], [[1, 0, 0, 0], [[1, 1, 0, 0],
75 | # [1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0],
76 | # [1, 1, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0],
77 | # [0, 0, 0, 0]] [0, 0, 0, 0]] [0, 0, 0, 0]]
78 |
79 |
80 | def calculate_iou(mask1, mask2):
81 |
82 | intersection = torch.logical_and(mask1, mask2)
83 | union = torch.logical_or(mask1, mask2)
84 | iou = torch.sum(intersection).float() / torch.sum(union).float()
85 | return iou
86 |
87 | def calc_iou_matrix(mask_list1, mask_list2):
88 |
89 | iou_matrix = torch.zeros((len(mask_list1), len(mask_list2)))
90 | for i, mask1 in enumerate(mask_list1):
91 | for j, mask2 in enumerate(mask_list2):
92 | iou_matrix[i, j] = calculate_iou(mask1, mask2)
93 | return iou_matrix
94 |
95 | def get_prompts(cfg: Box, bboxes, gt_masks):
96 | if cfg.prompt == "box" or cfg.prompt == "coarse":
97 | prompts = bboxes
98 | elif cfg.prompt == "point":
99 | prompts = get_point_prompts(gt_masks, cfg.num_points)
100 | else:
101 | raise ValueError("Prompt Type Error!")
102 | return prompts
103 |
104 | def validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, epoch: int = 0):
105 | model.eval()
106 | ious = AverageMeter()
107 | f1_scores = AverageMeter()
108 | recall = AverageMeter()
109 | precision = AverageMeter()
110 |
111 | with torch.no_grad():
112 | for iter, data in enumerate(tqdm(val_dataloader, desc='Validation', ncols=100)):
113 | images, bboxes, gt_masks, img_paths = data
114 | num_images = images.size(0)
115 | prompts = get_prompts(cfg, bboxes, gt_masks)
116 |
117 | _, pred_masks, _, _ = model(images, prompts)
118 |
119 | for pred_mask, gt_mask in zip(pred_masks, gt_masks):
120 | batch_stats = smp.metrics.get_stats(
121 | pred_mask,
122 | gt_mask.int(),
123 | mode='binary',
124 | threshold=0.5,
125 | )
126 | batch_recall = smp.metrics.recall(*batch_stats, reduction="micro-imagewise")
127 | batch_precision = smp.metrics.precision(*batch_stats, reduction="micro-imagewise")
128 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise")
129 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise")
130 | ious.update(batch_iou, num_images)
131 | f1_scores.update(batch_f1, num_images)
132 | recall.update(batch_recall, num_images)
133 | precision.update(batch_precision, num_images)
134 |
135 | torch.cuda.empty_cache()
136 |
137 | fabric.print(
138 | f'Val: [{epoch}] - [{iter+1}/{len(val_dataloader)}]: IoU: [{ious.avg:.4f}] -- Recall: [{recall.avg:.4f}] -- Precision [{precision.avg:.4f}] -- F1: [{f1_scores.avg:.4f}]'
139 | )
140 | csv_dict = {"Prompt": cfg.prompt, "IoU": f"{ious.avg:.4f}","Recall": f"{recall.avg:.4f}", "Precision": f"{precision.avg:.4f}", "F1": f"{f1_scores.avg:.4f}", "epoch": epoch}
141 |
142 | if fabric.global_rank == 0:
143 | write_csv(os.path.join(cfg.out_dir, "metrics.csv"), csv_dict, csv_head=cfg.csv_keys)
144 | return ious.avg, f1_scores.avg
145 |
146 |
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/utils/finch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | FINCH - First Integer Neighbor Clustering Hierarchy Algorithm
4 | """
5 |
6 | # Author: Eren Cakmak
7 | #
8 | # License: MIT
9 |
10 | import numpy as np
11 | from sklearn.neighbors import NearestNeighbors
12 | from scipy.sparse.csgraph import connected_components
13 | from sklearn.utils import check_array
14 | from sklearn.metrics import silhouette_score
15 |
16 |
17 | class FINCH():
18 | """
19 | A class to perform the FINCH clustering
20 |
21 | Read more in paper see reference below.
22 |
23 | Parameters
24 | ----------
25 | metric : string default='euclidean'
26 | The used distance metric - more options are
27 | ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’,
28 | ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘jensenshannon’,
29 | ‘kulsinski’, ‘mahalanobis’, ‘matching’, ‘rogerstanimoto’, ‘sqeuclidean’,
30 | ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘yule’.
31 |
32 | n_jobs : int or None, default=1
33 | The number of processes to start -1 means use all processors
34 |
35 | Attributes
36 | ----------
37 | labels : array, shape = [n_samples]
38 | Cluster labels for the data
39 |
40 | partitions : dict, contains all partitioning and their resulting labels, cluster centroids
41 | Cluster labels for the data
42 |
43 | References
44 | ----------
45 | Sarfraz, Saquib, Vivek Sharma, and Rainer Stiefelhagen.
46 | "Efficient parameter-free clustering using first neighbor relations."
47 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
48 |
49 | """
50 | def __init__(self, metric='cosine', n_jobs=1, verbose=True):
51 | self.metric = metric
52 | self.n_jobs = n_jobs
53 | self.verbose = verbose
54 |
55 | def _finch(self, X, prev_clusters, prev_cluster_core_indices):
56 | """
57 | Compute the adjacency link matrix as described in the paper Eq.1
58 | Afterwards get connected components and their cluster centroids
59 | ----------
60 | X : ndarray of shape (n_samples, n_features)
61 | The input data samples that should be clustered.
62 |
63 | prev_clusters : list of ndarray of shape (2,)
64 | The cluster centroids of the previous partitioning.
65 |
66 | prev_cluster_core_indices : list
67 | The previous samples belonging to the cluster of prev_clusters
68 | Returns
69 | -------
70 | n_connected_components_ : int
71 | The number of clusters in the partitioning.
72 |
73 | labels_ : list
74 | Cluster labels for all data samples between 0 and n_connected_components_
75 |
76 | cluster_centers_ : list of ndarray of shape (2,)
77 | The cluster centroids of the partitioning.
78 |
79 | cluster_core_indices_ : list
80 | The samples belonging to the cluster of cluster_centers_
81 | """
82 |
83 | # Adjacency link matrix by Eq.1
84 | connectivity = None
85 |
86 | # compute the adjacency link matrix
87 | if not prev_clusters:
88 | # first partitioning
89 | data = X
90 | else:
91 | data = prev_clusters
92 |
93 | # Compute the adjacency link matrix as described in the paper Eq.1
94 | # NN in sklearn
95 | nbrs = NearestNeighbors(n_neighbors=2,#2
96 | metric=self.metric,
97 | n_jobs=self.n_jobs).fit(data)
98 |
99 | # condition j = k_i - link nearest neighbors
100 | connectivity = nbrs.kneighbors_graph(data)
101 |
102 | # condition k_i = k_j - link same first neighbors
103 | # dot product forces symmtery therefore k_j plus k_j = i
104 | connectivity @= connectivity.T
105 |
106 | # remove diagonal
107 | connectivity.setdiag(0)
108 | connectivity.eliminate_zeros()
109 |
110 | # TODO this could also be solved by computing a linkage matrix
111 | # and then just calling the method scipy.cluster.hierarchy.fcluster
112 | # This will be probably increase the performance of the method further
113 | #
114 | # set values to one required for the linkage matrix
115 | # connectivity.data[:] = 1
116 |
117 | # get connected components
118 | n_connected_components_, labels_ = connected_components(
119 | csgraph=connectivity)
120 |
121 | # labels remap to previous cluster core indices
122 | # only called for second paritioning
123 | if len(labels_) < self.n_samples:
124 | new_labels = np.full(self.n_samples, 0)
125 | for i in range(n_connected_components_):
126 | idx = np.where(labels_ == i)[0]
127 | idx = sum([prev_cluster_core_indices[j] for j in idx], [])
128 | new_labels[idx] = i
129 | labels_ = new_labels
130 |
131 | # list of centroids and sample indices for each cluster
132 | cluster_centers_ = []
133 | cluster_core_indices_ = []
134 |
135 | # compute cluster centers with labels indicies
136 | for i in range(n_connected_components_):
137 | # update the cluster core indicies
138 | idx = np.where(labels_ == i)[0]
139 | cluster_core_indices_.append(idx.tolist())
140 |
141 | # compute the cluster means
142 | xc_mean = X[idx].mean(axis=0)
143 | cluster_centers_.append(xc_mean)
144 |
145 | return n_connected_components_, labels_, cluster_centers_, cluster_core_indices_
146 |
147 | def fit(self, X):
148 | """
149 | Apply the FINCH algorithm
150 | ----------
151 | X : ndarray of shape (n_samples, n_features)
152 | The data samples that are clustered
153 |
154 | Returns
155 | -------
156 | self
157 | """
158 | # check if input is correct
159 | X = check_array(X)
160 |
161 | self.n_samples = X.shape[0]
162 |
163 | # the results of the partitioning
164 | results = {}
165 |
166 | # intermediate results
167 | cluster_centers_ = None
168 | cluster_core_indices_ = None
169 |
170 | n_connected_components_ = len(X)
171 | if self.verbose:
172 | print('FINCH Partitionings')
173 | print('-------------------')
174 |
175 | i = 0
176 | while n_connected_components_ > 1:
177 | n_connected_components_, labels_, cluster_centers_, cluster_core_indices_ = self._finch(
178 | X, cluster_centers_, cluster_core_indices_)
179 |
180 | if n_connected_components_ == 1:
181 | break
182 | else:
183 | if self.verbose:
184 | print('Clusters in %s partition: %d' %
185 | (i, n_connected_components_))
186 |
187 | results['parition_' + str(i)] = {
188 | 'n_clusters': n_connected_components_,
189 | 'labels': labels_,
190 | 'cluster_centers': cluster_centers_,
191 | 'cluster_core_indices': cluster_core_indices_
192 | }
193 | i += 1
194 |
195 | self.partitions = results
196 |
197 | return self
198 |
199 | def fit_predict(self, X, verbose):
200 | """
201 | Apply the FINCH algorithm and returns a reasonable partitioning labels based on the silhouette coeffcient
202 | ----------
203 | X : ndarray of shape (n_samples, n_features)
204 | The data samples that are clustered
205 |
206 | Returns
207 | -------
208 | self
209 | """
210 | # check if input is correct
211 | X = check_array(X)
212 |
213 | self.n_samples = X.shape[0]
214 |
215 | # the results of the partitioning
216 | results = {}
217 |
218 | # intermediate results
219 | cluster_centers_ = None
220 | cluster_core_indices_ = None
221 |
222 | # min silhouette coefficent score
223 | max_sil_score = -1
224 |
225 | n_connected_components_ = len(X)
226 |
227 | print('FINCH Partitionings')
228 | print('-------------------')
229 |
230 | i = 0
231 | while n_connected_components_ > 1:
232 | n_connected_components_, labels_, cluster_centers_, cluster_core_indices_ = self._finch(
233 | X, cluster_centers_, cluster_core_indices_)
234 |
235 | if n_connected_components_ == 1:
236 | break
237 | else:
238 | # in this version the silhouette coefficent is computed
239 | sil_score = silhouette_score(X, labels_, metric=self.metric)
240 | # store the max silhouette coefficent
241 | # do not pick the first partitioning
242 | if max_sil_score <= sil_score and i != 0:
243 | best_labels = labels_
244 | max_sil_score = sil_score
245 |
246 | print(
247 | 'Clusters in %s partition: %d with average silhouette score %0.2f'
248 | % (i, n_connected_components_, sil_score))
249 |
250 | results['parition_' + str(i)] = {
251 | 'n_clusters': n_connected_components_,
252 | 'labels': labels_,
253 | 'cluster_centers': cluster_centers_,
254 | 'cluster_core_indices': cluster_core_indices_,
255 | 'silhouette_coefficient': sil_score
256 | }
257 | i += 1
258 |
259 | self.labels = best_labels
260 | self.partitions = results
261 |
262 | return self.labels
263 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from abc import ABC
6 | from scipy.optimize import linear_sum_assignment
7 |
8 | ALPHA = 0.8
9 | GAMMA = 2
10 |
11 | class FocalLoss(nn.Module):
12 |
13 | def __init__(self, weight=None, size_average=True):
14 | super().__init__()
15 |
16 | # def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
17 | # inputs = F.sigmoid(inputs)
18 | # inputs = torch.clamp(inputs, min=0, max=1)
19 | # #flatten label and prediction tensors
20 | # inputs = inputs.view(-1)
21 | # targets = targets.view(-1)
22 |
23 | # BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
24 | # BCE_EXP = torch.exp(-BCE)
25 | # focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
26 |
27 | # return focal_loss
28 |
29 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
30 | inputs = F.sigmoid(inputs)
31 | inputs = torch.clamp(inputs, min=0, max=1)
32 | #flatten label and prediction tensors
33 | inputs = inputs.view(-1)
34 | targets = targets.view(-1)
35 |
36 | BCE = F.binary_cross_entropy(inputs, targets, reduction='none')
37 | BCE_EXP = torch.exp(-BCE)
38 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
39 | focal_loss = focal_loss.mean()
40 |
41 | return focal_loss.mean()
42 |
43 |
44 | def dice_coefficient(x, target):
45 | eps = 1e-5
46 | n_inst = x.size(0)
47 | print(x.shape)
48 | x = x.reshape(n_inst, -1)
49 | target = target.reshape(n_inst, -1)
50 | intersection = (x * target).sum(dim=1)
51 | union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps
52 | loss = 1. - (2 * intersection / union)
53 | return loss
54 |
55 | class DiceLoss(nn.Module):
56 |
57 | def __init__(self, weight=None, size_average=True):
58 | super().__init__()
59 |
60 | def forward(self, inputs, targets, smooth=1):
61 | inputs = F.sigmoid(inputs)
62 | inputs = torch.clamp(inputs, min=0, max=1)
63 | #flatten label and prediction tensors
64 | inputs = inputs.view(-1)
65 | targets = targets.view(-1)
66 |
67 | intersection = (inputs * targets).sum()
68 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
69 |
70 | return 1 - dice
71 |
72 | class ContraLoss(nn.Module):
73 |
74 | def __init__(self, temperature = 0.3, weight=None, size_average=True):
75 | super().__init__()
76 | self.temperature = temperature
77 | self.criterion = torch.nn.CrossEntropyLoss()
78 |
79 | def forward(self, embedd_x: torch.Tensor, embedd_y: torch.Tensor, mask_x: torch.Tensor, mask_y: torch.Tensor):
80 | x_embedding = self.norm_embed(embedd_x) # embedd_x: [256, 64, 64]
81 | y_embedding = self.norm_embed(embedd_y)
82 |
83 | # mask_x = mask_x.float()
84 | # mask_y = mask_y.float()
85 | x_masks = F.interpolate(mask_x, size=x_embedding.shape[-2:], mode="bilinear", align_corners=False).detach()
86 | # x_masks = F.sigmoid(x_masks)
87 | # x_masks = torch.clamp(x_masks, min=0, max=1)
88 | # x_masks = (x_masks > 0.5).float()
89 | sum_x = x_masks.sum(dim=[-1, -2]).clone()
90 | # sum_x[sum_x[:, 0] == 0.] = 1.
91 |
92 | y_masks = F.interpolate(mask_y, size=y_embedding.shape[-2:], mode="bilinear", align_corners=False).detach()
93 | # y_masks = F.sigmoid(y_masks)
94 | # y_masks = torch.clamp(y_masks, min=0, max=1)
95 | # y_masks = (y_masks > 0.5).float()
96 | sum_y = y_masks.sum(dim=[-1, -2]).clone()
97 | # sum_y[sum_y[:, 0] == 0.] = 1.
98 | # [n, 1, H, W]
99 | multi_embedd_x = (x_embedding * x_masks).sum(dim=[-1, -2]) / sum_x # [n, 256, 64, 64] >> [n, 256]
100 | multi_embedd_y = (y_embedding * y_masks).sum(dim=[-1, -2]) / sum_y
101 |
102 | flatten_x = multi_embedd_x.view(multi_embedd_x.size(0), -1)
103 | flatten_y = multi_embedd_y.view(multi_embedd_y.size(0), -1)
104 | # similarity_matrix = torch.matmul(multi_embedd_x, multi_embedd_y.T)
105 | similarity_matrix = F.cosine_similarity(flatten_x.unsqueeze(1), flatten_y.unsqueeze(0), dim=2)
106 |
107 | label_pos = torch.eye(x_masks.size(0)).bool().to(embedd_x.device)
108 | label_nag = ~label_pos
109 |
110 | similarity_matrix = similarity_matrix / self.temperature # [n insts, n insts]
111 | loss = -torch.log(
112 | similarity_matrix.masked_select(label_pos).exp().sum() /
113 | similarity_matrix.exp().sum()
114 | )
115 | # loss = -torch.log(
116 | # similarity_matrix.masked_select(label_pos).exp().sum()
117 | # )
118 | # loss = -torch.log(
119 | # similarity_matrix.masked_select(label_pos).exp().sum() /
120 | # (similarity_matrix.masked_select(label_nag).exp().sum() + 1e-7)
121 | # )
122 | return loss
123 |
124 | def norm_embed(self, embedding: torch.Tensor):
125 | embedding = F.normalize(embedding, dim=0, p=2)
126 | return embedding
127 |
128 |
129 | def cosine_similarity(vec1, vec2, device):
130 | vec1 = torch.tensor(vec1).to(device).float()
131 | vec2 = torch.tensor(vec2).to(device).float()
132 | dot_product = torch.dot(vec1, vec2)
133 | norm1 = torch.norm(vec1)
134 | norm2 = torch.norm(vec2)
135 | similarity = dot_product / (norm1 * norm2)
136 | return similarity
137 |
138 | def Matching_Loss(pred_prototypes, target_prototypes, device):
139 | """
140 | pred_prototypes: list of predicted prototypes
141 | target_prototypes: list of target prototypes
142 | """
143 | num_pred = len(pred_prototypes)
144 | num_target = len(target_prototypes)
145 |
146 | num = min(num_pred,num_target)
147 | cost_matrix = torch.zeros((num_pred, num_target))
148 | for i, pred_proto in enumerate(pred_prototypes):
149 | for j, target_proto in enumerate(target_prototypes):
150 |
151 | cos_sim = cosine_similarity(pred_proto, target_proto, device)
152 | cost_matrix[i, j] = 1 - cos_sim
153 |
154 | row_indices, col_indices = linear_sum_assignment(cost_matrix.numpy())
155 |
156 | total_loss = 0
157 | for row, col in zip(row_indices, col_indices):
158 | total_loss += cost_matrix[row, col].item()
159 |
160 | return total_loss/len(row_indices)
161 |
162 |
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from segment_anything import sam_model_registry
7 | # from segment_anything_2.model_registry import sam2_model_registry
8 | from .sam_lora import LoRA_Sam
9 |
10 | class Model(nn.Module):
11 |
12 | def __init__(self, cfg):
13 | super().__init__()
14 | self.cfg = cfg
15 | self.image_embeddings = None
16 |
17 | def get_checkpoint(self, model_type):
18 | if model_type == "vit_b":
19 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_b_01ec64.pth")
20 | elif model_type == "vit_l":
21 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_l_0b3195.pth")
22 | elif model_type == "vit_h":
23 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_h_4b8939.pth")
24 |
25 | return checkpoint
26 |
27 | def setup(self):
28 | if self.cfg.model.type in ['vit_b','vit_l','vit_h']:
29 | checkpoint = self.get_checkpoint(self.cfg.model.type)
30 | self.model = sam_model_registry[self.cfg.model.type](checkpoint=checkpoint)
31 | self.base = 'sam'
32 | elif self.cfg.model.type in ["hiera_b",'hiera_l']:
33 | self.model = sam2_model_registry[self.cfg.model.type]()
34 | self.base = 'sam2'
35 | else:
36 | raise ValueError("Model type error!")
37 |
38 | # for param in self.model.parameters():
39 | # param.requires_grad = False
40 | if self.cfg.model.freeze.image_encoder:
41 | for param in self.model.image_encoder.parameters():
42 | param.requires_grad = False
43 | if self.cfg.model.freeze.prompt_encoder:
44 | try:
45 | for param in self.model.prompt_encoder.parameters():
46 | param.requires_grad = False
47 | except:
48 | for param in self.model.sam_prompt_encoder.parameters():
49 | param.requires_grad = False
50 |
51 | if self.cfg.model.freeze.mask_decoder:
52 | try:
53 | for param in self.model.mask_decoder.parameters():
54 | param.requires_grad = False
55 | except:
56 | for param in self.model.sam_prompt_encoder.parameters():
57 | param.requires_grad = False
58 |
59 | self.model.train()
60 | self.finetune()
61 |
62 | def finetune(self):
63 | LoRA_Sam(self.model, self.cfg.lora_rank, lora_layer=list(range(self.cfg.start_lora_layer, len(self.model.image_encoder.blocks))))
64 | # self.set_adapter_layer()
65 | # self.set_norm_layer()
66 | # print(self.model)
67 |
68 | def set_norm_layer(self):
69 | for name, param in self.model.image_encoder.named_parameters():
70 | if "norm" in name:
71 | param.requires_grad = True
72 |
73 | def set_adapter_layer(self):
74 | for block in self.model.image_encoder.blocks:
75 | if hasattr(block, "Space_Adapter"):
76 | for param in block.Space_Adapter.parameters():
77 | param.requires_grad = True
78 | if hasattr(block, "MLP_Adapter"):
79 | for param in block.MLP_Adapter.parameters():
80 | param.requires_grad = True
81 |
82 | def reset_parameters(self) -> None:
83 | for name, param in self.model.named_parameters():
84 | if param.requires_grad == True:
85 | if "linear_a" in name:
86 | nn.init.kaiming_uniform_(param, a=math.sqrt(5))
87 | if "linear_b" in name:
88 | nn.init.zeros_(param)
89 |
90 | def forward(self, images, prompts):
91 | _, _, H, W = images.shape#[n, 3, 1024, 1024]
92 |
93 | image_embeddings = self.model.image_encoder(images)
94 | pred_masks, ious, res_masks = self.decode((H, W), prompts, image_embeddings)
95 | return image_embeddings, pred_masks, ious, res_masks
96 |
97 | # def encode(self, images):
98 | # self.image_embeddings = self.model.image_encoder(images)
99 | # return self.image_embeddings
100 |
101 | def decode(self, image_shape, prompts, image_embeddings):
102 | if self.base == 'sam2':
103 | _bb_feat_sizes = [
104 | (256, 256),
105 | (128, 128),
106 | (64, 64),
107 | ]
108 | _, vision_feats, _, _ = self.model._prepare_backbone_features(image_embeddings)
109 |
110 | feats = [feat.permute(1, 2, 0).view(1, -1, *feat_size)
111 | for feat, feat_size in zip(vision_feats[::-1], _bb_feat_sizes[::-1])][::-1]
112 | self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
113 | image_embeddings = feats[-1]
114 | high_res_features = feats[:-1]
115 |
116 | if image_embeddings == None:
117 | raise "No image embeddings"
118 |
119 | pred_masks = []
120 | ious = []
121 | res_masks = []
122 | for prompt, embedding in zip(prompts, image_embeddings):
123 |
124 | if self.base =="sam":
125 | if isinstance(prompt, torch.Tensor):
126 | prompt = prompt.to(device=embedding.device)
127 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
128 | points=None,
129 | boxes=prompt,
130 | masks=None,
131 | )
132 | elif isinstance(prompt, tuple):
133 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
134 | points=prompt,
135 | boxes=None,
136 | masks=None,
137 | )
138 | low_res_masks, iou_predictions = self.model.mask_decoder(
139 | image_embeddings=embedding.unsqueeze(0),
140 | image_pe=self.model.prompt_encoder.get_dense_pe(),
141 | sparse_prompt_embeddings=sparse_embeddings,
142 | dense_prompt_embeddings=dense_embeddings,
143 | multimask_output=False,
144 | )
145 | else:
146 | if isinstance(prompt, torch.Tensor):
147 | prompt = prompt.to(device=embedding.device)
148 | sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
149 | points=None,
150 | boxes=prompt,
151 | masks=None,
152 | )
153 | elif isinstance(prompt, tuple):
154 | sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
155 | points=prompt,
156 | boxes=None,
157 | masks=None,
158 | )
159 | low_res_masks, iou_predictions,_,_ = self.model.sam_mask_decoder(
160 | image_embeddings=embedding.unsqueeze(0),
161 | image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
162 | sparse_prompt_embeddings=sparse_embeddings,
163 | dense_prompt_embeddings=dense_embeddings,
164 | multimask_output=False,
165 | repeat_image = True,
166 | high_res_features=high_res_features,
167 | )
168 |
169 | masks = F.interpolate(
170 | low_res_masks,
171 | image_shape,
172 | mode="bilinear",
173 | align_corners=False,
174 | )
175 | pred_masks.append(masks.squeeze(1))
176 | ious.append(iou_predictions)
177 | res_masks.append(low_res_masks)
178 | return pred_masks, ious, res_masks
--------------------------------------------------------------------------------
/utils/sam_lora.py:
--------------------------------------------------------------------------------
1 | # Sheng Wang at Apr 6 2023
2 | # What a time to be alive (first half of 2023)
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from torch.nn.parameter import Parameter
10 | from safetensors import safe_open
11 | from safetensors.torch import save_file
12 |
13 |
14 | class _LoRA_qkv(nn.Module):
15 | """In Sam it is implemented as
16 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
17 | B, N, C = x.shape
18 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
19 | q, k, v = qkv.unbind(0)
20 | """
21 |
22 | def __init__(
23 | self,
24 | qkv: nn.Module,
25 | linear_a_q: nn.Module,
26 | linear_b_q: nn.Module,
27 | linear_a_v: nn.Module,
28 | linear_b_v: nn.Module,
29 | ):
30 | super().__init__()
31 | self.qkv = qkv
32 | self.linear_a_q = linear_a_q
33 | self.linear_b_q = linear_b_q
34 | self.linear_a_v = linear_a_v
35 | self.linear_b_v = linear_b_v
36 | self.dim = qkv.in_features
37 | self.w_identity = torch.eye(qkv.in_features)
38 |
39 | def forward(self, x):
40 | # x: [25, 14, 14, 768]; self.qkv: Linear(in_features=768, out_features=2304, bias=True)
41 | qkv = self.qkv(x) # B,N,N,3*org_C
42 | new_q = self.linear_b_q(self.linear_a_q(x))
43 | new_v = self.linear_b_v(self.linear_a_v(x))
44 | qkv[:, :, :, : self.dim] += new_q
45 | qkv[:, :, :, -self.dim :] += new_v
46 | return qkv
47 |
48 |
49 | class LoRA(nn.Module):
50 | def __init__(self, *args, **kwargs) -> None:
51 | super().__init__(*args, **kwargs)
52 |
53 | def save_fc_parameters(self, filename: str) -> None:
54 | r"""Only safetensors is supported now.
55 |
56 | pip install safetensor if you do not have one installed yet.
57 | """
58 | assert filename.endswith(".safetensors")
59 | _in = self.lora_vit.head.in_features
60 | _out = self.lora_vit.head.out_features
61 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}
62 | save_file(fc_tensors, filename)
63 |
64 | def load_fc_parameters(self, filename: str) -> None:
65 | r"""Only safetensors is supported now.
66 |
67 | pip install safetensor if you do not have one installed yet.
68 | """
69 |
70 | assert filename.endswith(".safetensors")
71 | _in = self.lora_vit.head.in_features
72 | _out = self.lora_vit.head.out_features
73 | with safe_open(filename, framework="pt") as f:
74 | saved_key = f"fc_{_in}in_{_out}out"
75 | try:
76 | saved_tensor = f.get_tensor(saved_key)
77 | self.lora_vit.head.weight = Parameter(saved_tensor)
78 | except ValueError:
79 | print("this fc weight is not for this model")
80 |
81 | def save_lora_parameters(self, filename: str) -> None:
82 | r"""Only safetensors is supported now.
83 |
84 | pip install safetensor if you do not have one installed yet.
85 |
86 | save both lora and fc parameters.
87 | """
88 |
89 | assert filename.endswith(".safetensors")
90 |
91 | num_layer = len(self.w_As) # actually, it is half
92 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
93 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
94 |
95 | _in = self.lora_vit.head.in_features
96 | _out = self.lora_vit.head.out_features
97 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}
98 |
99 | merged_dict = {**a_tensors, **b_tensors, **fc_tensors}
100 | save_file(merged_dict, filename)
101 |
102 | def load_lora_parameters(self, filename: str) -> None:
103 | r"""Only safetensors is supported now.
104 |
105 | pip install safetensor if you do not have one installed yet.\
106 |
107 | load both lora and fc parameters.
108 | """
109 |
110 | assert filename.endswith(".safetensors")
111 |
112 | with safe_open(filename, framework="pt") as f:
113 | for i, w_A_linear in enumerate(self.w_As):
114 | saved_key = f"w_a_{i:03d}"
115 | saved_tensor = f.get_tensor(saved_key)
116 | w_A_linear.weight = Parameter(saved_tensor)
117 |
118 | for i, w_B_linear in enumerate(self.w_Bs):
119 | saved_key = f"w_b_{i:03d}"
120 | saved_tensor = f.get_tensor(saved_key)
121 | w_B_linear.weight = Parameter(saved_tensor)
122 |
123 | _in = self.lora_vit.head.in_features
124 | _out = self.lora_vit.head.out_features
125 | saved_key = f"fc_{_in}in_{_out}out"
126 | try:
127 | saved_tensor = f.get_tensor(saved_key)
128 | self.lora_vit.head.weight = Parameter(saved_tensor)
129 | except ValueError:
130 | print("this fc weight is not for this model")
131 |
132 | def reset_parameters(self) -> None:
133 | for w_A in self.w_As:
134 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
135 | for w_B in self.w_Bs:
136 | nn.init.zeros_(w_B.weight)
137 |
138 |
139 | class LoRA_Sam(LoRA):
140 | """Applies low-rank adaptation to a Sam model's image encoder.
141 |
142 | Args:
143 | sam_model: a vision transformer model, see base_vit.py
144 | r: rank of LoRA
145 | num_classes: how many classes the model output, default to the vit model
146 | lora_layer: which layer we apply LoRA.
147 |
148 | Examples::
149 | >>> model = ViT('B_16_imagenet1k')
150 | >>> lora_model = LoRA_ViT(model, r=4)
151 | >>> preds = lora_model(img)
152 | >>> print(preds.shape)
153 | torch.Size([1, 1000])
154 | """
155 |
156 | def __init__(self, sam_model, r: int, lora_layer=None):
157 | super(LoRA_Sam, self).__init__()
158 |
159 | assert r > 0
160 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
161 | # dim = base_vit_dim
162 | if lora_layer:
163 | self.lora_layer = lora_layer
164 | flag = 0
165 | else:
166 | try:
167 | self.lora_layer = list(range(len(sam_model.image_encoder.blocks)))
168 | flag = 0
169 | except:
170 | self.lora_layer = list(range(len(sam_model.image_encoder.trunk.blocks)))
171 | flag = 1
172 | # create for storage, then we can init them or load weights
173 | self.w_As = [] # These are linear layers
174 | self.w_Bs = []
175 |
176 | # lets freeze first
177 | for param in sam_model.image_encoder.parameters():
178 | param.requires_grad = False
179 | if flag ==0:
180 | blocks = sam_model.image_encoder.blocks
181 | elif flag ==1:
182 | blocks = sam_model.image_encoder.trunk.blocks
183 | # Here, we do the surgery
184 | for t_layer_i, blk in enumerate(blocks):
185 | # If we only want few lora layer instead of all
186 | if t_layer_i not in self.lora_layer:
187 | continue
188 | w_qkv_linear = blk.attn.qkv
189 |
190 | self.dim = w_qkv_linear.in_features
191 | w_a_linear_q = nn.Linear(self.dim, r, bias=False)
192 | w_b_linear_q = nn.Linear(r, self.dim, bias=False)
193 | w_a_linear_v = nn.Linear(self.dim, r, bias=False)
194 | w_b_linear_v = nn.Linear(r, self.dim, bias=False)
195 | self.w_As.append(w_a_linear_q)
196 | self.w_Bs.append(w_b_linear_q)
197 | self.w_As.append(w_a_linear_v)
198 | self.w_Bs.append(w_b_linear_v)
199 | blk.attn.qkv = _LoRA_qkv(
200 | w_qkv_linear,
201 | w_a_linear_q,
202 | w_b_linear_q,
203 | w_a_linear_v,
204 | w_b_linear_v,
205 | )
206 | self.reset_parameters()
207 | # self.sam = sam_model
208 | self.lora_vit = sam_model
209 |
210 |
--------------------------------------------------------------------------------
/utils/sample_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.cluster import KMeans
4 |
5 |
6 | def uniform_sampling(masks, N=1):
7 | n_points = []
8 | for i, mask in enumerate(masks):
9 | if not isinstance(mask, np.ndarray):
10 | mask = mask.cpu().numpy()
11 | indices = np.argwhere(mask == 1) # [y, x]
12 | sampled_indices = np.random.choice(len(indices), N, replace=True)
13 | sampled_points = np.flip(indices[sampled_indices], axis=1)
14 | n_points.append(sampled_points.tolist())
15 |
16 | return n_points
17 |
18 |
19 | def get_multi_distance_points(input_point, mask, points_nubmer):
20 | new_points = np.zeros((points_nubmer + 1, 2))
21 | new_points[0] = [input_point[1], input_point[0]]
22 | for i in range(points_nubmer):
23 | new_points[i + 1] = get_next_distance_point(new_points[:i + 1, :], mask)
24 |
25 | new_points = swap_xy(new_points)
26 | return new_points
27 |
28 |
29 | def get_next_distance_point(input_points, mask):
30 | max_distance_point = [0, 0]
31 | max_distance = 0
32 | input_points = np.array(input_points)
33 |
34 | indices = np.argwhere(mask == True)
35 | for x, y in indices:
36 | # print(x,y,input_points)
37 | distance = np.sum(np.sqrt((x - input_points[:, 0]) ** 2 + (y - input_points[:, 1]) ** 2))
38 | if max_distance < distance:
39 | max_distance_point = [x, y]
40 | max_distance = distance
41 | return max_distance_point
42 |
43 |
44 | def swap_xy(points):
45 | new_points = np.zeros((len(points),2))
46 | new_points[:,0] = points[:,1]
47 | new_points[:,1] = points[:,0]
48 | return new_points
49 |
50 |
51 | def k_means_sampling(mask, k):
52 | points = np.argwhere(mask == 1) # [y, x]
53 | points = np.flip(points, axis=1)
54 |
55 | kmeans = KMeans(n_clusters=k)
56 | kmeans.fit(points)
57 | points = kmeans.cluster_centers_
58 | return points
59 |
60 |
61 | def get_point_prompt_max_dist(masks, num_points):
62 | n_points = []
63 | for mask in masks:
64 | mask_np = mask.cpu().numpy()
65 |
66 | indices = np.argwhere(mask_np > 0)
67 | random_index = np.random.choice(len(indices), 1)[0]
68 |
69 | first_point = [indices[random_index][1], indices[random_index][0]]
70 | new_points = get_multi_distance_points(first_point, mask_np, num_points - 1)
71 | n_points.append(new_points)
72 |
73 | return n_points
74 |
75 |
76 | def get_point_prompt_kmeans(masks, num_points):
77 | n_points = []
78 | for mask in masks:
79 | mask_np = mask.cpu().numpy()
80 | points = k_means_sampling(mask_np, num_points)
81 | n_points.append(points.astype(int))
82 | return n_points
83 |
84 |
85 | def get_point_prompts(gt_masks, num_points):
86 | prompts = []
87 | # print('prompt',len(gt_masks))
88 | for mask in gt_masks:
89 | # print('prompt',len(mask))
90 | po_points = uniform_sampling(mask, num_points)
91 | # print('ori_po',po_points)
92 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points)
93 | # print('na_points',na_points)
94 | # print('na_points',na_points)
95 | po_point_coords = torch.tensor(po_points, device=mask.device)
96 | na_point_coords = torch.tensor(na_points, device=mask.device)
97 | # print('po_point_coords',po_point_coords.shape)
98 | # print('na_point_coords',na_point_coords.shape)
99 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1)
100 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=po_point_coords.device)
101 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=na_point_coords.device)
102 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1)
103 | in_points = (point_coords, point_labels)
104 | prompts.append(in_points)
105 | return prompts
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | import copy
5 | import numpy as np
6 | from torchsummary import summary
7 | from torch import Tensor
8 |
9 | def freeze(model: torch.nn.Module):
10 | model.eval()
11 | for param in model.parameters():
12 | param.requires_grad = False
13 |
14 |
15 | def momentum_update(student_model, teacher_model, momentum=0.99):
16 | for (src_name, src_param), (tgt_name, tgt_param) in zip(
17 | student_model.named_parameters(), teacher_model.named_parameters()
18 | ):
19 | if src_param.requires_grad:
20 | # print('src_name',src_name)
21 | # print('tgt_name',tgt_name)
22 | tgt_param.data.mul_(momentum).add_(src_param.data, alpha=1 - momentum)
23 |
24 |
25 | def decode_mask(mask):
26 | """
27 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects
28 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects.
29 |
30 | Args:
31 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
32 |
33 | Returns:
34 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
35 | """
36 | unique_labels = torch.unique(mask)
37 | unique_labels = unique_labels[unique_labels != 0]
38 | n_objects = len(unique_labels)
39 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64)
40 | for i, label in enumerate(unique_labels):
41 | new_mask[i] = (mask == label).squeeze(0)
42 | return new_mask
43 |
44 |
45 | def encode_mask(mask):
46 | """
47 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects
48 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
49 |
50 | Args:
51 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
52 |
53 | Returns:
54 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
55 | """
56 | n_objects = mask.shape[0]
57 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64)
58 | for i in range(n_objects):
59 | new_mask[0][mask[i] == 1] = i + 1
60 | return new_mask
61 |
62 |
63 | def copy_model(model: torch.nn.Module):
64 | new_model = copy.deepcopy(model)
65 | freeze(new_model)
66 | return new_model
67 |
68 |
69 | def create_csv(filename, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]):
70 | if os.path.exists(filename):
71 | return
72 | with open(filename, 'w') as csvfile:
73 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head)
74 | csv_write.writeheader()
75 |
76 |
77 | def write_csv(filename, csv_dict, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]):
78 | with open(filename, 'a+') as csvfile:
79 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head, extrasaction='ignore')
80 | csv_write.writerow(csv_dict)
81 |
82 |
83 | def check_grad(model: torch.nn.Module):
84 | for name, param in model.named_parameters():
85 | print(f"{name}: {param.requires_grad}")
86 |
87 |
88 | def check_model(model):
89 | return summary(model, (3, 1024, 1024), batch_size=1, device="cuda")
90 |
91 |
92 | def reduce_instances(bboxes, gt_masks, max_nums=50):
93 | bboxes_ = []
94 | gt_masks_ = []
95 | for bbox, gt_mask in zip(bboxes, gt_masks):
96 | idx = np.arange(bbox.shape[0])
97 | np.random.shuffle(idx)
98 | bboxes_.append(bbox[idx[:max_nums]])
99 | gt_masks_.append(gt_mask[idx[:max_nums]])
100 |
101 | bboxes = bboxes_
102 | gt_masks = gt_masks_
103 | return bboxes, gt_masks
104 |
105 | def _to_cpu(data):
106 | """transfer all tensors to cpu."""
107 | if isinstance(data, Tensor):
108 | return data.to('cpu')
109 | elif isinstance(data, list):
110 | return [_to_cpu(d) for d in data]
111 | elif isinstance(data, tuple):
112 | return tuple(_to_cpu(d) for d in data)
113 | elif isinstance(data, dict):
114 | return {k: _to_cpu(v) for k, v in data.items()}
115 | else:
116 | return data
117 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import deque
3 | from tqdm import tqdm
4 | from box import Box
5 | import numpy as np
6 | from scipy.optimize import linear_sum_assignment
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from .finch import FINCH
13 | from .sample_utils import get_point_prompts
14 |
15 | class Store:
16 | def __init__(self, total_num_classes, items_per_class, shuffle=False):
17 | self.shuffle = shuffle
18 | self.items_per_class = items_per_class
19 | self.total_num_classes = total_num_classes
20 | self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)]
21 |
22 | def add(self, items, class_ids):
23 | for idx, class_id in enumerate(class_ids):
24 | self.store[class_id].append(items[idx])
25 |
26 | def retrieve(self, class_id):
27 | if class_id != -1:
28 | items = []
29 | for item in self.store[class_id]:
30 | items.extend(list(item))
31 | if self.shuffle:
32 | random.shuffle(items)
33 | return items
34 | else:
35 | all_items = []
36 | for i in range(self.total_num_classes):
37 | items = []
38 | for item in self.store[i]:
39 | items.append(list(item))
40 | all_items.append(items)
41 | return all_items
42 |
43 | def reset(self):
44 | self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)]
45 |
46 | def __str__(self):
47 | s = self.__class__.__name__ + '('
48 | for idx, item in enumerate(self.store):
49 | s += '\n Class ' + str(idx) + ' --> ' + str(len(list(item))) + ' items'
50 | s = s + ' )'
51 | return s
52 |
53 | def __repr__(self):
54 | return self.__str__()
55 |
56 | def __len__(self):
57 | return sum([len(s) for s in self.store])
58 |
59 | def concatenate_images_with_padding(images, padding=10, color=(255, 255, 255)):
60 | heights = [image.shape[0] for image in images]
61 | widths = [image.shape[1] for image in images]
62 |
63 | total_width = sum(widths) + padding * (len(images) - 1)
64 | max_height = max(heights)
65 |
66 | if len(images[0].shape) == 3:
67 | new_image = np.full((max_height, total_width, 3), color, dtype=np.uint8)
68 | else:
69 | new_image = np.full((max_height, total_width), color[0], dtype=np.uint8)
70 |
71 | x_offset = 0
72 | for image in images:
73 | new_image[0:image.shape[0], x_offset:x_offset + image.shape[1]] = image
74 | x_offset += image.shape[1] + padding
75 |
76 | return new_image
77 |
78 | def calculate_iou(mask1, mask2):
79 | intersection = torch.logical_and(mask1, mask2)
80 | union = torch.logical_or(mask1, mask2)
81 | iou = torch.sum(intersection).float() / torch.sum(union).float()
82 | return iou
83 |
84 | def calc_iou_matrix(mask_list1, mask_list2):
85 | iou_matrix = torch.zeros((len(mask_list1), len(mask_list2)))
86 | for i, mask1 in enumerate(mask_list1):
87 | for j, mask2 in enumerate(mask_list2):
88 | iou_matrix[i, j] = calculate_iou(mask1, mask2)
89 | return iou_matrix
90 |
91 | def cal_mask_ious(
92 | cfg,
93 | model,
94 | images_weak,
95 | prompts,
96 | gt_masks,
97 | ):
98 | with torch.no_grad():
99 | _, soft_masks, _, _ = model(images_weak, prompts)
100 |
101 | for i, (soft_mask, gt_mask) in enumerate(zip(soft_masks, gt_masks)):
102 | soft_mask = (soft_mask > 0).float()
103 | mask_ious = calc_iou_matrix(soft_mask, soft_mask)
104 | indices = torch.arange(mask_ious.size(0))
105 | mask_ious[indices, indices] = 0.0
106 | return mask_ious, soft_mask
107 |
108 |
109 | def neg_prompt_calibration(
110 | cfg,
111 | mask_ious,
112 | prompts,
113 | ):
114 | '''
115 | mask_ious:[mask_nums,mask_nums]
116 | '''
117 | point_list = []
118 | point_labels_list = []
119 | num_points = cfg.num_points
120 | for m in range(len(mask_ious)):
121 |
122 | pos_point_coords = prompts[0][0][m][:num_points].unsqueeze(0)
123 | neg_point = prompts[0][0][m][num_points:].unsqueeze(0)
124 | neg_points_list = []
125 | neg_points_list.extend(neg_point[0])
126 |
127 | indices = torch.nonzero(mask_ious[m] > float(cfg.iou_thr)).squeeze(1)
128 |
129 | if indices.numel() != 0:
130 | # neg_points_list = []
131 | for indice in indices:
132 | neg_points_list.extend(prompts[0][0][indice][:num_points])
133 | neg_points = random.sample(neg_points_list, num_points)
134 | else:
135 | neg_points =neg_points_list
136 |
137 | neg_point_coords = torch.tensor([p.tolist() for p in neg_points], device=neg_point.device).unsqueeze(0)
138 |
139 | point_coords = torch.cat((pos_point_coords, neg_point_coords), dim=1)
140 |
141 | point_list.append(point_coords)
142 | pos_point_labels = torch.ones(pos_point_coords.shape[0:2], dtype=torch.int, device=neg_point.device)
143 | neg_point_labels = torch.zeros(neg_point_coords.shape[0:2], dtype=torch.int, device=neg_point.device)
144 | point_labels = torch.cat((pos_point_labels, neg_point_labels), dim=1)
145 | point_labels_list.append(point_labels)
146 |
147 | point_ = torch.cat(point_list).squeeze(1)
148 | point_labels_ = torch.cat(point_labels_list)
149 | new_prompts = [(point_, point_labels_)]
150 | return new_prompts
151 |
152 | def get_prompts(cfg: Box, bboxes, gt_masks):
153 | if cfg.prompt == "box" or cfg.prompt == "coarse":
154 | prompts = bboxes
155 | elif cfg.prompt == "point":
156 | prompts = get_point_prompts(gt_masks, cfg.num_points)
157 | else:
158 | raise ValueError("Prompt Type Error!")
159 | return prompts
160 |
161 | def generate_predict_feats(cfg, embed, pseudo_label, gts):
162 | coords, lbls = gts
163 | selected_coords = []
164 |
165 | num_insts = len(pseudo_label)
166 | num_points = cfg.num_points
167 | for coord_grp, lbl_grp in zip(coords, lbls):
168 | for coord, lbl in zip(coord_grp, lbl_grp):
169 | if lbl.item() == 1:
170 | selected_coords.append(coord.tolist())
171 |
172 | # Downsample coordinates (SAM's stride is 16)
173 | coords = [[int(c // 16) for c in pair] for pair in selected_coords]
174 |
175 | embed = embed.permute(1, 2, 0) # [H, W, C]
176 |
177 | pos_pts = []
178 |
179 | for index in range(0, num_insts * num_points, num_points):
180 | index = random.randint(0, num_points - 1)
181 | x, y = coords[index]
182 | pos_pt = embed[x, y]
183 | pos_pts.append(pos_pt)
184 |
185 | predict_feats = torch.stack(pos_pts, dim=0)
186 |
187 | return predict_feats
188 |
189 |
190 | def offline_prototypes_generation(cfg, model, loader):
191 | model.eval()
192 | pts = []
193 | max_iters = 128
194 | num_points = cfg.num_points
195 |
196 | with torch.no_grad():
197 | for i, batch in enumerate(tqdm(loader, desc='Generating target prototypes', ncols=100)):
198 | if i >= max_iters:
199 | break
200 | imgs, boxes, masks, _ = batch
201 | prompts = get_prompts(cfg, boxes, masks)
202 |
203 | embeds, masks, _, _ = model(imgs, prompts)
204 | del _
205 |
206 | if isinstance(embeds, dict):
207 | embeds = embeds['vision_features']
208 |
209 | for embed, prompt, mask in zip(embeds, prompts, masks):
210 | num_insts = len(mask)
211 | embed = embed.permute(1, 2, 0) # [H, W, C]
212 | coords = []
213 |
214 | points, labels = prompt
215 | for point_grp, label_grp in zip(points, labels):
216 | for point, label in zip(point_grp, label_grp):
217 | if label.item() == 1:
218 | coords.append(point.tolist())
219 |
220 | # 16 is the stride of SAM
221 | coords = [[int(pt / 16) for pt in pair] for pair in coords]
222 | for index in range(0, num_insts*num_points, num_points):
223 | x, y = coords[index]
224 | pt = embed[x,y]
225 | pts.append(pt)
226 |
227 | fin = FINCH(verbose=True)
228 | pts = torch.stack(pts).cpu().numpy()
229 | res = fin.fit(pts)
230 |
231 | last_key = list(res.partitions.keys())[-1]
232 | pt_stats = {'target_pts': res.partitions[last_key]['cluster_centers']}
233 | return pt_stats
234 |
235 |
236 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | from torchvision.utils import draw_bounding_boxes
5 | from torchvision.utils import draw_segmentation_masks
6 |
7 | from box import Box
8 | from tqdm import tqdm
9 | from model import Model
10 | from datasets.COCO import COCODataset
11 |
12 |
13 | def draw_image(image, masks, boxes, labels, alpha=0.4):
14 | image = torch.from_numpy(image).permute(2, 0, 1)
15 | if boxes is not None:
16 | image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2)
17 | if masks is not None:
18 | image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha)
19 | return image.numpy().transpose(1, 2, 0)
20 |
21 |
22 | def visualize(cfg: Box):
23 | model = Model(cfg)
24 | model.setup()
25 | model.eval()
26 | model.cuda()
27 | dataset = COCODataset(root_dir=cfg.dataset.val.root_dir,
28 | annotation_file=cfg.dataset.val.annotation_file,
29 | transform=None)
30 | predictor = model.get_predictor()
31 | os.makedirs(cfg.out_dir, exist_ok=True)
32 |
33 | for image_id in tqdm(dataset.image_ids):
34 | image_info = dataset.coco.loadImgs(image_id)[0]
35 | image_path = os.path.join(dataset.root_dir, image_info['file_name'])
36 | image_output_path = os.path.join(cfg.out_dir, image_info['file_name'])
37 | image = cv2.imread(image_path)
38 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39 | ann_ids = dataset.coco.getAnnIds(imgIds=image_id)
40 | anns = dataset.coco.loadAnns(ann_ids)
41 | bboxes = []
42 | for ann in anns:
43 | x, y, w, h = ann['bbox']
44 | bboxes.append([x, y, x + w, y + h])
45 | bboxes = torch.as_tensor(bboxes, device=model.model.device)
46 | transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2])
47 | predictor.set_image(image)
48 | masks, _, _ = predictor.predict_torch(
49 | point_coords=None,
50 | point_labels=None,
51 | boxes=transformed_boxes,
52 | multimask_output=False,
53 | )
54 | image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None)
55 | cv2.imwrite(image_output_path, image_output)
56 |
57 |
--------------------------------------------------------------------------------