├── .gitignore
├── INSTALL.md
├── LICENSE
├── PREPARE.md
├── README.md
├── adaptation.py
├── asserts
├── Pipeline.webp
├── VISUAL.webp
└── teaser.webp
├── configs
├── base_config.py
└── config.py
├── datasets
├── CAMO.py
├── COCO.py
├── COCONut.py
├── COD10K.py
├── GDD.py
├── ISIC.py
├── ISTD.py
├── MSD.py
├── OCID.py
├── OSD.py
├── PascalVOC.py
├── Polyp.py
├── SA_1B.py
├── __init__.py
└── tools.py
├── losses.py
├── model.py
├── requirements.txt
├── sam_lora.py
├── segment_anything
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ └── transforms.py
├── utils
├── eval_utils.py
├── sample_utils.py
├── tools.py
└── visualize.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | data
7 | output
8 | *.sh
9 |
10 | # C extensions
11 | *.so
12 | *.pth
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 |
4 | ### Requirements
5 | - Linux or macOS with Python ≥ 3.8
6 | - PyTorch ≥ 1.13.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
7 | - Install pytorch [lightning](https://lightning.ai/pytorch-lightning) that matches the PyTorch installation.
8 | - `pip install -r requirements.txt`
9 |
10 |
11 | ### Example conda environment setup
12 | ```bash
13 | conda create --name wesam python=3.8
14 | conda activate wesam
15 |
16 | # CUDA 11.7
17 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
18 |
19 | git clone https://github.com/zhang-haojie/wesam.git
20 | cd wesam
21 | pip install -r requirements.txt
22 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Haojie Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/PREPARE.md:
--------------------------------------------------------------------------------
1 | ## Prepare
2 |
3 |
4 | ### Download Dataset
5 |
6 | *Natural images*
7 |
8 | - [COCO Dataset](https://cocodataset.org/)
9 |
10 | - [PascalVOC Dataset](http://host.robots.ox.ac.uk/pascal/VOC/)
11 |
12 |
13 | *Medical images*
14 |
15 | - [Kvasir-SEG Dataset](https://datasets.simula.no/kvasir-seg/)
16 |
17 | - [ISIC Dataset](https://challenge.isic-archive.com/data/)
18 |
19 |
20 | *Camouflaged Objects*
21 |
22 | - [COD10k Dataset](https://drive.google.com/file/d/1pVq1rWXCwkMbEZpTt4-yUQ3NsnQd_DNY/view?usp=sharing) - [Camouflaged Object Detection](https://github.com/DengPingFan/SINet/)
23 |
24 | - [CAMO](https://drive.google.com/open?id=1h-OqZdwkuPhBvGcVAwmh0f1NGqlH_4B6) - [Project](https://sites.google.com/view/ltnghia/research/camo)
25 |
26 |
27 | *Robotic Images*
28 |
29 | - [OCID](https://www.acin.tuwien.ac.at/en/vision-for-robotics/software-tools/object-clutter-indoor-dataset/): *Object Clutter Indoor Dataset*
30 |
31 | - [OSD](https://www.acin.tuwien.ac.at/en/vision-for-robotics/software-tools/osd/): *Object Segmentation Database*
32 |
33 |
34 | *Corrupted Images*
35 |
36 | In `datasets/COCO.py`, uncomment the line that includes `corrupt_image`. Then comment line 192 of `adaptation.py` and run it.
37 |
38 | ```
39 | def __getitem__(self, idx):
40 | image_id = self.image_ids[idx]
41 | image_info = self.coco.loadImgs(image_id)[0]
42 | image_path = os.path.join(self.root_dir, image_info["file_name"])
43 | if self.cfg.corrupt in self.cfg.corruptions:
44 | image_path = image_path.replace("val2017", os.path.join("corruption", self.cfg.corrupt))
45 | image = cv2.imread(image_path)
46 |
47 | # corrupt_image(image, image_path)
48 |
49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50 | ```
51 |
52 |
53 | *Glass*
54 | - [GDD](http://gdd.dluticcd.com/) -
55 | [Don't Hit Me! Glass Detection in Real-world Scenes](https://github.com/Charmve/Mirror-Glass-Detection/tree/master/CVPR2020_GDNet)
56 | (*Glass Detection Dataset* Need Apply!)
57 |
58 | - [GSD](https://drive.google.com/file/d/1pSEUs-8I-4YHOTJ9J0wxiEpkdbHjMQQo/view?usp=sharing) -
59 | [Exploiting Semantic Relations for Glass Surface Detection](https://jiaying.link/neurips2022-gsds/)
60 | (*Glass Surface Dataset*)
61 |
62 |
63 | *Mirror*
64 | - [MSD](https://drive.google.com/file/d/1Znw92fO6lCKfXejjSSyMyL1qtFepgjPI/view?usp=sharing) -
65 | [Where is My Mirror?](https://github.com/Charmve/Mirror-Glass-Detection/tree/master/ICCV2019_MirrorNet)
66 | (*mirror segmentation dataset* Need Apply!)
67 |
68 |
69 | *Shadow Detection*
70 | - [ISTD](https://drive.google.com/file/d/1I0qw-65KBA6np8vIZzO6oeiOvcDBttAY/view) -
71 | [tacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal](https://github.com/DeepInsight-PCALab/ST-CGAN)
72 | (*Image Shadow Triplets dataset* Need Apply!)
73 |
74 |
75 | ### Download Checkpoints
76 |
77 | Click the links below to download the checkpoint for the corresponding model type.
78 |
79 | - `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)
80 | - `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
81 | - `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
82 |
83 |
84 | ### Prepare
85 |
86 | ```
87 | cd wesam/
88 |
89 | mkdir data
90 | mkdir checkpoints
91 |
92 | mv DATASETS ./data
93 |
94 | mv VIT_B ./checkpoints
95 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Improving the Generalization of Segmentation Foundation Model under Distribution Shift via Weakly Supervised Adaptation
4 |
5 |

6 |

7 |
8 |
9 |
10 |
11 | ## 🎈 News
12 |
13 | - [2024.2.27] Our work has been accepted to CVPR 2024 🎉
14 | - [2024.3.1] Training and inference code released
15 |
16 | ## 🚀 Introduction
17 |
18 |
19 |

20 |
21 |
22 | Segment Anything Model was pre-trained on a large-scale dataset but exhibits awkward performance on diverse downstream segmentation tasks. We adapt SAM through weak supervision to enhance its generalization capabilities.
23 |
24 |
25 | ## 📻 Overview
26 |
27 |
28 |

29 |
30 |
31 | The proposed self-training architecture with anchor network regularization and contrastive loss regularization. Red arrows indicates the backpropagation flow.
32 |
33 |
34 | ## 📆 TODO
35 |
36 | - [x] Release code
37 |
38 | ## 🎮 Getting Started
39 |
40 | ### 1. Install Environment
41 |
42 | see [INSTALL](INSTALL.md).
43 |
44 | ### 2. Prepare Dataset and Checkpoints
45 |
46 | see [PREPARE](PREPARE.md).
47 |
48 | ### 3. Adapt with Weak Supervision
49 |
50 | ```
51 | # 1 modify configs/config.py
52 | # Prompt type: box, point, coarse
53 |
54 | # 2 adapt
55 | python adaptation.py
56 | ```
57 |
58 | ### 4. Validation
59 |
60 | ```
61 | python validate.py --ckpt /path/to/checkpoint
62 | ```
63 |
64 |
65 | ## 🖼️ Visualization
66 |
67 |
68 |

69 |
70 |
71 |
72 | ## 🎫 License
73 |
74 | The content of this project itself is licensed under [LICENSE](LICENSE).
75 |
76 | ## 💡 Acknowledgement
77 |
78 | - [SAM](https://github.com/facebookresearch/segment-anything)
79 |
80 | - [lightning-sam](https://github.com/luca-medeiros/lightning-sam)
81 |
82 | - [SAM-LoRA](https://github.com/JamesQFreeman/Sam_LoRA)
83 |
84 | ## 🖊️ Citation
85 |
86 | If you find this project useful in your research, please consider cite:
87 |
88 | ```BibTeX
89 | @inproceedings{zhang2024improving,
90 | title={Improving the generalization of segmentation foundation model under distribution shift via weakly supervised adaptation},
91 | author={Zhang, Haojie and Su, Yongyi and Xu, Xun and Jia, Kui},
92 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
93 | pages={23385--23395},
94 | year={2024}
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/adaptation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import time
4 | import torch
5 | import lightning as L
6 | import torch.nn.functional as F
7 | import segmentation_models_pytorch as smp
8 | from box import Box
9 | from lightning.fabric.fabric import _FabricOptimizer
10 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
11 | from torch.utils.data import DataLoader
12 |
13 | from configs.config import cfg
14 | from losses import DiceLoss, FocalLoss, ContraLoss
15 | from datasets import call_load_dataset
16 |
17 | from model import Model
18 | from sam_lora import LoRA_Sam
19 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts
20 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances
21 |
22 |
23 | def train_sam(
24 | cfg: Box,
25 | fabric: L.Fabric,
26 | model: Model,
27 | anchor_model: Model,
28 | optimizer: _FabricOptimizer,
29 | scheduler: _FabricOptimizer,
30 | train_dataloader: DataLoader,
31 | val_dataloader: DataLoader,
32 | num_iters: int,
33 | ):
34 | """The SAM training loop."""
35 | batch_time = AverageMeter()
36 | data_time = AverageMeter()
37 | focal_losses = AverageMeter()
38 | dice_losses = AverageMeter()
39 | iou_losses = AverageMeter()
40 | anchor_losses = AverageMeter()
41 | contra_losses = AverageMeter()
42 | total_losses = AverageMeter()
43 | focal_loss = FocalLoss()
44 | dice_loss = DiceLoss()
45 | contra_loss = ContraLoss()
46 | end = time.time()
47 | max_iou = 0.
48 | num_epochs = cfg.num_iters // num_iters + 1
49 |
50 | for epoch in range(1, num_epochs):
51 |
52 | for iter, data in enumerate(train_dataloader):
53 |
54 | data_time.update(time.time() - end)
55 | images_weak, images_strong, bboxes, gt_masks = data
56 | batch_size = images_weak.size(0)
57 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks)
58 | if num_insts > cfg.max_nums:
59 | print(num_insts)
60 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums)
61 |
62 | prompts = get_prompts(cfg, bboxes, gt_masks)
63 |
64 | with torch.no_grad():
65 | anchor_image_embeds, anchor_masks, anchor_iou_predictions, anchor_res_masks = anchor_model(images_weak, prompts)
66 |
67 | soft_image_embeds, soft_masks, soft_iou_predictions, soft_res_masks = model(images_weak, prompts) # teacher
68 | pred_image_embeds, pred_masks, iou_predictions, pred_res_masks = model(images_strong, prompts) # student
69 |
70 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
71 | loss_focal = torch.tensor(0., device=fabric.device)
72 | loss_dice = torch.tensor(0., device=fabric.device)
73 | loss_iou = torch.tensor(0., device=fabric.device)
74 | loss_anchor = torch.tensor(0., device=fabric.device)
75 | loss_contra = torch.tensor(0., device=fabric.device)
76 |
77 | for i, (pred_mask, soft_mask, anchor_mask, iou_prediction) in enumerate(zip(pred_masks, soft_masks, anchor_masks, iou_predictions)):
78 | anchor_mask = (anchor_mask > 0.).float()
79 | loss_contra += contra_loss(soft_image_embeds[i], anchor_image_embeds[i], soft_res_masks[i].clone().detach(), anchor_res_masks[i].clone().detach())
80 | # loss_contra += contra_loss(pred_image_embeds[i], anchor_image_embeds[i], pred_res_masks[i].clone().detach(), anchor_res_masks[i].clone().detach())
81 |
82 | loss_anchor += (0.5 * dice_loss(pred_mask, anchor_mask) + 0.5 * dice_loss(soft_mask, anchor_mask))
83 |
84 | soft_mask = (soft_mask > 0.).float()
85 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks)
86 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks)
87 | batch_iou = calc_iou(pred_mask, soft_mask)
88 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
89 |
90 | loss_total = 20. * loss_focal + loss_dice + loss_iou + loss_anchor + loss_contra
91 | fabric.backward(loss_total)
92 |
93 | optimizer.step()
94 | scheduler.step()
95 | optimizer.zero_grad()
96 | torch.cuda.empty_cache()
97 |
98 | batch_time.update(time.time() - end)
99 | end = time.time()
100 |
101 | # momentum_update(model, anchor_model, momentum=cfg.ema_rate)
102 |
103 | focal_losses.update(loss_focal.item(), batch_size)
104 | dice_losses.update(loss_dice.item(), batch_size)
105 | iou_losses.update(loss_iou.item(), batch_size)
106 | anchor_losses.update(loss_anchor.item(), batch_size)
107 | contra_losses.update(loss_contra.item(), batch_size)
108 | total_losses.update(loss_total.item(), batch_size)
109 |
110 | fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
111 | f' | Dataset: [{cfg.dataset} - {cfg.prompt}]'
112 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
113 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
114 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
115 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
116 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'
117 | f' | Anchor Loss [{anchor_losses.val:.4f} ({anchor_losses.avg:.4f})]'
118 | f' | Contrast Loss [{contra_losses.val:.4f} ({contra_losses.avg:.4f})]'
119 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
120 |
121 | loss_logger = {"Focal Loss": focal_losses.avg, "Dice Loss": dice_losses.avg,
122 | "IoU Loss": iou_losses.avg, "Anchor Loss": anchor_losses.avg,
123 | "Contrast Loss": contra_losses.avg, "Total Loss": total_losses.avg}
124 | fabric.log_dict(loss_logger)
125 | torch.cuda.empty_cache()
126 |
127 | if epoch % cfg.eval_interval == 0:
128 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch * num_iters)
129 | if iou > max_iou:
130 | state = {"model": model, "optimizer": optimizer}
131 | fabric.save(os.path.join(cfg.out_dir, "save", f"{cfg.dataset}-{cfg.prompt}-last-ckpt.pth"), state)
132 | max_iou = iou
133 |
134 |
135 | def configure_opt(cfg: Box, model: Model):
136 |
137 | def lr_lambda(step):
138 | if step < cfg.opt.warmup_steps:
139 | return step / cfg.opt.warmup_steps
140 | elif step < cfg.opt.steps[0]:
141 | return 1.0
142 | elif step < cfg.opt.steps[1]:
143 | return 1 / cfg.opt.decay_factor
144 | else:
145 | return 1 / (cfg.opt.decay_factor**2)
146 |
147 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)
148 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
149 |
150 | return optimizer, scheduler
151 |
152 |
153 | def corrupt_main(cfg):
154 | for corrupt in cfg.corruptions:
155 | cfg.corrupt = corrupt
156 | cfg.name = corrupt
157 | torch.cuda.empty_cache()
158 | main(cfg)
159 |
160 |
161 | def multi_main(cfg):
162 | prompts = ["box", "point"]
163 | for prompt in prompts:
164 | cfg.prompt = prompt
165 | torch.cuda.empty_cache()
166 | main(cfg)
167 |
168 |
169 | def main(cfg: Box, ckpt: str = None) -> None:
170 | gpu_ids = cfg.gpu_ids.split(',')
171 | num_devices = len(gpu_ids)
172 |
173 | fabric = L.Fabric(accelerator="auto",
174 | devices=num_devices,
175 | strategy="auto",
176 | loggers=[TensorBoardLogger(cfg.out_dir, name=f"{cfg.dataset}-{cfg.prompt}")])
177 | fabric.launch()
178 | fabric.seed_everything(1337 + fabric.global_rank)
179 |
180 | if fabric.global_rank == 0:
181 | cfg_dict = cfg.to_dict()
182 | os.makedirs(os.path.join(cfg.out_dir, "configs"), exist_ok=True)
183 | cfg_dict_path = os.path.join(cfg.out_dir, "configs", f"{cfg.dataset}-{cfg.prompt}.yaml")
184 | with open(cfg_dict_path, "w") as file:
185 | yaml.dump(cfg_dict, file)
186 |
187 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True)
188 | create_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_head=cfg.csv_keys)
189 |
190 | with fabric.device:
191 | model = Model(cfg)
192 | model.setup()
193 | anchor_model = copy_model(model)
194 | LoRA_Sam(model.model, 4)
195 |
196 | load_datasets = call_load_dataset(cfg)
197 | train_data, val_data = load_datasets(cfg, model.model.image_encoder.img_size)
198 | optimizer, scheduler = configure_opt(cfg, model.model)
199 |
200 | fabric.print(f"Train Data: {len(train_data) * cfg.batch_size}; Val Data: {len(val_data) * cfg.val_batchsize}")
201 | num_iters = len(train_data) * cfg.batch_size
202 |
203 | if ckpt is not None:
204 | full_checkpoint = fabric.load(ckpt)
205 | model.load_state_dict(full_checkpoint["model"])
206 | # optimizer.load_state_dict(full_checkpoint["optimizer"])
207 |
208 | train_data = fabric._setup_dataloader(train_data)
209 | val_data = fabric._setup_dataloader(val_data)
210 | model, optimizer = fabric.setup(model, optimizer)
211 |
212 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, iters=0)
213 | train_sam(cfg, fabric, model, anchor_model, optimizer, scheduler, train_data, val_data, num_iters)
214 |
215 | del model, anchor_model, train_data, val_data
216 |
217 |
218 | if __name__ == "__main__":
219 | torch.cuda.empty_cache()
220 | torch.set_float32_matmul_precision('high')
221 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_ids
222 |
223 | main(cfg)
224 | # multi_main(cfg)
225 | torch.cuda.empty_cache()
226 |
--------------------------------------------------------------------------------
/asserts/Pipeline.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/Pipeline.webp
--------------------------------------------------------------------------------
/asserts/VISUAL.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/VISUAL.webp
--------------------------------------------------------------------------------
/asserts/teaser.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/teaser.webp
--------------------------------------------------------------------------------
/configs/base_config.py:
--------------------------------------------------------------------------------
1 | base_config = {
2 | "eval_interval": 1,
3 | "ema_rate": 0.9999,
4 | "get_prompt": False,
5 | "split": True,
6 | "csv_keys": ["Name", "Prompt", "Mean IoU", "Mean F1", "iters", "loss"],
7 | "opt": {
8 | "learning_rate": 1e-4,
9 | "weight_decay": 1e-4,
10 | "decay_factor": 10,
11 | "steps": [60000, 86666],
12 | "warmup_steps": 250,
13 | },
14 | "corruptions": [
15 | "gaussian_noise",
16 | "shot_noise",
17 | "impulse_noise",
18 | "defocus_blur",
19 | "glass_blur",
20 | "motion_blur",
21 | "zoom_blur",
22 | "snow",
23 | "frost",
24 | "fog",
25 | "brightness",
26 | "contrast",
27 | "elastic_transform",
28 | "pixelate",
29 | "jpeg_compression",
30 | ],
31 | "model": {
32 | "type": "vit_b",
33 | "checkpoint": "./checkpoints/",
34 | "ckpt": "",
35 | "freeze": {
36 | "image_encoder": True,
37 | "prompt_encoder": True,
38 | "mask_decoder": True,
39 | },
40 | },
41 | "datasets": {
42 | "coco": {
43 | "root_dir": "./data/coco2017/val2017",
44 | "annotation_file": "./data/coco2017/annotations/instances_val2017.json",
45 | },
46 | "coconut": {
47 | "root_dir": "./data/coconut/val2017",
48 | "annotation_file": "./data/coconut/coconut_dataset/annotations/annotations/relabeled_instances_val.json",
49 | },
50 | "PascalVOC": {
51 | "root_dir": "./data/VOC2012/",
52 | },
53 | "sa": {
54 | "root_dir": "./data/SA-1B",
55 | },
56 | "Polyp":{
57 | "root_dir": "./data/polyp/Kvasir-SEG",
58 | "annotation_file": "./data/polyp/Kvasir-SEG/kavsir_bboxes.json"
59 | },
60 | "ISIC": {
61 | "root_dir": "./data/ISIC/",
62 | "train_list": "./data/ISIC/ISBI2016_ISIC_Part1_Training_GroundTruth.csv",
63 | "test_list": "./data/ISIC/ISBI2016_ISIC_Part1_Test_GroundTruth.csv"
64 | },
65 | "ISTD": {
66 | "train": "./data/ISTD/train/train_A",
67 | "test": "./data/ISTD/test/test_A",
68 | },
69 | "MSD": {
70 | "train": "./data/MSD/train/image",
71 | "test": "./data/MSD/test/image",
72 | },
73 | "GDD": {
74 | "train": "./data/GDD/train/image",
75 | "test": "./data/GDD/test/image",
76 | },
77 | "CAMO":{
78 | "GT": "./data/CAMO-V.1.0-CVIU2019/GT",
79 | "train": "./data/CAMO-V.1.0-CVIU2019/Images/Train",
80 | "test": "./data/CAMO-V.1.0-CVIU2019/Images/Test",
81 | },
82 | "COD10K":{
83 | "GT": "./data/COD10K-v2/Test/GT_Object",
84 | "test": "./data/COD10K-v2/Test/Image",
85 | },
86 | "robot": {
87 | "OCID": "./data/OCID-dataset",
88 | "OSD": "./data/OSD-0.2-depth"
89 | },
90 | },
91 | }
92 |
--------------------------------------------------------------------------------
/configs/config.py:
--------------------------------------------------------------------------------
1 | from box import Box
2 | from configs.base_config import base_config
3 |
4 |
5 | config = {
6 | "gpu_ids": "0,1,2,3",
7 | "batch_size": 1,
8 | "val_batchsize": 4,
9 | "num_workers": 4,
10 | "num_iters": 40000,
11 | "max_nums": 40,
12 | "num_points": 5,
13 | "eval_interval": 1,
14 | "dataset": "COCO",
15 | "prompt": "box",
16 | "out_dir": "output/benchmark/COCO",
17 | "name": "baseline",
18 | "augment": True,
19 | "corrupt": None,
20 | "visual": False,
21 | "opt": {
22 | "learning_rate": 1e-4,
23 | },
24 | "model": {
25 | "type": "vit_b",
26 | },
27 | }
28 |
29 | cfg = Box(base_config)
30 | cfg.merge_update(config)
31 |
--------------------------------------------------------------------------------
/datasets/CAMO.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 | from skimage.draw import polygon2mask
10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
11 |
12 |
13 | class CAMODataset(Dataset):
14 | def __init__(self, cfg, image_root, gt_root, transform=None, if_self_training=False):
15 | self.cfg = cfg
16 | self.root_dir = image_root
17 | self.transform = transform
18 | images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg')]
19 | images = sorted(images)
20 |
21 | self.images = images
22 | self.gts = [os.path.join(gt_root, os.path.basename(image_path.replace(".jpg", ".png"))) for image_path in self.images]
23 | self.filter_files()
24 |
25 | self.if_self_training = if_self_training
26 |
27 | def filter_files(self):
28 | assert len(self.images) == len(self.gts)
29 | images = []
30 | gts = []
31 | for img_path, gt_path in zip(self.images, self.gts):
32 | img = Image.open(img_path)
33 | gt = Image.open(gt_path)
34 | if img.size == gt.size:
35 | images.append(img_path)
36 | gts.append(gt_path)
37 | self.images = images
38 | self.gts = gts
39 |
40 | def rgb_loader(self, path):
41 | with open(path, 'rb') as f:
42 | img = Image.open(f)
43 | return img.convert('RGB')
44 |
45 | def binary_loader(self, path):
46 | with open(path, 'rb') as f:
47 | img = Image.open(f)
48 | return img.convert('L')
49 |
50 | def __len__(self):
51 | return len(self.images)
52 |
53 | def __getitem__(self, idx):
54 | image = np.array(self.rgb_loader(self.images[idx]))
55 | gt_mask = np.array(self.binary_loader(self.gts[idx]))
56 |
57 | # mask = gt_mask.astype(bool).astype(np.uint8)
58 | if self.cfg.get_prompt:
59 | image_info = {}
60 | height, width, _ = image.shape
61 | image_info["file_path"] = self.images[idx]
62 | image_info["height"] = height
63 | image_info["width"] = width
64 | return idx, image_info, image
65 |
66 | bboxes = []
67 | masks = []
68 | categories = []
69 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
70 | assert gt_masks.sum() == (gt_mask > 0).sum()
71 | for mask in gt_masks:
72 | if np.all(mask == 0):
73 | continue
74 | masks.append(mask)
75 | x, y, w, h = cv2.boundingRect(mask)
76 | bboxes.append([x, y, x + w, y + h])
77 | categories.append("0")
78 |
79 | if self.if_self_training:
80 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
81 |
82 | if self.transform:
83 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
84 | image_strong = self.transform.transform_image(image_strong)
85 |
86 | bboxes_weak = np.stack(bboxes_weak, axis=0)
87 | masks_weak = np.stack(masks_weak, axis=0)
88 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
89 |
90 | elif self.cfg.visual:
91 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
92 | origin_image = image
93 | origin_bboxes = bboxes
94 | origin_masks = masks
95 | if self.transform:
96 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
97 |
98 | bboxes = np.stack(bboxes, axis=0)
99 | masks = np.stack(masks, axis=0)
100 | origin_bboxes = np.stack(origin_bboxes, axis=0)
101 | origin_masks = np.stack(origin_masks, axis=0)
102 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
103 |
104 | else:
105 | if self.transform:
106 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
107 |
108 | bboxes = np.stack(bboxes, axis=0)
109 | masks = np.stack(masks, axis=0)
110 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
111 |
112 |
113 | class CAMODatasetwithCoarse(CAMODataset):
114 |
115 | def __getitem__(self, idx):
116 | image = np.array(self.rgb_loader(self.images[idx]))
117 | gt_mask = np.array(self.binary_loader(self.gts[idx]))
118 |
119 | bboxes = []
120 | masks = []
121 | coarse_masks = []
122 | categories = []
123 | approxes = []
124 |
125 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
126 | assert gt_masks.sum() == (gt_mask > 0).sum()
127 | for mask in gt_masks:
128 | if np.all(mask == 0.):
129 | continue
130 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
131 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
132 | num_vertices = num_vertices if num_vertices > 3 else 3
133 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
134 | approx = approx.squeeze(1)
135 |
136 | coordinates = np.array(approx)
137 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
138 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
139 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
140 | if x_min == x_max or y_min == y_max:
141 | x, y, w, h = cv2.boundingRect(mask)
142 | bboxes.append([x, y, x + w, y + h])
143 | else:
144 | bboxes.append([x_min, y_min, x_max, y_max])
145 |
146 | masks.append(mask)
147 | coarse_masks.append(coarse_mask)
148 | approxes.append(approx)
149 | categories.append("0")
150 |
151 | if self.if_self_training:
152 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
153 |
154 | if self.transform:
155 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
156 | image_strong = self.transform.transform_image(image_strong)
157 |
158 | bboxes_weak = np.stack(bboxes_weak, axis=0)
159 | masks_weak = np.stack(masks_weak, axis=0)
160 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
161 |
162 | elif self.cfg.visual:
163 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
164 | origin_image = image
165 | origin_approxes = approxes
166 | origin_masks = masks
167 | if self.transform:
168 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
169 |
170 | bboxes = np.stack(bboxes, axis=0)
171 | masks = np.stack(masks, axis=0)
172 | # origin_approxes = np.stack(origin_approxes, axis=0)
173 | origin_masks = np.stack(origin_masks, axis=0)
174 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
175 |
176 | else:
177 | if self.transform:
178 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
179 |
180 | bboxes = np.stack(bboxes, axis=0)
181 | masks = np.stack(masks, axis=0)
182 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
183 |
184 |
185 | def load_datasets(cfg, img_size):
186 | transform = ResizeAndPad(img_size)
187 | val = CAMODataset(
188 | cfg,
189 | image_root=cfg.datasets.CAMO.test,
190 | gt_root=cfg.datasets.CAMO.GT,
191 | transform=transform,
192 | )
193 | train = CAMODataset(
194 | cfg,
195 | image_root=cfg.datasets.CAMO.train,
196 | gt_root=cfg.datasets.CAMO.GT,
197 | transform=transform,
198 | if_self_training=cfg.augment,
199 | )
200 | val_dataloader = DataLoader(
201 | val,
202 | batch_size=cfg.val_batchsize,
203 | shuffle=False,
204 | num_workers=cfg.num_workers,
205 | collate_fn=collate_fn,
206 | )
207 | train_dataloader = DataLoader(
208 | train,
209 | batch_size=cfg.batch_size,
210 | shuffle=True,
211 | num_workers=cfg.num_workers,
212 | collate_fn=collate_fn,
213 | )
214 | return train_dataloader, val_dataloader
215 |
216 |
217 | def load_datasets_coarse(cfg, img_size):
218 | transform = ResizeAndPad(img_size)
219 | val = CAMODatasetwithCoarse(
220 | cfg,
221 | image_root=cfg.datasets.CAMO.test,
222 | gt_root=cfg.datasets.CAMO.GT,
223 | transform=transform,
224 | )
225 | train = CAMODatasetwithCoarse(
226 | cfg,
227 | image_root=cfg.datasets.CAMO.train,
228 | gt_root=cfg.datasets.CAMO.GT,
229 | transform=transform,
230 | if_self_training=cfg.augment,
231 | )
232 | val_dataloader = DataLoader(
233 | val,
234 | batch_size=cfg.val_batchsize,
235 | shuffle=False,
236 | num_workers=cfg.num_workers,
237 | collate_fn=collate_fn,
238 | )
239 | train_dataloader = DataLoader(
240 | train,
241 | batch_size=cfg.batch_size,
242 | shuffle=True,
243 | num_workers=cfg.num_workers,
244 | collate_fn=collate_fn,
245 | )
246 | return train_dataloader, val_dataloader
247 |
248 |
249 | def load_datasets_visual(cfg, img_size):
250 | transform = ResizeAndPad(img_size)
251 | val = CAMODataset(
252 | cfg,
253 | image_root=cfg.datasets.CAMO.test,
254 | gt_root=cfg.datasets.CAMO.GT,
255 | transform=transform,
256 | )
257 | val_dataloader = DataLoader(
258 | val,
259 | batch_size=cfg.val_batchsize,
260 | shuffle=False,
261 | num_workers=cfg.num_workers,
262 | collate_fn=collate_fn_,
263 | )
264 | return val_dataloader
265 |
266 |
267 | def load_datasets_visual_coarse(cfg, img_size):
268 | transform = ResizeAndPad(img_size)
269 | val = CAMODatasetwithCoarse(
270 | cfg,
271 | image_root=cfg.datasets.CAMO.test,
272 | gt_root=cfg.datasets.CAMO.GT,
273 | transform=transform,
274 | )
275 | val_dataloader = DataLoader(
276 | val,
277 | batch_size=cfg.val_batchsize,
278 | shuffle=False,
279 | num_workers=cfg.num_workers,
280 | collate_fn=collate_fn_,
281 | )
282 | return val_dataloader
283 |
284 |
285 | def load_datasets_prompt(cfg, img_size):
286 | transform = ResizeAndPad(img_size)
287 | train = CAMODataset(
288 | cfg,
289 | image_root=cfg.datasets.CAMO.train,
290 | gt_root=cfg.datasets.CAMO.GT,
291 | transform=transform,
292 | if_self_training=cfg.augment,
293 | )
294 | train_dataloader = DataLoader(
295 | train,
296 | batch_size=cfg.batch_size,
297 | shuffle=True,
298 | num_workers=cfg.num_workers,
299 | collate_fn=collate_fn_,
300 | )
301 | return train_dataloader
--------------------------------------------------------------------------------
/datasets/GDD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import glob
5 | import json
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | from torch.utils.data import Dataset, DataLoader
10 | from skimage.draw import polygon2mask
11 |
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
13 |
14 |
15 | class GDDDataset(Dataset):
16 | def __init__(self, cfg, root_dir, transform=None, if_self_training=False):
17 | self.cfg = cfg
18 | self.root_dir = root_dir
19 | self.transform = transform
20 |
21 | images = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]
22 | images = sorted(images)
23 |
24 | self.images = images
25 | self.gts = [image_path.replace("image", "mask").replace("jpg", "png") for image_path in self.images]
26 |
27 | self.if_self_training = if_self_training
28 |
29 | def __len__(self):
30 | return len(self.images)
31 |
32 | def __getitem__(self, idx):
33 | image_path = self.images[idx]
34 | image = cv2.imread(image_path)
35 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36 |
37 | if self.cfg.get_prompt:
38 | image_info = {}
39 | height, width, _ = image.shape
40 | image_info["file_path"] = image_path
41 | image_info["height"] = height
42 | image_info["width"] = width
43 | return idx, image_info, image
44 |
45 | gt_path = self.gts[idx]
46 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
47 |
48 | masks = []
49 | bboxes = []
50 | categories = []
51 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
52 | assert gt_masks.sum() == (gt_mask > 0).sum()
53 | for mask in gt_masks:
54 | masks.append(mask)
55 | x, y, w, h = cv2.boundingRect(mask)
56 | bboxes.append([x, y, x + w, y + h])
57 | categories.append("0")
58 |
59 | if self.if_self_training:
60 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
61 |
62 | if self.transform:
63 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
64 | image_strong = self.transform.transform_image(image_strong)
65 |
66 | bboxes_weak = np.stack(bboxes_weak, axis=0)
67 | masks_weak = np.stack(masks_weak, axis=0)
68 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
69 |
70 | elif self.cfg.visual:
71 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
72 | origin_image = image
73 | origin_bboxes = bboxes
74 | origin_masks = masks
75 | if self.transform:
76 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
77 |
78 | bboxes = np.stack(bboxes, axis=0)
79 | masks = np.stack(masks, axis=0)
80 | origin_bboxes = np.stack(origin_bboxes, axis=0)
81 | origin_masks = np.stack(origin_masks, axis=0)
82 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
83 |
84 | else:
85 | if self.transform:
86 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
87 |
88 | bboxes = np.stack(bboxes, axis=0)
89 | masks = np.stack(masks, axis=0)
90 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
91 |
92 |
93 | class GDDDatasetwithCoarse(GDDDataset):
94 |
95 | def __getitem__(self, idx):
96 | image_path = self.images[idx]
97 | image = cv2.imread(image_path)
98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99 |
100 | gt_path = self.gts[idx]
101 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
102 |
103 | masks = []
104 | bboxes = []
105 | approxes = []
106 | categories = []
107 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
108 | assert gt_masks.sum() == (gt_mask > 0).sum()
109 | for mask in gt_masks:
110 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
111 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
112 | num_vertices = num_vertices if num_vertices > 3 else 3
113 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
114 | approx = approx.squeeze(1)
115 |
116 | coordinates = np.array(approx)
117 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
118 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
119 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
120 | if x_min == x_max or y_min == y_max:
121 | x, y, w, h = cv2.boundingRect(mask)
122 | bboxes.append([x, y, x + w, y + h])
123 | else:
124 | bboxes.append([x_min, y_min, x_max, y_max])
125 |
126 | masks.append(mask)
127 | categories.append("0")
128 | approxes.append(approx)
129 |
130 | if self.if_self_training:
131 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
132 |
133 | if self.transform:
134 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
135 | image_strong = self.transform.transform_image(image_strong)
136 |
137 | bboxes_weak = np.stack(bboxes_weak, axis=0)
138 | masks_weak = np.stack(masks_weak, axis=0)
139 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
140 |
141 | elif self.cfg.visual:
142 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
143 | origin_image = image
144 | origin_approxes = approxes
145 | origin_masks = masks
146 | if self.transform:
147 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
148 |
149 | bboxes = np.stack(bboxes, axis=0)
150 | masks = np.stack(masks, axis=0)
151 | origin_masks = np.stack(origin_masks, axis=0)
152 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
153 |
154 | else:
155 | if self.transform:
156 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
157 |
158 | bboxes = np.stack(bboxes, axis=0)
159 | masks = np.stack(masks, axis=0)
160 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
161 |
162 |
163 | def load_datasets(cfg, img_size):
164 | transform = ResizeAndPad(img_size)
165 | val = GDDDataset(
166 | cfg,
167 | root_dir=cfg.datasets.GDD.test,
168 | transform=transform,
169 | )
170 | train = GDDDataset(
171 | cfg,
172 | root_dir=cfg.datasets.GDD.train,
173 | transform=transform,
174 | if_self_training=cfg.augment,
175 | )
176 | val_dataloader = DataLoader(
177 | val,
178 | batch_size=cfg.val_batchsize,
179 | shuffle=False,
180 | num_workers=cfg.num_workers,
181 | collate_fn=collate_fn,
182 | )
183 | train_dataloader = DataLoader(
184 | train,
185 | batch_size=cfg.batch_size,
186 | shuffle=True,
187 | num_workers=cfg.num_workers,
188 | collate_fn=collate_fn,
189 | )
190 | return train_dataloader, val_dataloader
191 |
192 |
193 | def load_datasets_coarse(cfg, img_size):
194 | transform = ResizeAndPad(img_size)
195 | val = GDDDatasetwithCoarse(
196 | cfg,
197 | root_dir=cfg.datasets.GDD.test,
198 | transform=transform,
199 | )
200 | train = GDDDatasetwithCoarse(
201 | cfg,
202 | root_dir=cfg.datasets.GDD.train,
203 | transform=transform,
204 | if_self_training=cfg.augment,
205 | )
206 | val_dataloader = DataLoader(
207 | val,
208 | batch_size=cfg.val_batchsize,
209 | shuffle=False,
210 | num_workers=cfg.num_workers,
211 | collate_fn=collate_fn,
212 | )
213 | train_dataloader = DataLoader(
214 | train,
215 | batch_size=cfg.batch_size,
216 | shuffle=True,
217 | num_workers=cfg.num_workers,
218 | collate_fn=collate_fn,
219 | )
220 | return train_dataloader, val_dataloader
221 |
222 |
223 | def load_datasets_visual(cfg, img_size):
224 | transform = ResizeAndPad(img_size)
225 | val = GDDDataset(
226 | cfg,
227 | root_dir=cfg.datasets.GDD.test,
228 | transform=transform,
229 | )
230 | val_dataloader = DataLoader(
231 | val,
232 | batch_size=cfg.val_batchsize,
233 | shuffle=False,
234 | num_workers=cfg.num_workers,
235 | collate_fn=collate_fn_,
236 | )
237 | return val_dataloader
238 |
239 |
240 | def load_datasets_visual_coarse(cfg, img_size):
241 | transform = ResizeAndPad(img_size)
242 | val = GDDDatasetwithCoarse(
243 | cfg,
244 | root_dir=cfg.datasets.GDD.test,
245 | transform=transform,
246 | )
247 | val_dataloader = DataLoader(
248 | val,
249 | batch_size=cfg.val_batchsize,
250 | shuffle=False,
251 | num_workers=cfg.num_workers,
252 | collate_fn=collate_fn_,
253 | )
254 | return val_dataloader
255 |
256 |
257 | def load_datasets_prompt(cfg, img_size):
258 | transform = ResizeAndPad(img_size)
259 | train = GDDDataset(
260 | cfg,
261 | root_dir=cfg.datasets.GDD.train,
262 | transform=transform,
263 | if_self_training=cfg.augment,
264 | )
265 | train_dataloader = DataLoader(
266 | train,
267 | batch_size=cfg.batch_size,
268 | shuffle=True,
269 | num_workers=cfg.num_workers,
270 | collate_fn=collate_fn_,
271 | )
272 | return train_dataloader
273 |
--------------------------------------------------------------------------------
/datasets/ISIC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import glob
5 | import json
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | from torch.utils.data import Dataset, DataLoader
10 | from skimage.draw import polygon2mask
11 |
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
13 |
14 |
15 | class ISICDataset(Dataset):
16 | def __init__(self, cfg, root_dir, list_file, transform=None, if_self_training=False):
17 | self.cfg = cfg
18 | df = pd.read_csv(os.path.join(list_file), encoding='gbk')
19 | self.name_list = df.iloc[:,1].tolist()
20 | self.label_list = df.iloc[:,2].tolist()
21 | self.root_dir = root_dir
22 | self.transform = transform
23 |
24 | self.if_self_training = if_self_training
25 |
26 | def __len__(self):
27 | return len(self.name_list)
28 |
29 | def __getitem__(self, idx):
30 | name = self.name_list[idx]
31 | image_path = os.path.join(self.root_dir, name)
32 | image = cv2.imread(image_path)
33 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34 |
35 | if self.cfg.get_prompt:
36 | image_info = {}
37 | height, width, _ = image.shape
38 | image_info["file_path"] = image_path
39 | image_info["height"] = height
40 | image_info["width"] = width
41 | return idx, image_info, image
42 |
43 | label_name = self.label_list[idx]
44 | gt_path = os.path.join(self.root_dir, label_name)
45 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
46 |
47 | masks = []
48 | bboxes = []
49 | categories = []
50 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
51 | assert gt_masks.sum() == (gt_mask > 0).sum()
52 | for mask in gt_masks:
53 | masks.append(mask)
54 | x, y, w, h = cv2.boundingRect(mask)
55 | bboxes.append([x, y, x + w, y + h])
56 | categories.append("0")
57 |
58 | if self.if_self_training:
59 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
60 |
61 | if self.transform:
62 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
63 | image_strong = self.transform.transform_image(image_strong)
64 |
65 | bboxes_weak = np.stack(bboxes_weak, axis=0)
66 | masks_weak = np.stack(masks_weak, axis=0)
67 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
68 |
69 | elif self.cfg.visual:
70 | file_name = os.path.splitext(os.path.basename(name))[0]
71 | origin_image = image
72 | origin_bboxes = bboxes
73 | origin_masks = masks
74 | if self.transform:
75 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
76 |
77 | bboxes = np.stack(bboxes, axis=0)
78 | masks = np.stack(masks, axis=0)
79 | origin_bboxes = np.stack(origin_bboxes, axis=0)
80 | origin_masks = np.stack(origin_masks, axis=0)
81 | return file_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
82 |
83 | else:
84 | if self.transform:
85 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
86 |
87 | bboxes = np.stack(bboxes, axis=0)
88 | masks = np.stack(masks, axis=0)
89 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
90 |
91 |
92 | class ISICDatasetwithCoarse(ISICDataset):
93 |
94 | def __getitem__(self, idx):
95 | name = self.name_list[idx]
96 | image_path = os.path.join(self.root_dir, name)
97 | image = cv2.imread(image_path)
98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99 |
100 | label_name = self.label_list[idx]
101 | gt_path = os.path.join(self.root_dir, label_name)
102 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
103 |
104 | masks = []
105 | bboxes = []
106 | approxes = []
107 | categories = []
108 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
109 | assert gt_masks.sum() == (gt_mask > 0).sum()
110 | for mask in gt_masks:
111 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
112 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
113 | num_vertices = num_vertices if num_vertices > 3 else 3
114 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
115 | approx = approx.squeeze(1)
116 |
117 | coordinates = np.array(approx)
118 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
119 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
120 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
121 | if x_min == x_max or y_min == y_max:
122 | x, y, w, h = cv2.boundingRect(mask)
123 | bboxes.append([x, y, x + w, y + h])
124 | else:
125 | bboxes.append([x_min, y_min, x_max, y_max])
126 |
127 | masks.append(mask)
128 | categories.append("0")
129 | approxes.append(approx)
130 |
131 | if self.if_self_training:
132 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
133 |
134 | if self.transform:
135 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
136 | image_strong = self.transform.transform_image(image_strong)
137 |
138 | bboxes_weak = np.stack(bboxes_weak, axis=0)
139 | masks_weak = np.stack(masks_weak, axis=0)
140 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
141 |
142 | elif self.cfg.visual:
143 | file_name = os.path.splitext(os.path.basename(name))[0]
144 | origin_image = image
145 | origin_approxes = approxes
146 | origin_masks = masks
147 | if self.transform:
148 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
149 |
150 | bboxes = np.stack(bboxes, axis=0)
151 | masks = np.stack(masks, axis=0)
152 | origin_masks = np.stack(origin_masks, axis=0)
153 | return file_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
154 |
155 | else:
156 | if self.transform:
157 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
158 |
159 | bboxes = np.stack(bboxes, axis=0)
160 | masks = np.stack(masks, axis=0)
161 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
162 |
163 |
164 | def load_datasets(cfg, img_size):
165 | transform = ResizeAndPad(img_size)
166 | val = ISICDataset(
167 | cfg,
168 | root_dir=cfg.datasets.ISIC.root_dir,
169 | list_file=cfg.datasets.ISIC.test_list,
170 | transform=transform,
171 | )
172 | train = ISICDataset(
173 | cfg,
174 | root_dir=cfg.datasets.ISIC.root_dir,
175 | list_file=cfg.datasets.ISIC.train_list,
176 | transform=transform,
177 | if_self_training=cfg.augment,
178 | )
179 | val_dataloader = DataLoader(
180 | val,
181 | batch_size=cfg.val_batchsize,
182 | shuffle=False,
183 | num_workers=cfg.num_workers,
184 | collate_fn=collate_fn,
185 | )
186 | train_dataloader = DataLoader(
187 | train,
188 | batch_size=cfg.batch_size,
189 | shuffle=True,
190 | num_workers=cfg.num_workers,
191 | collate_fn=collate_fn,
192 | )
193 | return train_dataloader, val_dataloader
194 |
195 |
196 | def load_datasets_coarse(cfg, img_size):
197 | transform = ResizeAndPad(img_size)
198 | val = ISICDatasetwithCoarse(
199 | cfg,
200 | root_dir=cfg.datasets.ISIC.root_dir,
201 | list_file=cfg.datasets.ISIC.test_list,
202 | transform=transform,
203 | )
204 | train = ISICDatasetwithCoarse(
205 | cfg,
206 | root_dir=cfg.datasets.ISIC.root_dir,
207 | list_file=cfg.datasets.ISIC.train_list,
208 | transform=transform,
209 | if_self_training=cfg.augment,
210 | )
211 | val_dataloader = DataLoader(
212 | val,
213 | batch_size=cfg.val_batchsize,
214 | shuffle=False,
215 | num_workers=cfg.num_workers,
216 | collate_fn=collate_fn,
217 | )
218 | train_dataloader = DataLoader(
219 | train,
220 | batch_size=cfg.batch_size,
221 | shuffle=True,
222 | num_workers=cfg.num_workers,
223 | collate_fn=collate_fn,
224 | )
225 | return train_dataloader, val_dataloader
226 |
227 |
228 | def load_datasets_visual(cfg, img_size):
229 | transform = ResizeAndPad(img_size)
230 | val = ISICDataset(
231 | cfg,
232 | root_dir=cfg.datasets.ISIC.root_dir,
233 | list_file=cfg.datasets.ISIC.test_list,
234 | transform=transform,
235 | )
236 | val_dataloader = DataLoader(
237 | val,
238 | batch_size=cfg.val_batchsize,
239 | shuffle=False,
240 | num_workers=cfg.num_workers,
241 | collate_fn=collate_fn_,
242 | )
243 | return val_dataloader
244 |
245 |
246 | def load_datasets_visual_coarse(cfg, img_size):
247 | transform = ResizeAndPad(img_size)
248 | val = ISICDatasetwithCoarse(
249 | cfg,
250 | root_dir=cfg.datasets.ISIC.root_dir,
251 | list_file=cfg.datasets.ISIC.test_list,
252 | transform=transform,
253 | )
254 | val_dataloader = DataLoader(
255 | val,
256 | batch_size=cfg.val_batchsize,
257 | shuffle=False,
258 | num_workers=cfg.num_workers,
259 | collate_fn=collate_fn_,
260 | )
261 | return val_dataloader
262 |
263 |
264 | def load_datasets_prompt(cfg, img_size):
265 | transform = ResizeAndPad(img_size)
266 | train = ISICDataset(
267 | cfg,
268 | root_dir=cfg.datasets.ISIC.root_dir,
269 | list_file=cfg.datasets.ISIC.train_list,
270 | transform=transform,
271 | if_self_training=cfg.augment,
272 | )
273 | train_dataloader = DataLoader(
274 | train,
275 | batch_size=cfg.batch_size,
276 | shuffle=True,
277 | num_workers=cfg.num_workers,
278 | collate_fn=collate_fn_,
279 | )
280 | return train_dataloader
281 |
--------------------------------------------------------------------------------
/datasets/ISTD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 | from skimage.draw import polygon2mask
10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
11 |
12 |
13 | class ISTDDataset(Dataset):
14 | def __init__(self, cfg, image_root, transform=None, if_self_training=False):
15 | self.cfg = cfg
16 | self.root_dir = image_root
17 | self.transform = transform
18 | images = [os.path.join(image_root, f) for f in os.listdir(image_root)]
19 | images = sorted(images)
20 |
21 | self.images = images
22 | self.gts = [image_path.replace("A", "B") for image_path in self.images]
23 |
24 | self.if_self_training = if_self_training
25 |
26 | def rgb_loader(self, path):
27 | with open(path, 'rb') as f:
28 | img = Image.open(f)
29 | return img.convert('RGB')
30 |
31 | def binary_loader(self, path):
32 | with open(path, 'rb') as f:
33 | img = Image.open(f)
34 | return img.convert('L')
35 |
36 | def __len__(self):
37 | return len(self.images)
38 |
39 | def __getitem__(self, idx):
40 | image = np.array(self.rgb_loader(self.images[idx]))
41 | gt_mask = np.array(self.binary_loader(self.gts[idx]))
42 |
43 | if self.cfg.get_prompt:
44 | image_info = {}
45 | height, width, _ = image.shape
46 | image_info["file_path"] = self.images[idx]
47 | image_info["height"] = height
48 | image_info["width"] = width
49 | return idx, image_info, image
50 |
51 | bboxes = []
52 | masks = []
53 | categories = []
54 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
55 | assert gt_masks.sum() == (gt_mask > 0).sum()
56 | for mask in gt_masks:
57 | if np.all(mask == 0):
58 | continue
59 | masks.append(mask)
60 | x, y, w, h = cv2.boundingRect(mask)
61 | bboxes.append([x, y, x + w, y + h])
62 | categories.append("0")
63 |
64 | if self.if_self_training:
65 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
66 |
67 | if self.transform:
68 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
69 | image_strong = self.transform.transform_image(image_strong)
70 |
71 | bboxes_weak = np.stack(bboxes_weak, axis=0)
72 | masks_weak = np.stack(masks_weak, axis=0)
73 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
74 |
75 | elif self.cfg.visual:
76 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
77 | origin_image = image
78 | origin_bboxes = bboxes
79 | origin_masks = masks
80 | if self.transform:
81 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
82 |
83 | bboxes = np.stack(bboxes, axis=0)
84 | masks = np.stack(masks, axis=0)
85 | origin_bboxes = np.stack(origin_bboxes, axis=0)
86 | origin_masks = np.stack(origin_masks, axis=0)
87 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
88 |
89 | else:
90 | if self.transform:
91 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
92 |
93 | bboxes = np.stack(bboxes, axis=0)
94 | masks = np.stack(masks, axis=0)
95 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
96 |
97 |
98 | class ISTDDatasetwithCoarse(ISTDDataset):
99 |
100 | def __getitem__(self, idx):
101 | image = np. array(self.rgb_loader(self.images[idx]))
102 | gt_mask = np. array(self.binary_loader(self.gts[idx]))
103 |
104 | bboxes = []
105 | masks = []
106 | coarse_masks = []
107 | categories = []
108 | approxes = []
109 |
110 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
111 | assert gt_masks.sum() == (gt_mask > 0).sum()
112 | for mask in gt_masks:
113 | if np.all(mask == 0.):
114 | continue
115 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
116 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
117 | num_vertices = num_vertices if num_vertices > 3 else 3
118 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
119 | approx = approx.squeeze(1)
120 |
121 | coordinates = np.array(approx)
122 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
123 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
124 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
125 | if x_min == x_max or y_min == y_max:
126 | x, y, w, h = cv2.boundingRect(mask)
127 | bboxes.append([x, y, x + w, y + h])
128 | else:
129 | bboxes.append([x_min, y_min, x_max, y_max])
130 |
131 | masks.append(mask)
132 | coarse_masks.append(coarse_mask)
133 | approxes.append(approx)
134 | categories.append("0")
135 |
136 | if self.if_self_training:
137 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
138 |
139 | if self.transform:
140 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
141 | image_strong = self.transform.transform_image(image_strong)
142 |
143 | bboxes_weak = np.stack(bboxes_weak, axis=0)
144 | masks_weak = np.stack(masks_weak, axis=0)
145 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
146 |
147 | elif self.cfg.visual:
148 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
149 | origin_image = image
150 | origin_approxes = approxes
151 | origin_masks = masks
152 | if self.transform:
153 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
154 |
155 | bboxes = np.stack(bboxes, axis=0)
156 | masks = np.stack(masks, axis=0)
157 | origin_masks = np.stack(origin_masks, axis=0)
158 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
159 |
160 | else:
161 | if self.transform:
162 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
163 |
164 | bboxes = np.stack(bboxes, axis=0)
165 | masks = np.stack(masks, axis=0)
166 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
167 |
168 |
169 | def load_datasets(cfg, img_size):
170 | transform = ResizeAndPad(img_size)
171 | val = ISTDDataset(
172 | cfg,
173 | image_root=cfg.datasets.ISTD.test,
174 | transform=transform,
175 | )
176 | train = ISTDDataset(
177 | cfg,
178 | image_root=cfg.datasets.ISTD.train,
179 | transform=transform,
180 | if_self_training=cfg.augment,
181 | )
182 | val_dataloader = DataLoader(
183 | val,
184 | batch_size=cfg.val_batchsize,
185 | shuffle=False,
186 | num_workers=cfg.num_workers,
187 | collate_fn=collate_fn,
188 | )
189 | train_dataloader = DataLoader(
190 | train,
191 | batch_size=cfg.batch_size,
192 | shuffle=True,
193 | num_workers=cfg.num_workers,
194 | collate_fn=collate_fn,
195 | )
196 | return train_dataloader, val_dataloader
197 |
198 |
199 | def load_datasets_coarse(cfg, img_size):
200 | transform = ResizeAndPad(img_size)
201 | val = ISTDDatasetwithCoarse(
202 | cfg,
203 | image_root=cfg.datasets.ISTD.test,
204 | transform=transform,
205 | )
206 | train = ISTDDatasetwithCoarse(
207 | cfg,
208 | image_root=cfg.datasets.ISTD.train,
209 | transform=transform,
210 | )
211 | val_dataloader = DataLoader(
212 | val,
213 | batch_size=cfg.val_batchsize,
214 | shuffle=False,
215 | num_workers=cfg.num_workers,
216 | collate_fn=collate_fn,
217 | )
218 | train_dataloader = DataLoader(
219 | train,
220 | batch_size=cfg.batch_size,
221 | shuffle=True,
222 | num_workers=cfg.num_workers,
223 | collate_fn=collate_fn,
224 | )
225 | return train_dataloader, val_dataloader
226 |
227 |
228 | def load_datasets_visual(cfg, img_size):
229 | transform = ResizeAndPad(img_size)
230 | val = ISTDDataset(
231 | cfg,
232 | image_root=cfg.datasets.ISTD.test,
233 | transform=transform,
234 | )
235 | val_dataloader = DataLoader(
236 | val,
237 | batch_size=cfg.val_batchsize,
238 | shuffle=False,
239 | num_workers=cfg.num_workers,
240 | collate_fn=collate_fn_,
241 | )
242 | return val_dataloader
243 |
244 |
245 | def load_datasets_visual_coarse(cfg, img_size):
246 | transform = ResizeAndPad(img_size)
247 | val = ISTDDatasetwithCoarse(
248 | cfg,
249 | image_root=cfg.datasets.ISTD.test,
250 | transform=transform,
251 | )
252 | val_dataloader = DataLoader(
253 | val,
254 | batch_size=cfg.val_batchsize,
255 | shuffle=False,
256 | num_workers=cfg.num_workers,
257 | collate_fn=collate_fn_,
258 | )
259 | return val_dataloader
260 |
261 |
262 | def load_datasets_prompt(cfg, img_size):
263 | transform = ResizeAndPad(img_size)
264 | train = ISTDDataset(
265 | cfg,
266 | image_root=cfg.datasets.ISTD.train,
267 | transform=transform,
268 | if_self_training=cfg.augment,
269 | )
270 | train_dataloader = DataLoader(
271 | train,
272 | batch_size=cfg.batch_size,
273 | shuffle=True,
274 | num_workers=cfg.num_workers,
275 | collate_fn=collate_fn_,
276 | )
277 | return train_dataloader
278 |
--------------------------------------------------------------------------------
/datasets/MSD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import glob
5 | import json
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | from torch.utils.data import Dataset, DataLoader
10 | from skimage.draw import polygon2mask
11 |
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
13 |
14 |
15 | class MSDDataset(Dataset):
16 | def __init__(self, cfg, root_dir, transform=None, if_self_training=False):
17 | self.cfg = cfg
18 | self.root_dir = root_dir
19 | self.transform = transform
20 |
21 | images = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]
22 | images = sorted(images)
23 |
24 | self.images = images
25 | self.gts = [image_path.replace("image", "mask").replace("jpg", "png") for image_path in self.images]
26 |
27 | self.if_self_training = if_self_training
28 |
29 | def __len__(self):
30 | return len(self.images)
31 |
32 | def __getitem__(self, idx):
33 | image_path = self.images[idx]
34 | image = cv2.imread(image_path)
35 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36 |
37 | if self.cfg.get_prompt:
38 | image_info = {}
39 | height, width, _ = image.shape
40 | image_info["file_path"] = image_path
41 | image_info["height"] = height
42 | image_info["width"] = width
43 | return idx, image_info, image
44 |
45 | gt_path = self.gts[idx]
46 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
47 |
48 | masks = []
49 | bboxes = []
50 | categories = []
51 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
52 | assert gt_masks.sum() == (gt_mask > 0).sum()
53 | for mask in gt_masks:
54 | masks.append(mask)
55 | x, y, w, h = cv2.boundingRect(mask)
56 | bboxes.append([x, y, x + w, y + h])
57 | categories.append("0")
58 |
59 | if self.if_self_training:
60 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
61 |
62 | if self.transform:
63 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
64 | image_strong = self.transform.transform_image(image_strong)
65 |
66 | bboxes_weak = np.stack(bboxes_weak, axis=0)
67 | masks_weak = np.stack(masks_weak, axis=0)
68 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
69 |
70 | elif self.cfg.visual:
71 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
72 | origin_image = image
73 | origin_bboxes = bboxes
74 | origin_masks = masks
75 | if self.transform:
76 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True)
77 |
78 | bboxes = np.stack(bboxes, axis=0)
79 | masks = np.stack(masks, axis=0)
80 | origin_bboxes = np.stack(origin_bboxes, axis=0)
81 | origin_masks = np.stack(origin_masks, axis=0)
82 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
83 |
84 | else:
85 | if self.transform:
86 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
87 |
88 | bboxes = np.stack(bboxes, axis=0)
89 | masks = np.stack(masks, axis=0)
90 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
91 |
92 |
93 | class MSDDatasetwithCoarse(MSDDataset):
94 |
95 | def __getitem__(self, idx):
96 | image_path = self.images[idx]
97 | image = cv2.imread(image_path)
98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99 |
100 | gt_path = self.gts[idx]
101 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
102 |
103 | masks = []
104 | bboxes = []
105 | approxes = []
106 | categories = []
107 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
108 | assert gt_masks.sum() == (gt_mask > 0).sum()
109 | for mask in gt_masks:
110 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
111 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
112 | num_vertices = num_vertices if num_vertices > 3 else 3
113 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
114 | approx = approx.squeeze(1)
115 |
116 | coordinates = np.array(approx)
117 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
118 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
119 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
120 | if x_min == x_max or y_min == y_max:
121 | x, y, w, h = cv2.boundingRect(mask)
122 | bboxes.append([x, y, x + w, y + h])
123 | else:
124 | bboxes.append([x_min, y_min, x_max, y_max])
125 |
126 | masks.append(mask)
127 | categories.append("0")
128 | approxes.append(approx)
129 |
130 | if self.if_self_training:
131 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
132 |
133 | if self.transform:
134 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
135 | image_strong = self.transform.transform_image(image_strong)
136 |
137 | bboxes_weak = np.stack(bboxes_weak, axis=0)
138 | masks_weak = np.stack(masks_weak, axis=0)
139 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
140 |
141 | elif self.cfg.visual:
142 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0]
143 | origin_image = image
144 | origin_approxes = approxes
145 | origin_masks = masks
146 | if self.transform:
147 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
148 |
149 | bboxes = np.stack(bboxes, axis=0)
150 | masks = np.stack(masks, axis=0)
151 | origin_masks = np.stack(origin_masks, axis=0)
152 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
153 |
154 | else:
155 | if self.transform:
156 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
157 |
158 | bboxes = np.stack(bboxes, axis=0)
159 | masks = np.stack(masks, axis=0)
160 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
161 |
162 |
163 | def load_datasets(cfg, img_size):
164 | transform = ResizeAndPad(img_size)
165 | val = MSDDataset(
166 | cfg,
167 | root_dir=cfg.datasets.MSD.test,
168 | transform=transform,
169 | )
170 | train = MSDDataset(
171 | cfg,
172 | root_dir=cfg.datasets.MSD.train,
173 | transform=transform,
174 | if_self_training=cfg.augment,
175 | )
176 | val_dataloader = DataLoader(
177 | val,
178 | batch_size=cfg.val_batchsize,
179 | shuffle=False,
180 | num_workers=cfg.num_workers,
181 | collate_fn=collate_fn,
182 | )
183 | train_dataloader = DataLoader(
184 | train,
185 | batch_size=cfg.batch_size,
186 | shuffle=True,
187 | num_workers=cfg.num_workers,
188 | collate_fn=collate_fn,
189 | )
190 | return train_dataloader, val_dataloader
191 |
192 |
193 | def load_datasets_coarse(cfg, img_size):
194 | transform = ResizeAndPad(img_size)
195 | val = MSDDatasetwithCoarse(
196 | cfg,
197 | root_dir=cfg.datasets.MSD.test,
198 | transform=transform,
199 | )
200 | train = MSDDatasetwithCoarse(
201 | cfg,
202 | root_dir=cfg.datasets.MSD.train,
203 | transform=transform,
204 | if_self_training=cfg.augment,
205 | )
206 | val_dataloader = DataLoader(
207 | val,
208 | batch_size=cfg.val_batchsize,
209 | shuffle=False,
210 | num_workers=cfg.num_workers,
211 | collate_fn=collate_fn,
212 | )
213 | train_dataloader = DataLoader(
214 | train,
215 | batch_size=cfg.batch_size,
216 | shuffle=True,
217 | num_workers=cfg.num_workers,
218 | collate_fn=collate_fn,
219 | )
220 | return train_dataloader, val_dataloader
221 |
222 |
223 | def load_datasets_visual(cfg, img_size):
224 | transform = ResizeAndPad(img_size)
225 | val = MSDDataset(
226 | cfg,
227 | root_dir=cfg.datasets.MSD.test,
228 | transform=transform,
229 | )
230 | val_dataloader = DataLoader(
231 | val,
232 | batch_size=cfg.val_batchsize,
233 | shuffle=False,
234 | num_workers=cfg.num_workers,
235 | collate_fn=collate_fn_,
236 | )
237 | return val_dataloader
238 |
239 |
240 | def load_datasets_visual_coarse(cfg, img_size):
241 | transform = ResizeAndPad(img_size)
242 | val = MSDDatasetwithCoarse(
243 | cfg,
244 | root_dir=cfg.datasets.MSD.test,
245 | transform=transform,
246 | )
247 | val_dataloader = DataLoader(
248 | val,
249 | batch_size=cfg.val_batchsize,
250 | shuffle=False,
251 | num_workers=cfg.num_workers,
252 | collate_fn=collate_fn_,
253 | )
254 | return val_dataloader
255 |
256 |
257 | def load_datasets_prompt(cfg, img_size):
258 | transform = ResizeAndPad(img_size)
259 | train = MSDDataset(
260 | cfg,
261 | root_dir=cfg.datasets.MSD.train,
262 | transform=transform,
263 | if_self_training=cfg.augment,
264 | )
265 | train_dataloader = DataLoader(
266 | train,
267 | batch_size=cfg.batch_size,
268 | shuffle=True,
269 | num_workers=cfg.num_workers,
270 | collate_fn=collate_fn_,
271 | )
272 | return train_dataloader
273 |
--------------------------------------------------------------------------------
/datasets/OSD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import glob
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 | from skimage.draw import polygon2mask
10 | from pathlib import Path
11 | from PIL import Image
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, decode_mask, collate_fn_
13 |
14 |
15 | class OSDObject(Dataset):
16 | def __init__(self, cfg, root_dir, transform=None, split=False, training=False, if_self_training=False):
17 | self.cfg = cfg
18 | self._osd_object_path = root_dir
19 | self.transform = transform
20 | # get all images
21 | data_path = os.path.join(self._osd_object_path, 'image_color')
22 | all_image_paths = sorted(glob.glob(data_path + '/*.png'))
23 | all_image_paths = self.check_empty(all_image_paths)
24 |
25 | if split:
26 | train_image_paths = []
27 | eval_image_paths = []
28 | while all_image_paths:
29 | for _ in range(6):
30 | if all_image_paths:
31 | train_image_paths.append(all_image_paths.pop(0))
32 | if all_image_paths:
33 | eval_image_paths.append(all_image_paths.pop(0))
34 |
35 | if training:
36 | random.shuffle(train_image_paths)
37 | image_paths = train_image_paths
38 | else:
39 | random.shuffle(eval_image_paths)
40 | image_paths = eval_image_paths
41 | else:
42 | image_paths = all_image_paths
43 |
44 | self.image_files = image_paths
45 | # self.image_files = all_image_paths
46 |
47 | self.if_self_training = if_self_training
48 | assert os.path.exists(self._osd_object_path), \
49 | 'osd_object path does not exist: {}'.format(self._osd_object_path)
50 |
51 | def process_label(self, foreground_labels):
52 | """ Process foreground_labels
53 | - Map the foreground_labels to {0, 1, ..., K-1}
54 |
55 | @param foreground_labels: a [H x W] numpy array of labels
56 |
57 | @return: foreground_labels
58 | """
59 | # Find the unique (nonnegative) foreground_labels, map them to {0, ..., K-1}
60 | unique_nonnegative_indices = np.unique(foreground_labels)
61 | mapped_labels = foreground_labels.copy()
62 | for k in range(unique_nonnegative_indices.shape[0]):
63 | mapped_labels[foreground_labels == unique_nonnegative_indices[k]] = k
64 | foreground_labels = mapped_labels
65 | return foreground_labels
66 |
67 | def __len__(self):
68 | return len(self.image_files)
69 |
70 | def check_empty(self, image_paths):
71 | new_image_paths = []
72 | for filename in image_paths:
73 | labels_filename = str(filename).replace('image_color', 'annotation')
74 | annotation = Image.open(labels_filename)
75 | foreground_labels = np.array(annotation)
76 | # mask table as background
77 | foreground_labels[foreground_labels == 1] = 0
78 | if 'table' in labels_filename:
79 | foreground_labels[foreground_labels == 2] = 0
80 | gt_mask = self.process_label(foreground_labels)
81 | if not np.all(gt_mask == 0):
82 | new_image_paths.append(filename)
83 | return new_image_paths
84 |
85 | def __getitem__(self, idx):
86 | filename = self.image_files[idx]
87 |
88 | image = cv2.imread(filename)
89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
90 |
91 | if self.cfg.get_prompt:
92 | image_info = {}
93 | height, width, _ = image.shape
94 | image_info["file_path"] = filename
95 | image_info["height"] = height
96 | image_info["width"] = width
97 | return idx, image_info, image
98 |
99 | labels_filename = filename.replace('image_color', 'annotation')
100 | annotation = Image.open(labels_filename)
101 | foreground_labels = np.array(annotation)
102 |
103 | # mask table as background
104 | foreground_labels[foreground_labels == 1] = 0
105 | if 'table' in labels_filename:
106 | foreground_labels[foreground_labels == 2] = 0
107 | gt_mask = self.process_label(foreground_labels)
108 |
109 | bboxes = []
110 | masks = []
111 | categories = []
112 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
113 | assert gt_masks.sum() == (gt_mask > 0).sum()
114 | for mask in gt_masks:
115 | masks.append(mask)
116 | x, y, w, h = cv2.boundingRect(mask)
117 | bboxes.append([x, y, x + w, y + h])
118 | categories.append("0")
119 |
120 | if self.if_self_training:
121 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
122 |
123 | if self.transform:
124 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
125 | image_strong = self.transform.transform_image(image_strong)
126 |
127 | bboxes_weak = np.stack(bboxes_weak, axis=0)
128 | masks_weak = np.stack(masks_weak, axis=0)
129 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
130 | else:
131 | if self.transform:
132 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
133 |
134 | bboxes = np.stack(bboxes, axis=0)
135 | masks = np.stack(masks, axis=0)
136 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
137 |
138 |
139 | class OSDObjectwithCoarse(OSDObject):
140 |
141 | def __getitem__(self, idx):
142 | filename = self.image_files[idx]
143 |
144 | image = cv2.imread(filename)
145 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
146 |
147 | labels_filename = filename.replace('image_color', 'annotation')
148 | annotation = Image.open(labels_filename)
149 | foreground_labels = np.array(annotation)
150 |
151 | # mask table as background
152 | foreground_labels[foreground_labels == 1] = 0
153 | if 'table' in labels_filename:
154 | foreground_labels[foreground_labels == 2] = 0
155 | gt_mask = self.process_label(foreground_labels)
156 |
157 | bboxes = []
158 | masks = []
159 | categories = []
160 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
161 | assert gt_masks.sum() == (gt_mask > 0).sum()
162 | for mask in gt_masks:
163 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
164 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
165 | num_vertices = num_vertices if num_vertices > 3 else 3
166 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
167 | approx = approx.squeeze(1)
168 |
169 | coordinates = np.array(approx)
170 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
171 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
172 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
173 | if x_min == x_max or y_min == y_max:
174 | x, y, w, h = cv2.boundingRect(mask)
175 | bboxes.append([x, y, x + w, y + h])
176 | else:
177 | bboxes.append([x_min, y_min, x_max, y_max])
178 |
179 | masks.append(mask)
180 | categories.append("0")
181 |
182 | if self.if_self_training:
183 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
184 |
185 | if self.transform:
186 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
187 | image_strong = self.transform.transform_image(image_strong)
188 |
189 | bboxes_weak = np.stack(bboxes_weak, axis=0)
190 | masks_weak = np.stack(masks_weak, axis=0)
191 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
192 | else:
193 | if self.transform:
194 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
195 |
196 | bboxes = np.stack(bboxes, axis=0)
197 | masks = np.stack(masks, axis=0)
198 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
199 |
200 |
201 | def load_datasets(cfg, img_size):
202 | transform = ResizeAndPad(img_size)
203 | train = OSDObject(
204 | cfg,
205 | root_dir=cfg.datasets.robot.OSD,
206 | transform=transform,
207 | split=cfg.split,
208 | training=True,
209 | if_self_training=cfg.augment,
210 | )
211 | val = OSDObject(
212 | cfg,
213 | root_dir=cfg.datasets.robot.OSD,
214 | transform=transform,
215 | split=cfg.split,
216 | )
217 | train_dataloader = DataLoader(
218 | train,
219 | batch_size=cfg.batch_size,
220 | shuffle=True,
221 | num_workers=cfg.num_workers,
222 | collate_fn=collate_fn,
223 | )
224 | val_dataloader = DataLoader(
225 | val,
226 | batch_size=cfg.val_batchsize,
227 | shuffle=False,
228 | num_workers=cfg.num_workers,
229 | collate_fn=collate_fn,
230 | )
231 | return train_dataloader, val_dataloader
232 |
233 |
234 | def load_datasets_coarse(cfg, img_size):
235 | transform = ResizeAndPad(img_size)
236 | train = OSDObjectwithCoarse(
237 | cfg,
238 | root_dir=cfg.datasets.robot.OSD,
239 | transform=transform,
240 | split=cfg.split,
241 | training=True,
242 | if_self_training=cfg.augment,
243 | )
244 | val = OSDObjectwithCoarse(
245 | cfg,
246 | root_dir=cfg.datasets.robot.OSD,
247 | transform=transform,
248 | split=cfg.split,
249 | )
250 | train_dataloader = DataLoader(
251 | train,
252 | batch_size=cfg.batch_size,
253 | shuffle=True,
254 | num_workers=cfg.num_workers,
255 | collate_fn=collate_fn,
256 | )
257 | val_dataloader = DataLoader(
258 | val,
259 | batch_size=cfg.val_batchsize,
260 | shuffle=False,
261 | num_workers=cfg.num_workers,
262 | collate_fn=collate_fn,
263 | )
264 | return train_dataloader, val_dataloader
265 |
266 |
267 | def load_datasets_prompt(cfg, img_size):
268 | transform = ResizeAndPad(img_size)
269 | train = OSDObject(
270 | cfg,
271 | root_dir=cfg.datasets.robot.OSD,
272 | transform=transform,
273 | split=cfg.split,
274 | training=True,
275 | if_self_training=cfg.augment,
276 | )
277 | train_dataloader = DataLoader(
278 | train,
279 | batch_size=cfg.batch_size,
280 | shuffle=True,
281 | num_workers=cfg.num_workers,
282 | collate_fn=collate_fn_,
283 | )
284 | return train_dataloader
285 |
--------------------------------------------------------------------------------
/datasets/PascalVOC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data import Dataset
8 | from skimage.draw import polygon2mask
9 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask
10 |
11 |
12 | class PascalVOCDataset(Dataset):
13 | def __init__(self, cfg, root_dir, transform=None, split=False, training=False, if_self_training=False):
14 | self.cfg = cfg
15 | self.root_dir = root_dir
16 | self.transform = transform
17 |
18 | segment_root = os.path.join(root_dir, "SegmentationObject")
19 | all_anns = [os.path.join(segment_root, f) for f in os.listdir(segment_root) if f.endswith('.png')]
20 | all_anns = sorted(all_anns)
21 |
22 | if split:
23 | train_list = []
24 | eval_list = []
25 | while all_anns:
26 | for _ in range(6):
27 | if all_anns:
28 | train_list.append(all_anns.pop(0))
29 | if all_anns:
30 | eval_list.append(all_anns.pop(0))
31 |
32 | if training:
33 | random.shuffle(train_list)
34 | image_ids = train_list
35 | else:
36 | random.shuffle(eval_list)
37 | image_ids = eval_list
38 | else:
39 | image_ids = all_anns
40 |
41 | self.image_ids = image_ids
42 |
43 | self.if_self_training = if_self_training
44 |
45 | def __len__(self):
46 | return len(self.image_ids)
47 |
48 | def get_pascal_labels(self):
49 | """Load the mapping that associates pascal classes with label colors
50 |
51 | Returns:
52 | np.ndarray with dimensions (21, 3)
53 | """
54 | return np.asarray(
55 | [
56 | [0, 0, 0],
57 | [128, 0, 0],
58 | [0, 128, 0],
59 | [128, 128, 0],
60 | [0, 0, 128],
61 | [128, 0, 128],
62 | [0, 128, 128],
63 | [128, 128, 128],
64 | [64, 0, 0],
65 | [192, 0, 0],
66 | [64, 128, 0],
67 | [192, 128, 0],
68 | [64, 0, 128],
69 | [192, 0, 128],
70 | [64, 128, 128],
71 | [192, 128, 128],
72 | [0, 64, 0],
73 | [128, 64, 0],
74 | [0, 192, 0],
75 | [128, 192, 0],
76 | [0, 64, 128],
77 | ]
78 | )
79 |
80 | def encode_segmap(self, mask):
81 | """Encode segmentation label images as pascal classes
82 |
83 | Args:
84 | mask (np.ndarray): raw segmentation label image of dimension
85 | (M, N, 3), in which the Pascal classes are encoded as colours.
86 |
87 | Returns:
88 | (np.ndarray): class map with dimensions (M,N), where the value at
89 | a given location is the integer denoting the class index.
90 | """
91 | mask = mask.astype(int)
92 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
93 | for ii, label in enumerate(self.get_pascal_labels()):
94 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
95 | label_mask = label_mask.astype(int)
96 | return label_mask
97 |
98 | def __getitem__(self, idx):
99 |
100 | anno_path = self.image_ids[idx]
101 | image_path = anno_path.replace("SegmentationObject", "JPEGImages").replace(".png", ".jpg")
102 |
103 | image = cv2.imread(image_path)
104 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
105 | gt_mask = cv2.imread(anno_path)
106 | gt_labels = self.encode_segmap(gt_mask)
107 |
108 | if self.cfg.get_prompt:
109 | image_info = {}
110 | height, width, _ = image.shape
111 | image_info["file_path"] = image_path
112 | image_info["height"] = height
113 | image_info["width"] = width
114 | return idx, image_info, image
115 |
116 | masks = []
117 | bboxes = []
118 | categories = []
119 | gt_masks = decode_mask(torch.tensor(gt_labels[None, :, :])).numpy().astype(np.uint8)
120 | assert gt_masks.sum() == (gt_labels > 0).sum()
121 | for mask in gt_masks:
122 | masks.append(mask)
123 | x, y, w, h = cv2.boundingRect(mask)
124 | bboxes.append([x, y, x + w, y + h])
125 | categories.append("0")
126 |
127 | if self.if_self_training:
128 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
129 |
130 | if self.transform:
131 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
132 | image_strong = self.transform.transform_image(image_strong)
133 |
134 | bboxes_weak = np.stack(bboxes_weak, axis=0)
135 | masks_weak = np.stack(masks_weak, axis=0)
136 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
137 | else:
138 | if self.transform:
139 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
140 |
141 | bboxes = np.stack(bboxes, axis=0)
142 | masks = np.stack(masks, axis=0)
143 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
144 |
145 |
146 | class PascalVOCDatasetwithCoarse(PascalVOCDataset):
147 |
148 | def __getitem__(self, idx):
149 | anno_path = self.image_ids[idx]
150 | image_path = anno_path.replace("SegmentationObject", "JPEGImages").replace(".png", ".jpg")
151 |
152 | image = cv2.imread(image_path)
153 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
154 | gt_mask = cv2.imread(anno_path)
155 | gt_labels = self.encode_segmap(gt_mask)
156 |
157 | masks = []
158 | bboxes = []
159 | approxes =[]
160 | categories = []
161 | gt_masks = decode_mask(torch.tensor(gt_labels[None, :, :])).numpy().astype(np.uint8)
162 | assert gt_masks.sum() == (gt_labels > 0).sum()
163 | for mask in gt_masks:
164 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
165 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
166 | num_vertices = num_vertices if num_vertices > 5 else 5
167 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
168 | approx = approx.squeeze(1)
169 | coordinates = np.array(approx)
170 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
171 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
172 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
173 | if x_min == x_max or y_min == y_max:
174 | x, y, w, h = cv2.boundingRect(mask)
175 | bboxes.append([x, y, x + w, y + h])
176 | else:
177 | bboxes.append([x_min, y_min, x_max, y_max])
178 | masks.append(mask)
179 | categories.append("0")
180 | approxes.append(approx)
181 |
182 | if self.if_self_training:
183 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
184 |
185 | if self.transform:
186 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
187 | image_strong = self.transform.transform_image(image_strong)
188 |
189 | bboxes_weak = np.stack(bboxes_weak, axis=0)
190 | masks_weak = np.stack(masks_weak, axis=0)
191 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
192 |
193 | elif self.cfg.visual:
194 | image_name = os.path.splitext(os.path.basename(image_path))[0]
195 |
196 | origin_image = image
197 | origin_approxes = approxes
198 | origin_masks = masks
199 | if self.transform:
200 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual)
201 |
202 | bboxes = np.stack(bboxes, axis=0)
203 | masks = np.stack(masks, axis=0)
204 | origin_masks = np.stack(origin_masks, axis=0)
205 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float()
206 |
207 | else:
208 | if self.transform:
209 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
210 |
211 | bboxes = np.stack(bboxes, axis=0)
212 | masks = np.stack(masks, axis=0)
213 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
214 |
215 |
216 | def load_datasets(cfg, img_size):
217 | transform = ResizeAndPad(img_size)
218 | val = PascalVOCDataset(
219 | cfg,
220 | root_dir=cfg.datasets.PascalVOC.root_dir,
221 | transform=transform,
222 | split=cfg.split,
223 | )
224 | train = PascalVOCDataset(
225 | cfg,
226 | root_dir=cfg.datasets.PascalVOC.root_dir,
227 | transform=transform,
228 | split=cfg.split,
229 | training=True,
230 | if_self_training=cfg.augment,
231 | )
232 | val_dataloader = DataLoader(
233 | val,
234 | batch_size=cfg.val_batchsize,
235 | shuffle=False,
236 | num_workers=cfg.num_workers,
237 | collate_fn=collate_fn,
238 | )
239 | train_dataloader = DataLoader(
240 | train,
241 | batch_size=cfg.batch_size,
242 | shuffle=True,
243 | num_workers=cfg.num_workers,
244 | collate_fn=collate_fn,
245 | )
246 | return train_dataloader, val_dataloader
247 |
248 |
249 | def load_datasets_coarse(cfg, img_size):
250 | transform = ResizeAndPad(img_size)
251 | val = PascalVOCDatasetwithCoarse(
252 | cfg,
253 | root_dir=cfg.datasets.PascalVOC.root_dir,
254 | transform=transform,
255 | split=cfg.split,
256 | )
257 | train = PascalVOCDatasetwithCoarse(
258 | cfg,
259 | root_dir=cfg.datasets.PascalVOC.root_dir,
260 | transform=transform,
261 | split=cfg.split,
262 | training=True,
263 | if_self_training=cfg.augment,
264 | )
265 | val_dataloader = DataLoader(
266 | val,
267 | batch_size=cfg.val_batchsize,
268 | shuffle=False,
269 | num_workers=cfg.num_workers,
270 | collate_fn=collate_fn,
271 | )
272 | train_dataloader = DataLoader(
273 | train,
274 | batch_size=cfg.batch_size,
275 | shuffle=True,
276 | num_workers=cfg.num_workers,
277 | collate_fn=collate_fn,
278 | )
279 | return train_dataloader, val_dataloader
280 |
281 |
282 | def load_datasets_prompt(cfg, img_size):
283 | transform = ResizeAndPad(img_size)
284 | train = PascalVOCDataset(
285 | cfg,
286 | root_dir=cfg.datasets.PascalVOC.root_dir,
287 | transform=transform,
288 | split=cfg.split,
289 | training=True,
290 | if_self_training=cfg.augment,
291 | )
292 | train_dataloader = DataLoader(
293 | train,
294 | batch_size=cfg.batch_size,
295 | shuffle=True,
296 | num_workers=cfg.num_workers,
297 | collate_fn=collate_fn_,
298 | )
299 | return train_dataloader
--------------------------------------------------------------------------------
/datasets/Polyp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 | import glob
5 | import json
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | from torch.utils.data import Dataset, DataLoader
10 | from skimage.draw import polygon2mask
11 |
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, decode_mask, collate_fn_
13 |
14 |
15 | class PolypDataset(Dataset):
16 | def __init__(self, cfg, root_dir, annotation_file, transform=None, split=False, training=False, if_self_training=False):
17 | self.cfg = cfg
18 | self.root_dir = root_dir
19 | self.transform = transform
20 | with open(annotation_file, "r") as ann_file:
21 | anns = json.load(ann_file)
22 | all_images = list(anns.keys())
23 |
24 | if split:
25 | train_images = []
26 | eval_images = []
27 | while all_images:
28 | for _ in range(5):
29 | if all_images:
30 | train_images.append(all_images.pop(0))
31 | if all_images:
32 | eval_images.append(all_images.pop(0))
33 |
34 | if training:
35 | random.shuffle(train_images)
36 | images = train_images
37 | else:
38 | random.shuffle(eval_images)
39 | images = eval_images
40 | else:
41 | images = all_images
42 |
43 | self.images = images
44 | self.anns = anns
45 | self.if_self_training = if_self_training
46 |
47 | def __len__(self):
48 | return len(self.images)
49 |
50 | def find_points_outside_bbox(self, mask, bboxes):
51 | points_outside_bbox = np.where(mask != 0)
52 | for bbox in bboxes:
53 | x_min, y_min, x_max, y_max = bbox
54 | points_outside_bbox = (points_outside_bbox[0][(points_outside_bbox[0] < y_min) | (points_outside_bbox[0] >= y_max)],
55 | points_outside_bbox[1][(points_outside_bbox[1] < x_min) | (points_outside_bbox[1] >= x_max)])
56 | return points_outside_bbox
57 |
58 | def __getitem__(self, idx):
59 | name = self.images[idx]
60 | image_path = os.path.join(self.root_dir, "images", name + ".jpg")
61 |
62 | image = cv2.imread(image_path)
63 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
64 |
65 | if self.cfg.get_prompt:
66 | image_info = {}
67 | height, width, _ = image.shape
68 | image_info["file_path"] = image_path
69 | image_info["height"] = height
70 | image_info["width"] = width
71 | return idx, image_info, image
72 |
73 | gt_path = image_path.replace("images", "masks")
74 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
75 | gt_mask[gt_mask > 0] = 255
76 |
77 | masks = []
78 | bboxes = []
79 | categories = []
80 | anns = self.anns[name]
81 | ann_bboxes = anns["bbox"]
82 |
83 | for i, bbox in enumerate(ann_bboxes):
84 | x_min = bbox["xmin"]
85 | x_max = bbox["xmax"]
86 | y_min = bbox["ymin"]
87 | y_max = bbox["ymax"]
88 | gt_mask[y_min:y_max, x_min:x_max][gt_mask[y_min:y_max, x_min:x_max] > 0] = i + 1
89 | bboxes.append([x_min, y_min, x_max, y_max])
90 | categories.append(bbox["label"])
91 |
92 | gt_mask[gt_mask > i + 1] = 0
93 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
94 | assert gt_masks.sum() == (gt_mask > 0).sum()
95 | assert len(ann_bboxes) == gt_masks.shape[0]
96 | masks = [mask for mask in gt_masks]
97 |
98 | if self.if_self_training:
99 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
100 |
101 | if self.transform:
102 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
103 | image_strong = self.transform.transform_image(image_strong)
104 |
105 | bboxes_weak = np.stack(bboxes_weak, axis=0)
106 | masks_weak = np.stack(masks_weak, axis=0)
107 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
108 | else:
109 | if self.transform:
110 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
111 |
112 | bboxes = np.stack(bboxes, axis=0)
113 | masks = np.stack(masks, axis=0)
114 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
115 |
116 |
117 | class PolypDatasetwithCoarse(PolypDataset):
118 |
119 | def __getitem__(self, idx):
120 | name = self.images[idx]
121 | image_path = os.path.join(self.root_dir, "images", name + ".jpg")
122 |
123 | image = cv2.imread(image_path)
124 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
125 |
126 | gt_path = image_path.replace("images", "masks")
127 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
128 | gt_mask[gt_mask > 0] = 255
129 |
130 | masks = []
131 | bboxes = []
132 | categories = []
133 | anns = self.anns[name]
134 | ann_bboxes = anns["bbox"]
135 |
136 | for i, bbox in enumerate(ann_bboxes):
137 | x_min = bbox["xmin"]
138 | x_max = bbox["xmax"]
139 | y_min = bbox["ymin"]
140 | y_max = bbox["ymax"]
141 | gt_mask[y_min:y_max, x_min:x_max][gt_mask[y_min:y_max, x_min:x_max] > 0] = i + 1
142 | # bboxes.append([x_min, y_min, x_max, y_max])
143 | categories.append(bbox["label"])
144 |
145 | gt_mask[gt_mask > i + 1] = 0
146 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8)
147 | assert gt_masks.sum() == (gt_mask > 0).sum()
148 | assert len(ann_bboxes) == gt_masks.shape[0]
149 |
150 | for mask in gt_masks:
151 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
152 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
153 | num_vertices = num_vertices if num_vertices > 3 else 3
154 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
155 | approx = approx.squeeze(1)
156 |
157 | coordinates = np.array(approx)
158 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
159 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
160 | if x_min == x_max or y_min == y_max:
161 | x, y, w, h = cv2.boundingRect(mask)
162 | bboxes.append([x, y, x + w, y + h])
163 | else:
164 | bboxes.append([x_min, y_min, x_max, y_max])
165 |
166 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
167 |
168 | masks.append(mask)
169 |
170 | masks = [mask for mask in gt_masks]
171 |
172 | if self.if_self_training:
173 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
174 |
175 | if self.transform:
176 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
177 | image_strong = self.transform.transform_image(image_strong)
178 |
179 | bboxes_weak = np.stack(bboxes_weak, axis=0)
180 | masks_weak = np.stack(masks_weak, axis=0)
181 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
182 | else:
183 | if self.transform:
184 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
185 |
186 | bboxes = np.stack(bboxes, axis=0)
187 | masks = np.stack(masks, axis=0)
188 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
189 |
190 |
191 | def load_datasets(cfg, img_size):
192 | transform = ResizeAndPad(img_size)
193 | val = PolypDataset(
194 | cfg,
195 | root_dir=cfg.datasets.Polyp.root_dir,
196 | annotation_file=cfg.datasets.Polyp.annotation_file,
197 | transform=transform,
198 | split=cfg.split,
199 | )
200 | train = PolypDataset(
201 | cfg,
202 | root_dir=cfg.datasets.Polyp.root_dir,
203 | annotation_file=cfg.datasets.Polyp.annotation_file,
204 | transform=transform,
205 | split=cfg.split,
206 | training=True,
207 | if_self_training=cfg.augment,
208 | )
209 | val_dataloader = DataLoader(
210 | val,
211 | batch_size=cfg.val_batchsize,
212 | shuffle=False,
213 | num_workers=cfg.num_workers,
214 | collate_fn=collate_fn,
215 | )
216 | train_dataloader = DataLoader(
217 | train,
218 | batch_size=cfg.batch_size,
219 | shuffle=True,
220 | num_workers=cfg.num_workers,
221 | collate_fn=collate_fn,
222 | )
223 | return train_dataloader, val_dataloader
224 |
225 |
226 | def load_datasets_coarse(cfg, img_size):
227 | transform = ResizeAndPad(img_size)
228 | val = PolypDatasetwithCoarse(
229 | cfg,
230 | root_dir=cfg.datasets.Polyp.root_dir,
231 | annotation_file=cfg.datasets.Polyp.annotation_file,
232 | transform=transform,
233 | split=cfg.split,
234 | )
235 | train = PolypDatasetwithCoarse(
236 | cfg,
237 | root_dir=cfg.datasets.Polyp.root_dir,
238 | annotation_file=cfg.datasets.Polyp.annotation_file,
239 | transform=transform,
240 | split=cfg.split,
241 | training=True,
242 | if_self_training=cfg.augment,
243 | )
244 | val_dataloader = DataLoader(
245 | val,
246 | batch_size=cfg.val_batchsize,
247 | shuffle=False,
248 | num_workers=cfg.num_workers,
249 | collate_fn=collate_fn,
250 | )
251 | train_dataloader = DataLoader(
252 | train,
253 | batch_size=cfg.batch_size,
254 | shuffle=True,
255 | num_workers=cfg.num_workers,
256 | collate_fn=collate_fn,
257 | )
258 | return train_dataloader, val_dataloader
259 |
260 |
261 | def load_datasets_prompt(cfg, img_size):
262 | transform = ResizeAndPad(img_size)
263 | train = PolypDataset(
264 | cfg,
265 | root_dir=cfg.datasets.Polyp.root_dir,
266 | annotation_file=cfg.datasets.Polyp.annotation_file,
267 | transform=transform,
268 | split=cfg.split,
269 | training=True,
270 | if_self_training=cfg.augment,
271 | )
272 | train_dataloader = DataLoader(
273 | train,
274 | batch_size=cfg.batch_size,
275 | shuffle=True,
276 | num_workers=cfg.num_workers,
277 | collate_fn=collate_fn_,
278 | )
279 | return train_dataloader
280 |
--------------------------------------------------------------------------------
/datasets/SA_1B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import random
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 | from pycocotools.coco import COCO
10 | from pycocotools import mask as mask_utils
11 | from skimage.draw import polygon2mask
12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_
13 |
14 |
15 | class SADataset(Dataset):
16 | def __init__(self, cfg, root_dir, transform=None, training=False):
17 | self.root_dir = root_dir
18 | self.transform = transform
19 | self.image_list = []
20 | self.sa_list = [
21 | sa for sa in os.listdir(root_dir)
22 | if os.path.isdir(os.path.join(root_dir, sa))
23 | ]
24 |
25 | for sa in self.sa_list:
26 | sa_dir = os.path.join(root_dir, sa)
27 | self.image_list.extend(
28 | [
29 | os.path.join(sa_dir, f)
30 | for f in os.listdir(sa_dir)
31 | if f.endswith(".jpg")
32 | ]
33 | )
34 |
35 | self.image_list = random.sample(self.image_list, 200)
36 | self.if_self_training = training
37 |
38 | def __len__(self):
39 | return len(self.image_list)
40 |
41 | def __getitem__(self, idx):
42 | image_path = self.image_list[idx]
43 | image = cv2.imread(image_path)
44 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
45 |
46 | if self.cfg.get_prompt:
47 | image_info = {}
48 | height, width, _ = image.shape
49 | image_info["file_path"] = image_path
50 | image_info["height"] = height
51 | image_info["width"] = width
52 | return idx, image_info, image
53 |
54 | json_path = image_path.replace(".jpg", ".json")
55 | with open(json_path, "r") as f:
56 | annotations = json.load(f)
57 |
58 | bboxes = []
59 | masks = []
60 | categories = []
61 | for anno in annotations["annotations"]:
62 | x, y, w, h = anno["bbox"]
63 | bboxes.append([x, y, x + w, y + h])
64 | mask = mask_utils.decode(anno["segmentation"])
65 | masks.append(mask)
66 | categories.append("0")
67 |
68 | if self.if_self_training:
69 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
70 |
71 | if self.transform:
72 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
73 | image_strong = self.transform.transform_image(image_strong)
74 |
75 | bboxes_weak = np.stack(bboxes_weak, axis=0)
76 | masks_weak = np.stack(masks_weak, axis=0)
77 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
78 | else:
79 | if self.transform:
80 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
81 |
82 | bboxes = np.stack(bboxes, axis=0)
83 | masks = np.stack(masks, axis=0)
84 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
85 |
86 |
87 | class SADatasetwithCoarse(SADataset):
88 | def __getitem__(self, idx):
89 | image_path = self.image_list[idx]
90 | image = cv2.imread(image_path)
91 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92 |
93 | json_path = image_path.replace(".jpg", ".json")
94 | with open(json_path, "r") as f:
95 | annotations = json.load(f)
96 |
97 | bboxes = []
98 | masks = []
99 | categories = []
100 | for anno in annotations["annotations"]:
101 | mask = mask_utils.decode(anno["segmentation"])
102 |
103 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
104 | num_vertices = 0.05 * cv2.arcLength(contours[0], True)
105 | num_vertices = num_vertices if num_vertices > 5 else 5
106 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y]
107 | approx = approx.squeeze(1)
108 | coordinates = np.array(approx)
109 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0])
110 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1])
111 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype)
112 | if x_min == x_max or y_min == y_max:
113 | x, y, w, h = cv2.boundingRect(mask)
114 | bboxes.append([x, y, x + w, y + h])
115 | else:
116 | bboxes.append([x_min, y_min, x_max, y_max])
117 |
118 | masks.append(mask)
119 | categories.append("0")
120 |
121 | if self.if_self_training:
122 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories)
123 |
124 | if self.transform:
125 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak))
126 | image_strong = self.transform.transform_image(image_strong)
127 |
128 | bboxes_weak = np.stack(bboxes_weak, axis=0)
129 | masks_weak = np.stack(masks_weak, axis=0)
130 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float()
131 | else:
132 | if self.transform:
133 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes))
134 |
135 | bboxes = np.stack(bboxes, axis=0)
136 | masks = np.stack(masks, axis=0)
137 | return image, torch.tensor(bboxes), torch.tensor(masks).float()
138 |
139 |
140 | def load_datasets(cfg, img_size):
141 | transform = ResizeAndPad(img_size)
142 | train = SADataset(
143 | cfg,
144 | root_dir=cfg.datasets.sa.root_dir,
145 | transform=transform,
146 | training=True,
147 | if_self_training=cfg.augment,
148 | )
149 | val = SADataset(
150 | cfg,
151 | root_dir=cfg.datasets.sa.root_dir,
152 | transform=transform,
153 | )
154 | train_dataloader = DataLoader(
155 | train,
156 | batch_size=cfg.batch_size,
157 | shuffle=True,
158 | num_workers=cfg.num_workers,
159 | collate_fn=collate_fn,
160 | )
161 | val_dataloader = DataLoader(
162 | val,
163 | batch_size=cfg.batch_size,
164 | shuffle=False,
165 | num_workers=cfg.num_workers,
166 | collate_fn=collate_fn,
167 | )
168 | return train_dataloader, val_dataloader
169 |
170 |
171 | def load_datasets_coarse(cfg, img_size):
172 | transform = ResizeAndPad(img_size)
173 | train = SADatasetwithCoarse(
174 | cfg,
175 | root_dir=cfg.datasets.sa.root_dir,
176 | transform=transform,
177 | training=True,
178 | if_self_training=cfg.augment,
179 | )
180 | val = SADatasetwithCoarse(
181 | cfg,
182 | root_dir=cfg.datasets.sa.root_dir,
183 | transform=transform,
184 | )
185 | train_dataloader = DataLoader(
186 | train,
187 | batch_size=cfg.batch_size,
188 | shuffle=True,
189 | num_workers=cfg.num_workers,
190 | collate_fn=collate_fn,
191 | )
192 | val_dataloader = DataLoader(
193 | val,
194 | batch_size=cfg.batch_size,
195 | shuffle=False,
196 | num_workers=cfg.num_workers,
197 | collate_fn=collate_fn,
198 | )
199 | return train_dataloader, val_dataloader
200 |
201 |
202 | def load_datasets(cfg, img_size):
203 | transform = ResizeAndPad(img_size)
204 | train = SADataset(
205 | cfg,
206 | root_dir=cfg.datasets.sa.root_dir,
207 | transform=transform,
208 | training=True,
209 | if_self_training=cfg.augment,
210 | )
211 | train_dataloader = DataLoader(
212 | train,
213 | batch_size=cfg.batch_size,
214 | shuffle=True,
215 | num_workers=cfg.num_workers,
216 | collate_fn=collate_fn_,
217 | )
218 | return train_dataloader
219 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | def call_load_dataset(cfg):
2 | name = cfg.dataset
3 |
4 | key = name.split("-")[0]
5 | module_name = f"datasets.{key}"
6 | function_name = "load_datasets"
7 |
8 | if cfg.visual:
9 | function_name = function_name + "_" + "visual"
10 |
11 | if cfg.prompt == "coarse":
12 | function_name = function_name + "_" + "coarse"
13 |
14 | exec(f"from {module_name} import {function_name}")
15 | func = eval(function_name)
16 | return func
17 |
18 |
19 | def call_load_dataset_prompt(cfg):
20 | name = cfg.dataset
21 |
22 | key = name.split("-")[0]
23 | module_name = f"datasets.{key}"
24 | function_name = "load_datasets"
25 |
26 | function_name = function_name + "_" + "prompt"
27 |
28 | exec(f"from {module_name} import {function_name}")
29 | func = eval(function_name)
30 | return func
31 |
32 |
33 | def call_load_dataset_val(cfg):
34 | name = cfg.dataset
35 |
36 | key = name.split("-")[0]
37 | module_name = f"datasets.{key}"
38 | function_name = "load_datasets"
39 |
40 | function_name = function_name + "_" + "val"
41 |
42 | exec(f"from {module_name} import {function_name}")
43 | func = eval(function_name)
44 | return func
45 |
--------------------------------------------------------------------------------
/datasets/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import albumentations as A
6 | import torchvision.transforms as transforms
7 | from segment_anything.utils.transforms import ResizeLongestSide
8 | from imagecorruptions import corrupt, get_corruption_names
9 |
10 | # A.RandomCropNearBBox()
11 | # A.BBoxSafeRandomCrop()
12 | # A.RandomSizedBBoxSafeCrop()
13 |
14 | weak_transforms = A.Compose(
15 | [
16 | A.Flip(),
17 | A.HorizontalFlip(),
18 | A.VerticalFlip(),
19 | # A.BBoxSafeRandomCrop()
20 | ],
21 | bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]),
22 | # keypoint_params=A.KeypointParams(format='xy')
23 | )
24 |
25 |
26 | strong_transforms = A.Compose(
27 | [
28 | A.Posterize(),
29 | A.Equalize(),
30 | A.Sharpen(),
31 | A.Solarize(),
32 | A.RandomBrightnessContrast(),
33 | A.RandomShadow(),
34 | ]
35 | )
36 |
37 |
38 | class ResizeAndPad:
39 |
40 | def __init__(self, target_size):
41 | self.target_size = target_size
42 | self.transform = ResizeLongestSide(target_size)
43 | self.to_tensor = transforms.ToTensor()
44 |
45 | def __call__(self, image, masks, bboxes=None, visual=False):
46 | # Resize image and masks
47 | og_h, og_w, _ = image.shape
48 | image = self.transform.apply_image(image)
49 | masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks]
50 | image = self.to_tensor(image)
51 |
52 | # Pad image and masks to form a square
53 | _, h, w = image.shape
54 | max_dim = max(w, h)
55 | pad_w = (max_dim - w) // 2
56 | pad_h = (max_dim - h) // 2
57 |
58 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
59 | image = transforms.Pad(padding)(image)
60 | masks = [transforms.Pad(padding)(mask) for mask in masks]
61 |
62 | # Adjust bounding boxes
63 | if bboxes is not None:
64 | bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w))
65 | bboxes = [
66 | [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h]
67 | for bbox in bboxes
68 | ]
69 | if visual:
70 | return padding, image, masks, bboxes
71 | else:
72 | return image, masks, bboxes
73 | else:
74 | if visual:
75 | return padding, image, masks
76 | else:
77 | return image, masks
78 |
79 | def transform_image(self, image):
80 | # Resize image and masks
81 | image = self.transform.apply_image(image)
82 | image = self.to_tensor(image)
83 |
84 | # Pad image and masks to form a square
85 | _, h, w = image.shape
86 | max_dim = max(w, h)
87 | pad_w = (max_dim - w) // 2
88 | pad_h = (max_dim - h) // 2
89 |
90 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
91 | image = transforms.Pad(padding)(image)
92 | return image
93 |
94 | def transform_coord(self, points, image):
95 | og_h, og_w, _ = image.shape
96 | coords = points.reshape(1, -1, 2)
97 | points = self.transform.apply_coords(coords, (og_h, og_w))
98 | return points.reshape(-1, 2)
99 |
100 | def transform_coords(self, points, image, n):
101 | og_h, og_w, _ = image.shape
102 | coords = points.reshape(-1, n, 2)
103 | points = self.transform.apply_coords(coords, (og_h, og_w))
104 | return points.reshape(-1, n, 2)
105 |
106 |
107 | def corrupt_image(image, filename):
108 | file_name = os.path.basename(os.path.abspath(filename))
109 | file_path = os.path.dirname(os.path.abspath(filename))
110 | for corruption in get_corruption_names():
111 | corrupted = corrupt(image, severity=5, corruption_name=corruption)
112 | corrupt_path = file_path.replace(
113 | "val2017", os.path.join("corruption", corruption)
114 | )
115 | if not os.path.exists(corrupt_path):
116 | os.makedirs(corrupt_path, exist_ok=True)
117 | cv2.imwrite(os.path.join(corrupt_path, file_name), corrupted)
118 |
119 |
120 | def soft_transform(image: np.ndarray, bboxes: list, masks: list, categories: list):
121 | weak_transformed = weak_transforms(
122 | image=image, bboxes=bboxes, masks=masks, category_ids=categories
123 | )
124 | image_weak = weak_transformed["image"]
125 | bboxes_weak = weak_transformed["bboxes"]
126 | masks_weak = weak_transformed["masks"]
127 |
128 | strong_transformed = strong_transforms(image=image_weak)
129 | image_strong = strong_transformed["image"]
130 | return image_weak, bboxes_weak, masks_weak, image_strong
131 |
132 |
133 | def soft_transform_all(
134 | image: np.ndarray, bboxes: list, masks: list, points: list, categories: list
135 | ):
136 | weak_transformed = weak_transforms(
137 | image=image,
138 | bboxes=bboxes,
139 | masks=masks,
140 | category_ids=categories,
141 | keypoints=points,
142 | )
143 | image_weak = weak_transformed["image"]
144 | bboxes_weak = weak_transformed["bboxes"]
145 | masks_weak = weak_transformed["masks"]
146 | keypoints_weak = weak_transformed["keypoints"]
147 |
148 | strong_transformed = strong_transforms(image=image_weak)
149 | image_strong = strong_transformed["image"]
150 | return image_weak, bboxes_weak, masks_weak, keypoints_weak, image_strong
151 |
152 |
153 | def collate_fn(batch):
154 | if len(batch[0]) == 3:
155 | images, bboxes, masks = zip(*batch)
156 | images = torch.stack(images)
157 | return images, bboxes, masks
158 | elif len(batch[0]) == 4:
159 | images_soft, images, bboxes, masks = zip(*batch)
160 | images = torch.stack(images)
161 | images_soft = torch.stack(images_soft)
162 | return images_soft, images, bboxes, masks
163 | else:
164 | raise ValueError("Unexpected batch format")
165 |
166 |
167 | def collate_fn_(batch):
168 | return zip(*batch)
169 |
170 |
171 | def decode_mask(mask):
172 | """
173 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects
174 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects.
175 |
176 | Args:
177 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
178 |
179 | Returns:
180 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
181 | """
182 | unique_labels = torch.unique(mask)
183 | unique_labels = unique_labels[unique_labels != 0]
184 | n_objects = len(unique_labels)
185 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64)
186 | for i, label in enumerate(unique_labels):
187 | new_mask[i] = (mask == label).squeeze(0)
188 | return new_mask
189 |
190 |
191 | def encode_mask(mask):
192 | """
193 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects
194 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
195 |
196 | Args:
197 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
198 |
199 | Returns:
200 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
201 | """
202 | n_objects = mask.shape[0]
203 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64)
204 | for i in range(n_objects):
205 | new_mask[0][mask[i] == 1] = i + 1
206 | return new_mask
207 |
208 |
209 | if __name__ == "__main__":
210 | mask_encode = np.array([[[0, 0, 1], [2, 0, 2], [0, 3, 3]]])
211 | mask_decode = np.array(
212 | [
213 | [[0, 0, 1], [0, 0, 0], [0, 0, 0]],
214 | [[0, 0, 0], [1, 0, 1], [0, 0, 0]],
215 | [[0, 0, 0], [0, 0, 0], [0, 1, 1]],
216 | ]
217 | )
218 | encoded_mask = encode_mask(torch.tensor(mask_decode))
219 | decoded_mask = decode_mask(torch.tensor(mask_encode))
220 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from abc import ABC
6 |
7 | ALPHA = 0.8
8 | GAMMA = 2
9 |
10 |
11 | class FocalLoss(nn.Module):
12 |
13 | def __init__(self, weight=None, size_average=True):
14 | super().__init__()
15 |
16 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
17 | inputs = F.sigmoid(inputs)
18 | inputs = torch.clamp(inputs, min=0, max=1)
19 | #flatten label and prediction tensors
20 | inputs = inputs.view(-1)
21 | targets = targets.view(-1)
22 |
23 | BCE = F.binary_cross_entropy(inputs, targets, reduction='none')
24 | BCE_EXP = torch.exp(-BCE)
25 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
26 | focal_loss = focal_loss.mean()
27 |
28 | return focal_loss
29 |
30 |
31 | class DiceLoss(nn.Module):
32 |
33 | def __init__(self, weight=None, size_average=True):
34 | super().__init__()
35 |
36 | def forward(self, inputs, targets, smooth=1):
37 | inputs = F.sigmoid(inputs)
38 | inputs = torch.clamp(inputs, min=0, max=1)
39 | #flatten label and prediction tensors
40 | inputs = inputs.view(-1)
41 | targets = targets.view(-1)
42 |
43 | intersection = (inputs * targets).sum()
44 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
45 |
46 | return 1 - dice
47 |
48 |
49 | class ContraLoss(nn.Module):
50 |
51 | def __init__(self, temperature = 0.3, weight=None, size_average=True):
52 | super().__init__()
53 | self.temperature = temperature
54 | self.criterion = torch.nn.CrossEntropyLoss()
55 |
56 | def forward(self, embedd_x: torch.Tensor, embedd_y: torch.Tensor, mask_x: torch.Tensor, mask_y: torch.Tensor):
57 | x_embedding = self.norm_embed(embedd_x) # embedd_x: [256, 64, 64]
58 | y_embedding = self.norm_embed(embedd_y)
59 |
60 | x_masks = F.interpolate(mask_x, size=x_embedding.shape[-2:], mode="bilinear", align_corners=False).detach()
61 | y_masks = F.interpolate(mask_y, size=y_embedding.shape[-2:], mode="bilinear", align_corners=False).detach()
62 |
63 | x_masks = F.sigmoid(x_masks)
64 | x_masks = torch.clamp(x_masks, min=0, max=1)
65 | x_masks = x_masks > 0.5
66 | y_masks = F.sigmoid(y_masks)
67 | y_masks = torch.clamp(y_masks, min=0, max=1)
68 | y_masks = y_masks > 0.5
69 |
70 | # x_masks = self.add_background(x_masks)
71 | # y_masks = self.add_background(y_masks)
72 |
73 | sum_x = x_masks.sum(dim=[-1, -2]).clone()
74 | sum_y = y_masks.sum(dim=[-1, -2]).clone()
75 | sum_x[sum_x[:, 0] == 0.] = 1.
76 | sum_y[sum_y[:, 0] == 0.] = 1.
77 |
78 | multi_embedd_x = (x_embedding * x_masks).sum(dim=[-1, -2]) / sum_x # [n, 256, 64, 64] >> [n, 256]
79 | multi_embedd_y = (y_embedding * y_masks).sum(dim=[-1, -2]) / sum_y
80 |
81 | flatten_x = multi_embedd_x.view(multi_embedd_x.size(0), -1) # [n, 256]
82 | flatten_y = multi_embedd_y.view(multi_embedd_y.size(0), -1)
83 | # similarity_matrix = torch.matmul(multi_embedd_x, multi_embedd_y.T)
84 | similarity_matrix = F.cosine_similarity(flatten_x.unsqueeze(1), flatten_y.unsqueeze(0), dim=2) # [n, n]
85 |
86 | label_pos = torch.eye(x_masks.size(0)).bool().to(embedd_x.device)
87 | label_nag = ~label_pos
88 |
89 | similarity_matrix = similarity_matrix / self.temperature # [n insts, n insts]
90 | loss = -torch.log(
91 | similarity_matrix.masked_select(label_pos).exp().sum() /
92 | similarity_matrix.exp().sum()
93 | )
94 | # loss = -torch.log(
95 | # similarity_matrix.masked_select(label_pos).exp().sum()
96 | # )
97 | return loss
98 |
99 | def norm_embed(self, embedding: torch.Tensor):
100 | embedding = F.normalize(embedding, dim=0, p=2)
101 | return embedding
102 |
103 | def add_background(self, masks):
104 | mask_union = torch.max(masks, dim=0).values
105 | mask_complement = ~mask_union
106 | concatenated_masks = torch.cat((masks, mask_complement.unsqueeze(0)), dim=0)
107 | return concatenated_masks
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
7 | from sam_lora import LoRA_Sam
8 |
9 |
10 | class Model(nn.Module):
11 |
12 | def __init__(self, cfg):
13 | super().__init__()
14 | self.cfg = cfg
15 |
16 | def get_checkpoint(self, model_type):
17 | if model_type == "vit_b":
18 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_b_01ec64.pth")
19 | elif model_type == "vit_l":
20 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_l_0b3195.pth")
21 | elif model_type == "vit_h":
22 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_h_4b8939.pth")
23 | else:
24 | raise ValueError("Model type error!")
25 | return checkpoint
26 |
27 | def setup(self):
28 | checkpoint = self.get_checkpoint(self.cfg.model.type)
29 | self.model = sam_model_registry[self.cfg.model.type](checkpoint=checkpoint)
30 |
31 | self.model.train()
32 | if self.cfg.model.freeze.image_encoder:
33 | for param in self.model.image_encoder.parameters():
34 | param.requires_grad = False
35 | if self.cfg.model.freeze.prompt_encoder:
36 | for param in self.model.prompt_encoder.parameters():
37 | param.requires_grad = False
38 | if self.cfg.model.freeze.mask_decoder:
39 | for param in self.model.mask_decoder.parameters():
40 | param.requires_grad = False
41 |
42 | # self.finetune()
43 |
44 | def finetune(self):
45 | LoRA_Sam(self.model, 4)
46 | # self.set_norm_layer()
47 | # self.set_evp_adaptor_layer()
48 | # self.set_prompt_layer()
49 |
50 | def set_norm_layer(self):
51 | for name, param in self.model.image_encoder.named_parameters():
52 | if "norm" in name:
53 | param.requires_grad = True
54 |
55 | def set_evp_adaptor_layer(self):
56 | for param in self.model.image_encoder.prompt_generator.parameters():
57 | param.requires_grad = True
58 |
59 | def set_prompt_layer(self):
60 | self.model.image_encoder.Prompt_Tokens.requires_grad = True
61 |
62 | def reset_parameters(self) -> None:
63 | for name, param in self.model.named_parameters():
64 | if param.requires_grad == True:
65 | if "linear_a" in name:
66 | nn.init.kaiming_uniform_(param, a=math.sqrt(5))
67 | if "linear_b" in name:
68 | nn.init.zeros_(param)
69 |
70 | def forward(self, images, prompts):
71 | image_embeddings = self.encode(images)
72 | pred_masks, ious, res_masks = self.decode(prompts, image_embeddings)
73 | return image_embeddings, pred_masks, ious, res_masks
74 |
75 | def encode(self, images):
76 | _, _, H, W = images.shape
77 | self.image_shape = (H, W)
78 | image_embeddings = self.model.image_encoder(images)
79 | return image_embeddings
80 |
81 | def decode(self, prompts, image_embeddings):
82 | pred_masks = []
83 | ious = []
84 | res_masks = []
85 | for prompt, embedding in zip(prompts, image_embeddings):
86 | if isinstance(prompt, torch.Tensor):
87 | prompt = prompt.to(device=embedding.device)
88 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
89 | points=None,
90 | boxes=prompt,
91 | masks=None,
92 | )
93 | elif isinstance(prompt, tuple):
94 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
95 | points=prompt,
96 | boxes=None,
97 | masks=None,
98 | )
99 |
100 | low_res_masks, iou_predictions = self.model.mask_decoder(
101 | image_embeddings=embedding.unsqueeze(0),
102 | image_pe=self.model.prompt_encoder.get_dense_pe(),
103 | sparse_prompt_embeddings=sparse_embeddings,
104 | dense_prompt_embeddings=dense_embeddings,
105 | multimask_output=False,
106 | )
107 |
108 | masks = F.interpolate(
109 | low_res_masks,
110 | self.image_shape,
111 | mode="bilinear",
112 | align_corners=False,
113 | )
114 | pred_masks.append(masks.squeeze(1))
115 | ious.append(iou_predictions)
116 | res_masks.append(low_res_masks)
117 | return pred_masks, ious, res_masks
118 |
119 | def get_predictor(self):
120 | return SamPredictor(self.model)
121 |
122 | def get_generator(self, output_mode):
123 | return SamAutomaticMaskGenerator(self.model, output_mode=output_mode)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python-box==7.0.1
2 | pycocotools==2.0.6
3 | numpy==1.24.2
4 | opencv_python==4.7.0.72
5 | Pillow==9.3.0
6 | lightning==2.0.1
7 | segmentation-models-pytorch==0.3.2
8 | albumentations==1.3.1
9 | imagecorruptions==1.1.2
10 | safetensors==0.4.1
11 | torchsummary==1.5.1
12 | tensorboard==2.14.0
13 | einops
--------------------------------------------------------------------------------
/sam_lora.py:
--------------------------------------------------------------------------------
1 | # Sheng Wang at Apr 6 2023
2 | # What a time to be alive (first half of 2023)
3 |
4 | from segment_anything.modeling import Sam
5 | from segment_anything import build_sam, SamPredictor
6 | from segment_anything import sam_model_registry
7 |
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch import Tensor
13 | from torch.nn.parameter import Parameter
14 | from segment_anything.modeling import Sam
15 | from safetensors import safe_open
16 | from safetensors.torch import save_file
17 | from timm.models.vision_transformer import VisionTransformer as timm_ViT
18 |
19 |
20 | class _LoRA_qkv(nn.Module):
21 | """In Sam it is implemented as
22 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
23 | B, N, C = x.shape
24 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
25 | q, k, v = qkv.unbind(0)
26 | """
27 |
28 | def __init__(
29 | self,
30 | qkv: nn.Module,
31 | linear_a_q: nn.Module,
32 | linear_b_q: nn.Module,
33 | linear_a_v: nn.Module,
34 | linear_b_v: nn.Module,
35 | ):
36 | super().__init__()
37 | self.qkv = qkv
38 | self.linear_a_q = linear_a_q
39 | self.linear_b_q = linear_b_q
40 | self.linear_a_v = linear_a_v
41 | self.linear_b_v = linear_b_v
42 | self.dim = qkv.in_features
43 | self.w_identity = torch.eye(qkv.in_features)
44 |
45 | def forward(self, x):
46 | # x: [25, 14, 14, 768]; self.qkv: Linear(in_features=768, out_features=2304, bias=True)
47 | qkv = self.qkv(x) # B,N,N,3*org_C
48 | new_q = self.linear_b_q(self.linear_a_q(x))
49 | new_v = self.linear_b_v(self.linear_a_v(x))
50 | qkv[:, :, :, : self.dim] += new_q
51 | qkv[:, :, :, -self.dim :] += new_v
52 | return qkv
53 |
54 |
55 | class LoRA(nn.Module):
56 | def __init__(self, *args, **kwargs) -> None:
57 | super().__init__(*args, **kwargs)
58 |
59 | def save_fc_parameters(self, filename: str) -> None:
60 | r"""Only safetensors is supported now.
61 |
62 | pip install safetensor if you do not have one installed yet.
63 | """
64 | assert filename.endswith(".safetensors")
65 | _in = self.lora_vit.head.in_features
66 | _out = self.lora_vit.head.out_features
67 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}
68 | save_file(fc_tensors, filename)
69 |
70 | def load_fc_parameters(self, filename: str) -> None:
71 | r"""Only safetensors is supported now.
72 |
73 | pip install safetensor if you do not have one installed yet.
74 | """
75 |
76 | assert filename.endswith(".safetensors")
77 | _in = self.lora_vit.head.in_features
78 | _out = self.lora_vit.head.out_features
79 | with safe_open(filename, framework="pt") as f:
80 | saved_key = f"fc_{_in}in_{_out}out"
81 | try:
82 | saved_tensor = f.get_tensor(saved_key)
83 | self.lora_vit.head.weight = Parameter(saved_tensor)
84 | except ValueError:
85 | print("this fc weight is not for this model")
86 |
87 | def save_lora_parameters(self, filename: str) -> None:
88 | r"""Only safetensors is supported now.
89 |
90 | pip install safetensor if you do not have one installed yet.
91 |
92 | save both lora and fc parameters.
93 | """
94 |
95 | assert filename.endswith(".safetensors")
96 |
97 | num_layer = len(self.w_As) # actually, it is half
98 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
99 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
100 |
101 | _in = self.lora_vit.head.in_features
102 | _out = self.lora_vit.head.out_features
103 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}
104 |
105 | merged_dict = {**a_tensors, **b_tensors, **fc_tensors}
106 | save_file(merged_dict, filename)
107 |
108 | def load_lora_parameters(self, filename: str) -> None:
109 | r"""Only safetensors is supported now.
110 |
111 | pip install safetensor if you do not have one installed yet.\
112 |
113 | load both lora and fc parameters.
114 | """
115 |
116 | assert filename.endswith(".safetensors")
117 |
118 | with safe_open(filename, framework="pt") as f:
119 | for i, w_A_linear in enumerate(self.w_As):
120 | saved_key = f"w_a_{i:03d}"
121 | saved_tensor = f.get_tensor(saved_key)
122 | w_A_linear.weight = Parameter(saved_tensor)
123 |
124 | for i, w_B_linear in enumerate(self.w_Bs):
125 | saved_key = f"w_b_{i:03d}"
126 | saved_tensor = f.get_tensor(saved_key)
127 | w_B_linear.weight = Parameter(saved_tensor)
128 |
129 | _in = self.lora_vit.head.in_features
130 | _out = self.lora_vit.head.out_features
131 | saved_key = f"fc_{_in}in_{_out}out"
132 | try:
133 | saved_tensor = f.get_tensor(saved_key)
134 | self.lora_vit.head.weight = Parameter(saved_tensor)
135 | except ValueError:
136 | print("this fc weight is not for this model")
137 |
138 | def reset_parameters(self) -> None:
139 | for w_A in self.w_As:
140 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
141 | for w_B in self.w_Bs:
142 | nn.init.zeros_(w_B.weight)
143 |
144 |
145 | class LoRA_Sam(LoRA):
146 | """Applies low-rank adaptation to a Sam model's image encoder.
147 |
148 | Args:
149 | sam_model: a vision transformer model, see base_vit.py
150 | r: rank of LoRA
151 | num_classes: how many classes the model output, default to the vit model
152 | lora_layer: which layer we apply LoRA.
153 |
154 | Examples::
155 | >>> model = ViT('B_16_imagenet1k')
156 | >>> lora_model = LoRA_ViT(model, r=4)
157 | >>> preds = lora_model(img)
158 | >>> print(preds.shape)
159 | torch.Size([1, 1000])
160 | """
161 |
162 | def __init__(self, sam_model: Sam, r: int, lora_layer=None):
163 | super(LoRA_Sam, self).__init__()
164 |
165 | assert r > 0
166 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
167 | # dim = base_vit_dim
168 | if lora_layer:
169 | self.lora_layer = lora_layer
170 | else:
171 | self.lora_layer = list(range(len(sam_model.image_encoder.blocks)))
172 | # create for storage, then we can init them or load weights
173 | self.w_As = [] # These are linear layers
174 | self.w_Bs = []
175 |
176 | # lets freeze first
177 | for param in sam_model.image_encoder.parameters():
178 | param.requires_grad = False
179 |
180 | # Here, we do the surgery
181 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
182 | # If we only want few lora layer instead of all
183 | if t_layer_i not in self.lora_layer:
184 | continue
185 | w_qkv_linear = blk.attn.qkv
186 | self.dim = w_qkv_linear.in_features
187 | w_a_linear_q = nn.Linear(self.dim, r, bias=False)
188 | w_b_linear_q = nn.Linear(r, self.dim, bias=False)
189 | w_a_linear_v = nn.Linear(self.dim, r, bias=False)
190 | w_b_linear_v = nn.Linear(r, self.dim, bias=False)
191 | self.w_As.append(w_a_linear_q)
192 | self.w_Bs.append(w_b_linear_q)
193 | self.w_As.append(w_a_linear_v)
194 | self.w_Bs.append(w_b_linear_v)
195 | blk.attn.qkv = _LoRA_qkv(
196 | w_qkv_linear,
197 | w_a_linear_q,
198 | w_b_linear_q,
199 | w_a_linear_v,
200 | w_b_linear_v,
201 | )
202 | self.reset_parameters()
203 | # self.sam = sam_model
204 | self.lora_vit = sam_model
205 |
206 |
--------------------------------------------------------------------------------
/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | src = src + dense_prompt_embeddings
128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129 | b, c, h, w = src.shape
130 |
131 | # Run the transformer
132 | hs, src = self.transformer(src, pos_src, tokens)
133 | iou_token_out = hs[:, 0, :]
134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135 |
136 | # Upscale mask embeddings and predict masks using the mask tokens
137 | src = src.transpose(1, 2).view(b, c, h, w)
138 | upscaled_embedding = self.output_upscaling(src)
139 | hyper_in_list: List[torch.Tensor] = []
140 | for i in range(self.num_mask_tokens):
141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142 | hyper_in = torch.stack(hyper_in_list, dim=1)
143 | b, c, h, w = upscaled_embedding.shape
144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145 |
146 | # Generate mask quality predictions
147 | iou_pred = self.iou_prediction_head(iou_token_out)
148 |
149 | return masks, iou_pred
150 |
151 |
152 | # Lightly adapted from
153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154 | class MLP(nn.Module):
155 | def __init__(
156 | self,
157 | input_dim: int,
158 | hidden_dim: int,
159 | output_dim: int,
160 | num_layers: int,
161 | sigmoid_output: bool = False,
162 | ) -> None:
163 | super().__init__()
164 | self.num_layers = num_layers
165 | h = [hidden_dim] * (num_layers - 1)
166 | self.layers = nn.ModuleList(
167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168 | )
169 | self.sigmoid_output = sigmoid_output
170 |
171 | def forward(self, x):
172 | for i, layer in enumerate(self.layers):
173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174 | if self.sigmoid_output:
175 | x = F.sigmoid(x)
176 | return x
177 |
--------------------------------------------------------------------------------
/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels = torch.cat([labels, padding_label], dim=1)
86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87 | point_embedding[labels == -1] = 0.0
88 | point_embedding[labels == -1] += self.not_a_point_embed.weight
89 | point_embedding[labels == 0] += self.point_embeddings[0].weight
90 | point_embedding[labels == 1] += self.point_embeddings[1].weight
91 | return point_embedding
92 |
93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94 | """Embeds box prompts."""
95 | boxes = boxes + 0.5 # Shift to center of pixel
96 | coords = boxes.reshape(-1, 2, 2)
97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100 | return corner_embedding
101 |
102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103 | """Embeds mask inputs."""
104 | mask_embedding = self.mask_downscaling(masks)
105 | return mask_embedding
106 |
107 | def _get_batch_size(
108 | self,
109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110 | boxes: Optional[torch.Tensor],
111 | masks: Optional[torch.Tensor],
112 | ) -> int:
113 | """
114 | Gets the batch size of the output given the batch size of the input prompts.
115 | """
116 | if points is not None:
117 | return points[0].shape[0]
118 | elif boxes is not None:
119 | return boxes.shape[0]
120 | elif masks is not None:
121 | return masks.shape[0]
122 | else:
123 | return 1
124 |
125 | def _get_device(self) -> torch.device:
126 | return self.point_embeddings[0].weight.device
127 |
128 | def forward(
129 | self,
130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131 | boxes: Optional[torch.Tensor],
132 | masks: Optional[torch.Tensor],
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """
135 | Embeds different types of prompts, returning both sparse and dense
136 | embeddings.
137 |
138 | Arguments:
139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140 | and labels to embed.
141 | boxes (torch.Tensor or none): boxes to embed
142 | masks (torch.Tensor or none): masks to embed
143 |
144 | Returns:
145 | torch.Tensor: sparse embeddings for the points and boxes, with shape
146 | BxNx(embed_dim), where N is determined by the number of input points
147 | and boxes.
148 | torch.Tensor: dense embeddings for the masks, in the shape
149 | Bx(embed_dim)x(embed_H)x(embed_W)
150 | """
151 | bs = self._get_batch_size(points, boxes, masks)
152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153 | if points is not None:
154 | coords, labels = points
155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 | if boxes is not None:
158 | box_embeddings = self._embed_boxes(boxes)
159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160 |
161 | if masks is not None:
162 | dense_embeddings = self._embed_masks(masks)
163 | else:
164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166 | )
167 |
168 | return sparse_embeddings, dense_embeddings
169 |
170 |
171 | class PositionEmbeddingRandom(nn.Module):
172 | """
173 | Positional encoding using random spatial frequencies.
174 | """
175 |
176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177 | super().__init__()
178 | if scale is None or scale <= 0.0:
179 | scale = 1.0
180 | self.register_buffer(
181 | "positional_encoding_gaussian_matrix",
182 | scale * torch.randn((2, num_pos_feats)),
183 | )
184 |
185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186 | """Positionally encode points that are normalized to [0,1]."""
187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188 | coords = 2 * coords - 1
189 | coords = coords @ self.positional_encoding_gaussian_matrix
190 | coords = 2 * np.pi * coords
191 | # outputs d_1 x ... x d_n x C shape
192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193 |
194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195 | """Generate positional encoding for a grid of the specified size."""
196 | h, w = size
197 | device: Any = self.positional_encoding_gaussian_matrix.device
198 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
199 | y_embed = grid.cumsum(dim=0) - 0.5
200 | x_embed = grid.cumsum(dim=1) - 0.5
201 | y_embed = y_embed / h
202 | x_embed = x_embed / w
203 |
204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205 | return pe.permute(2, 0, 1) # C x H x W
206 |
207 | def forward_with_coords(
208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209 | ) -> torch.Tensor:
210 | """Positionally encode points that are not normalized to [0,1]."""
211 | coords = coords_input.clone()
212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
215 |
--------------------------------------------------------------------------------
/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | @torch.no_grad()
54 | def forward(
55 | self,
56 | batched_input: List[Dict[str, Any]],
57 | multimask_output: bool,
58 | ) -> List[Dict[str, torch.Tensor]]:
59 | """
60 | Predicts masks end-to-end from provided images and prompts.
61 | If prompts are not known in advance, using SamPredictor is
62 | recommended over calling the model directly.
63 |
64 | Arguments:
65 | batched_input (list(dict)): A list over input images, each a
66 | dictionary with the following keys. A prompt key can be
67 | excluded if it is not present.
68 | 'image': The image as a torch tensor in 3xHxW format,
69 | already transformed for input to the model.
70 | 'original_size': (tuple(int, int)) The original size of
71 | the image before transformation, as (H, W).
72 | 'point_coords': (torch.Tensor) Batched point prompts for
73 | this image, with shape BxNx2. Already transformed to the
74 | input frame of the model.
75 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
76 | with shape BxN.
77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78 | Already transformed to the input frame of the model.
79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80 | in the form Bx1xHxW.
81 | multimask_output (bool): Whether the model should predict multiple
82 | disambiguating masks, or return a single mask.
83 |
84 | Returns:
85 | (list(dict)): A list over input images, where each element is
86 | as dictionary with the following keys.
87 | 'masks': (torch.Tensor) Batched binary mask predictions,
88 | with shape BxCxHxW, where B is the number of input prompts,
89 | C is determined by multimask_output, and (H, W) is the
90 | original size of the image.
91 | 'iou_predictions': (torch.Tensor) The model's predictions
92 | of mask quality, in shape BxC.
93 | 'low_res_logits': (torch.Tensor) Low resolution logits with
94 | shape BxCxHxW, where H=W=256. Can be passed as mask input
95 | to subsequent iterations of prediction.
96 | """
97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98 | image_embeddings = self.image_encoder(input_images)
99 |
100 | outputs = []
101 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
102 | if "point_coords" in image_record:
103 | points = (image_record["point_coords"], image_record["point_labels"])
104 | else:
105 | points = None
106 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
107 | points=points,
108 | boxes=image_record.get("boxes", None),
109 | masks=image_record.get("mask_inputs", None),
110 | )
111 | low_res_masks, iou_predictions = self.mask_decoder(
112 | image_embeddings=curr_embedding.unsqueeze(0),
113 | image_pe=self.prompt_encoder.get_dense_pe(),
114 | sparse_prompt_embeddings=sparse_embeddings,
115 | dense_prompt_embeddings=dense_embeddings,
116 | multimask_output=multimask_output,
117 | )
118 | masks = self.postprocess_masks(
119 | low_res_masks,
120 | input_size=image_record["image"].shape[-2:],
121 | original_size=image_record["original_size"],
122 | )
123 | masks = masks > self.mask_threshold
124 | outputs.append(
125 | {
126 | "masks": masks,
127 | "iou_predictions": iou_predictions,
128 | "low_res_logits": low_res_masks,
129 | }
130 | )
131 | return outputs
132 |
133 | def postprocess_masks(
134 | self,
135 | masks: torch.Tensor,
136 | input_size: Tuple[int, ...],
137 | original_size: Tuple[int, ...],
138 | ) -> torch.Tensor:
139 | """
140 | Remove padding and upscale masks to the original image size.
141 |
142 | Arguments:
143 | masks (torch.Tensor): Batched masks from the mask_decoder,
144 | in BxCxHxW format.
145 | input_size (tuple(int, int)): The size of the image input to the
146 | model, in (H, W) format. Used to remove padding.
147 | original_size (tuple(int, int)): The original size of the image
148 | before resizing for input to the model, in (H, W) format.
149 |
150 | Returns:
151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152 | is given by original_size.
153 | """
154 | masks = F.interpolate(
155 | masks,
156 | (self.image_encoder.img_size, self.image_encoder.img_size),
157 | mode="bilinear",
158 | align_corners=False,
159 | )
160 | masks = masks[..., : input_size[0], : input_size[1]]
161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162 | return masks
163 |
164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165 | """Normalize pixel values and pad to a square input."""
166 | # Normalize colors
167 | x = (x - self.pixel_mean) / self.pixel_std
168 |
169 | # Pad
170 | h, w = x.shape[-2:]
171 | padh = self.image_encoder.img_size - h
172 | padw = self.image_encoder.img_size - w
173 | x = F.pad(x, (0, padw, 0, padh))
174 | return x
175 |
--------------------------------------------------------------------------------
/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attention layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import lightning as L
4 | import segmentation_models_pytorch as smp
5 | from box import Box
6 | from torch.utils.data import DataLoader
7 | from model import Model
8 | from utils.sample_utils import get_point_prompts
9 | from utils.tools import write_csv
10 |
11 |
12 | class AverageMeter:
13 | """Computes and stores the average and current value."""
14 |
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
29 |
30 |
31 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor):
32 | pred_mask = (pred_mask >= 0.5).float()
33 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2))
34 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection
35 | epsilon = 1e-7
36 | batch_iou = intersection / (union + epsilon)
37 |
38 | batch_iou = batch_iou.unsqueeze(1)
39 | return batch_iou
40 |
41 |
42 | def get_prompts(cfg: Box, bboxes, gt_masks):
43 | if cfg.prompt == "box" or cfg.prompt == "coarse":
44 | prompts = bboxes
45 | elif cfg.prompt == "point":
46 | prompts = get_point_prompts(gt_masks, cfg.num_points)
47 | else:
48 | raise ValueError("Prompt Type Error!")
49 | return prompts
50 |
51 |
52 | def validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0):
53 | model.eval()
54 | ious = AverageMeter()
55 | f1_scores = AverageMeter()
56 |
57 | with torch.no_grad():
58 | for iter, data in enumerate(val_dataloader):
59 | images, bboxes, gt_masks = data
60 | num_images = images.size(0)
61 |
62 | prompts = get_prompts(cfg, bboxes, gt_masks)
63 |
64 | _, pred_masks, _, _ = model(images, prompts)
65 | for pred_mask, gt_mask in zip(pred_masks, gt_masks):
66 | batch_stats = smp.metrics.get_stats(
67 | pred_mask,
68 | gt_mask.int(),
69 | mode='binary',
70 | threshold=0.5,
71 | )
72 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise")
73 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise")
74 | ious.update(batch_iou, num_images)
75 | f1_scores.update(batch_f1, num_images)
76 | fabric.print(
77 | f'Val: [{iters}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]'
78 | )
79 | torch.cuda.empty_cache()
80 |
81 | fabric.print(f'Validation [{iters}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]')
82 | csv_dict = {"Name": name, "Prompt": cfg.prompt, "Mean IoU": f"{ious.avg:.4f}", "Mean F1": f"{f1_scores.avg:.4f}", "iters": iters}
83 |
84 | if fabric.global_rank == 0:
85 | write_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_dict, csv_head=cfg.csv_keys)
86 | model.train()
87 | return ious.avg, f1_scores.avg
88 |
89 |
90 | def unspervised_validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0):
91 | init_prompt = cfg.prompt
92 | cfg.prompt = "box"
93 | iou_box, f1_box = validate(fabric, cfg, model, val_dataloader, name, iters)
94 | cfg.prompt = "point"
95 | iou_point, f1_point = validate(fabric, cfg, model, val_dataloader, name, iters)
96 | # cfg.prompt = "coarse"
97 | # validate(fabric, cfg, model, val_dataloader, name, iters)
98 | cfg.prompt = init_prompt
99 | return iou_box, f1_box, iou_point, f1_point
100 |
101 |
102 | def contrast_validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0, loss: float = 0.):
103 | model.eval()
104 | ious = AverageMeter()
105 | f1_scores = AverageMeter()
106 |
107 | with torch.no_grad():
108 | for iter, data in enumerate(val_dataloader):
109 | images, bboxes, gt_masks = data
110 | num_images = images.size(0)
111 |
112 | prompts = get_prompts(cfg, bboxes, gt_masks)
113 |
114 | _, pred_masks, _, _ = model(images, prompts)
115 | for pred_mask, gt_mask in zip(pred_masks, gt_masks):
116 | batch_stats = smp.metrics.get_stats(
117 | pred_mask,
118 | gt_mask.int(),
119 | mode='binary',
120 | threshold=0.5,
121 | )
122 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise")
123 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise")
124 | ious.update(batch_iou, num_images)
125 | f1_scores.update(batch_f1, num_images)
126 | fabric.print(
127 | f'Val: [{iters}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]'
128 | )
129 | torch.cuda.empty_cache()
130 |
131 | fabric.print(f'Validation [{iters}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]')
132 | csv_dict = {"Name": name, "Prompt": cfg.prompt, "Mean IoU": f"{ious.avg:.4f}", "Mean F1": f"{f1_scores.avg:.4f}", "iters": iters, "loss": loss}
133 |
134 | if fabric.global_rank == 0:
135 | write_csv(os.path.join(cfg.out_dir, f"metrics-{cfg.prompt}.csv"), csv_dict, csv_head=cfg.csv_keys)
136 | model.train()
137 | return ious.avg, f1_scores.avg
138 |
--------------------------------------------------------------------------------
/utils/sample_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.cluster import KMeans
4 |
5 |
6 | def uniform_sampling(masks, N=5):
7 | n_points = []
8 | for mask in masks:
9 | if not isinstance(mask, np.ndarray):
10 | mask = mask.cpu().numpy()
11 |
12 | indices = np.argwhere(mask == 1) # [y, x]
13 | sampled_indices = np.random.choice(len(indices), N, replace=True)
14 | sampled_points = np.flip(indices[sampled_indices], axis=1)
15 | n_points.append(sampled_points.tolist())
16 |
17 | return n_points
18 |
19 |
20 | def get_multi_distance_points(input_point, mask, points_nubmer):
21 | new_points = np.zeros((points_nubmer + 1, 2))
22 | new_points[0] = [input_point[1], input_point[0]]
23 | for i in range(points_nubmer):
24 | new_points[i + 1] = get_next_distance_point(new_points[:i + 1, :], mask)
25 |
26 | new_points = swap_xy(new_points)
27 | return new_points
28 |
29 |
30 | def get_next_distance_point(input_points, mask):
31 | max_distance_point = [0, 0]
32 | max_distance = 0
33 | input_points = np.array(input_points)
34 |
35 | indices = np.argwhere(mask == True)
36 | for x, y in indices:
37 | # print(x,y,input_points)
38 | distance = np.sum(np.sqrt((x - input_points[:, 0]) ** 2 + (y - input_points[:, 1]) ** 2))
39 | if max_distance < distance:
40 | max_distance_point = [x, y]
41 | max_distance = distance
42 | return max_distance_point
43 |
44 |
45 | def swap_xy(points):
46 | new_points = np.zeros((len(points),2))
47 | new_points[:,0] = points[:,1]
48 | new_points[:,1] = points[:,0]
49 | return new_points
50 |
51 |
52 | def k_means_sampling(mask, k):
53 | points = np.argwhere(mask == 1) # [y, x]
54 | points = np.flip(points, axis=1)
55 |
56 | kmeans = KMeans(n_clusters=k)
57 | kmeans.fit(points)
58 | points = kmeans.cluster_centers_
59 | return points
60 |
61 |
62 | def get_point_prompt_max_dist(masks, num_points):
63 | n_points = []
64 | for mask in masks:
65 | mask_np = mask.cpu().numpy()
66 |
67 | indices = np.argwhere(mask_np > 0)
68 | random_index = np.random.choice(len(indices), 1)[0]
69 |
70 | first_point = [indices[random_index][1], indices[random_index][0]]
71 | new_points = get_multi_distance_points(first_point, mask_np, num_points - 1)
72 | n_points.append(new_points)
73 |
74 | return n_points
75 |
76 |
77 | def get_point_prompt_kmeans(masks, num_points):
78 | n_points = []
79 | for mask in masks:
80 | mask_np = mask.cpu().numpy()
81 | points = k_means_sampling(mask_np, num_points)
82 | n_points.append(points.astype(int))
83 | return n_points
84 |
85 |
86 | def get_point_prompts(gt_masks, num_points):
87 | prompts = []
88 | for mask in gt_masks:
89 | po_points = uniform_sampling(mask, num_points)
90 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points)
91 | po_point_coords = torch.tensor(po_points, device=mask.device)
92 | na_point_coords = torch.tensor(na_points, device=mask.device)
93 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1)
94 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=po_point_coords.device)
95 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=na_point_coords.device)
96 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1)
97 | in_points = (point_coords, point_labels)
98 | prompts.append(in_points)
99 | return prompts
100 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | import copy
5 | import numpy as np
6 | from torchsummary import summary
7 |
8 | def freeze(model: torch.nn.Module):
9 | model.eval()
10 | for param in model.parameters():
11 | param.requires_grad = False
12 |
13 |
14 | def momentum_update(student_model, teacher_model, momentum=0.99):
15 | for (src_name, src_param), (tgt_name, tgt_param) in zip(
16 | student_model.named_parameters(), teacher_model.named_parameters()
17 | ):
18 | if src_param.requires_grad:
19 | tgt_param.data.mul_(momentum).add_(src_param.data, alpha=1 - momentum)
20 |
21 |
22 | def decode_mask(mask):
23 | """
24 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects
25 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects.
26 |
27 | Args:
28 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
29 |
30 | Returns:
31 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
32 | """
33 | unique_labels = torch.unique(mask)
34 | unique_labels = unique_labels[unique_labels != 0]
35 | n_objects = len(unique_labels)
36 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64)
37 | for i, label in enumerate(unique_labels):
38 | new_mask[i] = (mask == label).squeeze(0)
39 | return new_mask
40 |
41 |
42 | def encode_mask(mask):
43 | """
44 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects
45 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
46 |
47 | Args:
48 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects.
49 |
50 | Returns:
51 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects.
52 | """
53 | n_objects = mask.shape[0]
54 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64)
55 | for i in range(n_objects):
56 | new_mask[0][mask[i] == 1] = i + 1
57 | return new_mask
58 |
59 |
60 | def copy_model(model: torch.nn.Module):
61 | new_model = copy.deepcopy(model)
62 | freeze(new_model)
63 | return new_model
64 |
65 |
66 | def create_csv(filename, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]):
67 | if os.path.exists(filename):
68 | return
69 | with open(filename, 'w') as csvfile:
70 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head)
71 | csv_write.writeheader()
72 |
73 |
74 | def write_csv(filename, csv_dict, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]):
75 | with open(filename, 'a+') as csvfile:
76 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head, extrasaction='ignore')
77 | csv_write.writerow(csv_dict)
78 |
79 |
80 | def check_grad(model: torch.nn.Module):
81 | for name, param in model.named_parameters():
82 | print(f"{name}: {param.requires_grad}")
83 |
84 |
85 | def check_equal(model1: torch.nn.Module, model2: torch.nn.Module):
86 | for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()):
87 | if name1 == name2:
88 | if not torch.allclose(param1, param2):
89 | print(f"{name1} is different")
90 | else:
91 | print(f"same")
92 | else:
93 | print("The models have different structures")
94 |
95 |
96 | def check_model(model):
97 | return summary(model, (3, 1024, 1024), batch_size=1, device="cuda")
98 |
99 |
100 | def reduce_instances(bboxes, gt_masks, max_nums=50):
101 | bboxes_ = []
102 | gt_masks_ = []
103 | for bbox, gt_mask in zip(bboxes, gt_masks):
104 | idx = np.arange(bbox.shape[0])
105 | np.random.shuffle(idx)
106 | bboxes_.append(bbox[idx[:max_nums]])
107 | gt_masks_.append(gt_mask[idx[:max_nums]])
108 |
109 | bboxes = bboxes_
110 | gt_masks = gt_masks_
111 | return bboxes, gt_masks
112 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | from torchvision.utils import draw_bounding_boxes
5 | from torchvision.utils import draw_segmentation_masks
6 |
7 | from box import Box
8 | from tqdm import tqdm
9 | from model import Model
10 | from datasets.COCO import COCODataset
11 |
12 |
13 | def draw_image(image, masks, boxes, labels, alpha=0.4):
14 | image = torch.from_numpy(image).permute(2, 0, 1)
15 | if boxes is not None:
16 | image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2)
17 | if masks is not None:
18 | image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha)
19 | return image.numpy().transpose(1, 2, 0)
20 |
21 |
22 | def visualize(cfg: Box):
23 | model = Model(cfg)
24 | model.setup()
25 | model.eval()
26 | model.cuda()
27 | dataset = COCODataset(root_dir=cfg.dataset.val.root_dir,
28 | annotation_file=cfg.dataset.val.annotation_file,
29 | transform=None)
30 | predictor = model.get_predictor()
31 | os.makedirs(cfg.out_dir, exist_ok=True)
32 |
33 | for image_id in tqdm(dataset.image_ids):
34 | image_info = dataset.coco.loadImgs(image_id)[0]
35 | image_path = os.path.join(dataset.root_dir, image_info['file_name'])
36 | image_output_path = os.path.join(cfg.out_dir, image_info['file_name'])
37 | image = cv2.imread(image_path)
38 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39 | ann_ids = dataset.coco.getAnnIds(imgIds=image_id)
40 | anns = dataset.coco.loadAnns(ann_ids)
41 | bboxes = []
42 | for ann in anns:
43 | x, y, w, h = ann['bbox']
44 | bboxes.append([x, y, x + w, y + h])
45 | bboxes = torch.as_tensor(bboxes, device=model.model.device)
46 | transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2])
47 | predictor.set_image(image)
48 | masks, _, _ = predictor.predict_torch(
49 | point_coords=None,
50 | point_labels=None,
51 | boxes=transformed_boxes,
52 | multimask_output=False,
53 | )
54 | image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None)
55 | cv2.imwrite(image_output_path, image_output)
56 |
57 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import lightning as L
5 | import segmentation_models_pytorch as smp
6 | from box import Box
7 | from lightning.fabric.loggers import TensorBoardLogger
8 |
9 | from configs.config_validate import cfg
10 | from datasets import call_load_dataset
11 | from model import Model
12 | from utils.eval_utils import AverageMeter, get_prompts, validate
13 | from utils.tools import copy_model, create_csv
14 | from model import Model
15 | from sam_lora import LoRA_Sam
16 |
17 |
18 | def multi_main(cfg):
19 | prompts = ["box", "point"]
20 | for prompt in prompts:
21 | cfg.prompt = prompt
22 | torch.cuda.empty_cache()
23 | main(cfg)
24 |
25 |
26 | def main(cfg: Box, ckpt: str = None) -> None:
27 | gpu_ids = cfg.gpu_ids.split(',')
28 | num_devices = len(gpu_ids)
29 |
30 | fabric = L.Fabric(accelerator="auto",
31 | devices=num_devices,
32 | strategy="auto",
33 | loggers=[TensorBoardLogger(cfg.out_dir)])
34 | fabric.launch()
35 | fabric.seed_everything(1337 + fabric.global_rank)
36 |
37 | if fabric.global_rank == 0:
38 | os.makedirs(cfg.out_dir, exist_ok=True)
39 | create_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_head=cfg.csv_keys)
40 |
41 | with fabric.device:
42 | model = Model(cfg)
43 | model.setup()
44 | LoRA_Sam(model.model, 4)
45 |
46 | load_datasets = call_load_dataset(cfg)
47 | _, val_data = load_datasets(cfg, model.model.image_encoder.img_size)
48 |
49 | fabric.print(f"Val Data: {len(val_data) * cfg.val_batchsize}")
50 | val_data = fabric._setup_dataloader(val_data)
51 |
52 | if ckpt is not None:
53 | full_checkpoint = fabric.load(ckpt)
54 | model.load_state_dict(full_checkpoint["model"])
55 |
56 | validate(fabric, cfg, model, val_data, name=cfg.name, iters=0)
57 |
58 | del model, val_data
59 |
60 |
61 | if __name__ == "__main__":
62 | torch.cuda.empty_cache()
63 | torch.set_float32_matmul_precision('high')
64 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_ids
65 |
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("--ckpt", default=None, type=str)
68 | args = parser.parse_args()
69 |
70 | main(cfg, args.ckpt)
71 | # multi_main(cfg, args.ckpt)
72 | torch.cuda.empty_cache()
73 |
--------------------------------------------------------------------------------