├── .gitignore ├── INSTALL.md ├── LICENSE ├── PREPARE.md ├── README.md ├── adaptation.py ├── asserts ├── Pipeline.webp ├── VISUAL.webp └── teaser.webp ├── configs ├── base_config.py └── config.py ├── datasets ├── CAMO.py ├── COCO.py ├── COCONut.py ├── COD10K.py ├── GDD.py ├── ISIC.py ├── ISTD.py ├── MSD.py ├── OCID.py ├── OSD.py ├── PascalVOC.py ├── Polyp.py ├── SA_1B.py ├── __init__.py └── tools.py ├── losses.py ├── model.py ├── requirements.txt ├── sam_lora.py ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── utils ├── eval_utils.py ├── sample_utils.py ├── tools.py └── visualize.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | data 7 | output 8 | *.sh 9 | 10 | # C extensions 11 | *.so 12 | *.pth 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | 4 | ### Requirements 5 | - Linux or macOS with Python ≥ 3.8 6 | - PyTorch ≥ 1.13.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 7 | - Install pytorch [lightning](https://lightning.ai/pytorch-lightning) that matches the PyTorch installation. 8 | - `pip install -r requirements.txt` 9 | 10 | 11 | ### Example conda environment setup 12 | ```bash 13 | conda create --name wesam python=3.8 14 | conda activate wesam 15 | 16 | # CUDA 11.7 17 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 18 | 19 | git clone https://github.com/zhang-haojie/wesam.git 20 | cd wesam 21 | pip install -r requirements.txt 22 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Haojie Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PREPARE.md: -------------------------------------------------------------------------------- 1 | ## Prepare 2 | 3 | 4 | ### Download Dataset 5 | 6 | *Natural images* 7 | 8 | - [COCO Dataset](https://cocodataset.org/) 9 | 10 | - [PascalVOC Dataset](http://host.robots.ox.ac.uk/pascal/VOC/) 11 | 12 | 13 | *Medical images* 14 | 15 | - [Kvasir-SEG Dataset](https://datasets.simula.no/kvasir-seg/) 16 | 17 | - [ISIC Dataset](https://challenge.isic-archive.com/data/) 18 | 19 | 20 | *Camouflaged Objects* 21 | 22 | - [COD10k Dataset](https://drive.google.com/file/d/1pVq1rWXCwkMbEZpTt4-yUQ3NsnQd_DNY/view?usp=sharing) - [Camouflaged Object Detection](https://github.com/DengPingFan/SINet/) 23 | 24 | - [CAMO](https://drive.google.com/open?id=1h-OqZdwkuPhBvGcVAwmh0f1NGqlH_4B6) - [Project](https://sites.google.com/view/ltnghia/research/camo) 25 | 26 | 27 | *Robotic Images* 28 | 29 | - [OCID](https://www.acin.tuwien.ac.at/en/vision-for-robotics/software-tools/object-clutter-indoor-dataset/): *Object Clutter Indoor Dataset* 30 | 31 | - [OSD](https://www.acin.tuwien.ac.at/en/vision-for-robotics/software-tools/osd/): *Object Segmentation Database* 32 | 33 | 34 | *Corrupted Images* 35 | 36 | In `datasets/COCO.py`, uncomment the line that includes `corrupt_image`. Then comment line 192 of `adaptation.py` and run it. 37 | 38 | ``` 39 | def __getitem__(self, idx): 40 | image_id = self.image_ids[idx] 41 | image_info = self.coco.loadImgs(image_id)[0] 42 | image_path = os.path.join(self.root_dir, image_info["file_name"]) 43 | if self.cfg.corrupt in self.cfg.corruptions: 44 | image_path = image_path.replace("val2017", os.path.join("corruption", self.cfg.corrupt)) 45 | image = cv2.imread(image_path) 46 | 47 | # corrupt_image(image, image_path) 48 | 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | ``` 51 | 52 | 53 | *Glass* 54 | - [GDD](http://gdd.dluticcd.com/) - 55 | [Don't Hit Me! Glass Detection in Real-world Scenes](https://github.com/Charmve/Mirror-Glass-Detection/tree/master/CVPR2020_GDNet) 56 | (*Glass Detection Dataset* Need Apply!) 57 | 58 | - [GSD](https://drive.google.com/file/d/1pSEUs-8I-4YHOTJ9J0wxiEpkdbHjMQQo/view?usp=sharing) - 59 | [Exploiting Semantic Relations for Glass Surface Detection](https://jiaying.link/neurips2022-gsds/) 60 | (*Glass Surface Dataset*) 61 | 62 | 63 | *Mirror* 64 | - [MSD](https://drive.google.com/file/d/1Znw92fO6lCKfXejjSSyMyL1qtFepgjPI/view?usp=sharing) - 65 | [Where is My Mirror?](https://github.com/Charmve/Mirror-Glass-Detection/tree/master/ICCV2019_MirrorNet) 66 | (*mirror segmentation dataset* Need Apply!) 67 | 68 | 69 | *Shadow Detection* 70 | - [ISTD](https://drive.google.com/file/d/1I0qw-65KBA6np8vIZzO6oeiOvcDBttAY/view) - 71 | [tacked Conditional Generative Adversarial Networks for Jointly Learning Shadow Detection and Shadow Removal](https://github.com/DeepInsight-PCALab/ST-CGAN) 72 | (*Image Shadow Triplets dataset* Need Apply!) 73 | 74 | 75 | ### Download Checkpoints 76 | 77 | Click the links below to download the checkpoint for the corresponding model type. 78 | 79 | - `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) 80 | - `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) 81 | - `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 82 | 83 | 84 | ### Prepare 85 | 86 | ``` 87 | cd wesam/ 88 | 89 | mkdir data 90 | mkdir checkpoints 91 | 92 | mv DATASETS ./data 93 | 94 | mv VIT_B ./checkpoints 95 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Improving the Generalization of Segmentation Foundation Model under Distribution Shift via Weakly Supervised Adaptation

4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | ## 🎈 News 12 | 13 | - [2024.2.27] Our work has been accepted to CVPR 2024 🎉 14 | - [2024.3.1] Training and inference code released 15 | 16 | ## 🚀 Introduction 17 | 18 |
19 | image 20 |
21 | 22 | Segment Anything Model was pre-trained on a large-scale dataset but exhibits awkward performance on diverse downstream segmentation tasks. We adapt SAM through weak supervision to enhance its generalization capabilities. 23 | 24 | 25 | ## 📻 Overview 26 | 27 |
28 | image 29 |
30 | 31 | The proposed self-training architecture with anchor network regularization and contrastive loss regularization. Red arrows indicates the backpropagation flow. 32 | 33 | 34 | ## 📆 TODO 35 | 36 | - [x] Release code 37 | 38 | ## 🎮 Getting Started 39 | 40 | ### 1. Install Environment 41 | 42 | see [INSTALL](INSTALL.md). 43 | 44 | ### 2. Prepare Dataset and Checkpoints 45 | 46 | see [PREPARE](PREPARE.md). 47 | 48 | ### 3. Adapt with Weak Supervision 49 | 50 | ``` 51 | # 1 modify configs/config.py 52 | # Prompt type: box, point, coarse 53 | 54 | # 2 adapt 55 | python adaptation.py 56 | ``` 57 | 58 | ### 4. Validation 59 | 60 | ``` 61 | python validate.py --ckpt /path/to/checkpoint 62 | ``` 63 | 64 | 65 | ## 🖼️ Visualization 66 | 67 |
68 | image 69 |
70 | 71 | 72 | ## 🎫 License 73 | 74 | The content of this project itself is licensed under [LICENSE](LICENSE). 75 | 76 | ## 💡 Acknowledgement 77 | 78 | - [SAM](https://github.com/facebookresearch/segment-anything) 79 | 80 | - [lightning-sam](https://github.com/luca-medeiros/lightning-sam) 81 | 82 | - [SAM-LoRA](https://github.com/JamesQFreeman/Sam_LoRA) 83 | 84 | ## 🖊️ Citation 85 | 86 | If you find this project useful in your research, please consider cite: 87 | 88 | ```BibTeX 89 | @inproceedings{zhang2024improving, 90 | title={Improving the generalization of segmentation foundation model under distribution shift via weakly supervised adaptation}, 91 | author={Zhang, Haojie and Su, Yongyi and Xu, Xun and Jia, Kui}, 92 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 93 | pages={23385--23395}, 94 | year={2024} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /adaptation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import torch 5 | import lightning as L 6 | import torch.nn.functional as F 7 | import segmentation_models_pytorch as smp 8 | from box import Box 9 | from lightning.fabric.fabric import _FabricOptimizer 10 | from lightning.fabric.loggers import TensorBoardLogger, CSVLogger 11 | from torch.utils.data import DataLoader 12 | 13 | from configs.config import cfg 14 | from losses import DiceLoss, FocalLoss, ContraLoss 15 | from datasets import call_load_dataset 16 | 17 | from model import Model 18 | from sam_lora import LoRA_Sam 19 | from utils.eval_utils import AverageMeter, calc_iou, validate, get_prompts 20 | from utils.tools import copy_model, create_csv, check_grad, momentum_update, reduce_instances 21 | 22 | 23 | def train_sam( 24 | cfg: Box, 25 | fabric: L.Fabric, 26 | model: Model, 27 | anchor_model: Model, 28 | optimizer: _FabricOptimizer, 29 | scheduler: _FabricOptimizer, 30 | train_dataloader: DataLoader, 31 | val_dataloader: DataLoader, 32 | num_iters: int, 33 | ): 34 | """The SAM training loop.""" 35 | batch_time = AverageMeter() 36 | data_time = AverageMeter() 37 | focal_losses = AverageMeter() 38 | dice_losses = AverageMeter() 39 | iou_losses = AverageMeter() 40 | anchor_losses = AverageMeter() 41 | contra_losses = AverageMeter() 42 | total_losses = AverageMeter() 43 | focal_loss = FocalLoss() 44 | dice_loss = DiceLoss() 45 | contra_loss = ContraLoss() 46 | end = time.time() 47 | max_iou = 0. 48 | num_epochs = cfg.num_iters // num_iters + 1 49 | 50 | for epoch in range(1, num_epochs): 51 | 52 | for iter, data in enumerate(train_dataloader): 53 | 54 | data_time.update(time.time() - end) 55 | images_weak, images_strong, bboxes, gt_masks = data 56 | batch_size = images_weak.size(0) 57 | num_insts = sum(len(gt_mask) for gt_mask in gt_masks) 58 | if num_insts > cfg.max_nums: 59 | print(num_insts) 60 | bboxes, gt_masks = reduce_instances(bboxes, gt_masks, cfg.max_nums) 61 | 62 | prompts = get_prompts(cfg, bboxes, gt_masks) 63 | 64 | with torch.no_grad(): 65 | anchor_image_embeds, anchor_masks, anchor_iou_predictions, anchor_res_masks = anchor_model(images_weak, prompts) 66 | 67 | soft_image_embeds, soft_masks, soft_iou_predictions, soft_res_masks = model(images_weak, prompts) # teacher 68 | pred_image_embeds, pred_masks, iou_predictions, pred_res_masks = model(images_strong, prompts) # student 69 | 70 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks) 71 | loss_focal = torch.tensor(0., device=fabric.device) 72 | loss_dice = torch.tensor(0., device=fabric.device) 73 | loss_iou = torch.tensor(0., device=fabric.device) 74 | loss_anchor = torch.tensor(0., device=fabric.device) 75 | loss_contra = torch.tensor(0., device=fabric.device) 76 | 77 | for i, (pred_mask, soft_mask, anchor_mask, iou_prediction) in enumerate(zip(pred_masks, soft_masks, anchor_masks, iou_predictions)): 78 | anchor_mask = (anchor_mask > 0.).float() 79 | loss_contra += contra_loss(soft_image_embeds[i], anchor_image_embeds[i], soft_res_masks[i].clone().detach(), anchor_res_masks[i].clone().detach()) 80 | # loss_contra += contra_loss(pred_image_embeds[i], anchor_image_embeds[i], pred_res_masks[i].clone().detach(), anchor_res_masks[i].clone().detach()) 81 | 82 | loss_anchor += (0.5 * dice_loss(pred_mask, anchor_mask) + 0.5 * dice_loss(soft_mask, anchor_mask)) 83 | 84 | soft_mask = (soft_mask > 0.).float() 85 | loss_focal += focal_loss(pred_mask, soft_mask, num_masks) 86 | loss_dice += dice_loss(pred_mask, soft_mask, num_masks) 87 | batch_iou = calc_iou(pred_mask, soft_mask) 88 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks 89 | 90 | loss_total = 20. * loss_focal + loss_dice + loss_iou + loss_anchor + loss_contra 91 | fabric.backward(loss_total) 92 | 93 | optimizer.step() 94 | scheduler.step() 95 | optimizer.zero_grad() 96 | torch.cuda.empty_cache() 97 | 98 | batch_time.update(time.time() - end) 99 | end = time.time() 100 | 101 | # momentum_update(model, anchor_model, momentum=cfg.ema_rate) 102 | 103 | focal_losses.update(loss_focal.item(), batch_size) 104 | dice_losses.update(loss_dice.item(), batch_size) 105 | iou_losses.update(loss_iou.item(), batch_size) 106 | anchor_losses.update(loss_anchor.item(), batch_size) 107 | contra_losses.update(loss_contra.item(), batch_size) 108 | total_losses.update(loss_total.item(), batch_size) 109 | 110 | fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' 111 | f' | Dataset: [{cfg.dataset} - {cfg.prompt}]' 112 | f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' 113 | f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' 114 | f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' 115 | f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' 116 | f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]' 117 | f' | Anchor Loss [{anchor_losses.val:.4f} ({anchor_losses.avg:.4f})]' 118 | f' | Contrast Loss [{contra_losses.val:.4f} ({contra_losses.avg:.4f})]' 119 | f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]') 120 | 121 | loss_logger = {"Focal Loss": focal_losses.avg, "Dice Loss": dice_losses.avg, 122 | "IoU Loss": iou_losses.avg, "Anchor Loss": anchor_losses.avg, 123 | "Contrast Loss": contra_losses.avg, "Total Loss": total_losses.avg} 124 | fabric.log_dict(loss_logger) 125 | torch.cuda.empty_cache() 126 | 127 | if epoch % cfg.eval_interval == 0: 128 | iou, f1_score = validate(fabric, cfg, model, val_dataloader, cfg.name, epoch * num_iters) 129 | if iou > max_iou: 130 | state = {"model": model, "optimizer": optimizer} 131 | fabric.save(os.path.join(cfg.out_dir, "save", f"{cfg.dataset}-{cfg.prompt}-last-ckpt.pth"), state) 132 | max_iou = iou 133 | 134 | 135 | def configure_opt(cfg: Box, model: Model): 136 | 137 | def lr_lambda(step): 138 | if step < cfg.opt.warmup_steps: 139 | return step / cfg.opt.warmup_steps 140 | elif step < cfg.opt.steps[0]: 141 | return 1.0 142 | elif step < cfg.opt.steps[1]: 143 | return 1 / cfg.opt.decay_factor 144 | else: 145 | return 1 / (cfg.opt.decay_factor**2) 146 | 147 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay) 148 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 149 | 150 | return optimizer, scheduler 151 | 152 | 153 | def corrupt_main(cfg): 154 | for corrupt in cfg.corruptions: 155 | cfg.corrupt = corrupt 156 | cfg.name = corrupt 157 | torch.cuda.empty_cache() 158 | main(cfg) 159 | 160 | 161 | def multi_main(cfg): 162 | prompts = ["box", "point"] 163 | for prompt in prompts: 164 | cfg.prompt = prompt 165 | torch.cuda.empty_cache() 166 | main(cfg) 167 | 168 | 169 | def main(cfg: Box, ckpt: str = None) -> None: 170 | gpu_ids = cfg.gpu_ids.split(',') 171 | num_devices = len(gpu_ids) 172 | 173 | fabric = L.Fabric(accelerator="auto", 174 | devices=num_devices, 175 | strategy="auto", 176 | loggers=[TensorBoardLogger(cfg.out_dir, name=f"{cfg.dataset}-{cfg.prompt}")]) 177 | fabric.launch() 178 | fabric.seed_everything(1337 + fabric.global_rank) 179 | 180 | if fabric.global_rank == 0: 181 | cfg_dict = cfg.to_dict() 182 | os.makedirs(os.path.join(cfg.out_dir, "configs"), exist_ok=True) 183 | cfg_dict_path = os.path.join(cfg.out_dir, "configs", f"{cfg.dataset}-{cfg.prompt}.yaml") 184 | with open(cfg_dict_path, "w") as file: 185 | yaml.dump(cfg_dict, file) 186 | 187 | os.makedirs(os.path.join(cfg.out_dir, "save"), exist_ok=True) 188 | create_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_head=cfg.csv_keys) 189 | 190 | with fabric.device: 191 | model = Model(cfg) 192 | model.setup() 193 | anchor_model = copy_model(model) 194 | LoRA_Sam(model.model, 4) 195 | 196 | load_datasets = call_load_dataset(cfg) 197 | train_data, val_data = load_datasets(cfg, model.model.image_encoder.img_size) 198 | optimizer, scheduler = configure_opt(cfg, model.model) 199 | 200 | fabric.print(f"Train Data: {len(train_data) * cfg.batch_size}; Val Data: {len(val_data) * cfg.val_batchsize}") 201 | num_iters = len(train_data) * cfg.batch_size 202 | 203 | if ckpt is not None: 204 | full_checkpoint = fabric.load(ckpt) 205 | model.load_state_dict(full_checkpoint["model"]) 206 | # optimizer.load_state_dict(full_checkpoint["optimizer"]) 207 | 208 | train_data = fabric._setup_dataloader(train_data) 209 | val_data = fabric._setup_dataloader(val_data) 210 | model, optimizer = fabric.setup(model, optimizer) 211 | 212 | validate(fabric, cfg, anchor_model, val_data, name=cfg.name, iters=0) 213 | train_sam(cfg, fabric, model, anchor_model, optimizer, scheduler, train_data, val_data, num_iters) 214 | 215 | del model, anchor_model, train_data, val_data 216 | 217 | 218 | if __name__ == "__main__": 219 | torch.cuda.empty_cache() 220 | torch.set_float32_matmul_precision('high') 221 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_ids 222 | 223 | main(cfg) 224 | # multi_main(cfg) 225 | torch.cuda.empty_cache() 226 | -------------------------------------------------------------------------------- /asserts/Pipeline.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/Pipeline.webp -------------------------------------------------------------------------------- /asserts/VISUAL.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/VISUAL.webp -------------------------------------------------------------------------------- /asserts/teaser.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhang-haojie/wesam/5ac493ff4cbc52b2efbc2c530b71718a9e5b1ea1/asserts/teaser.webp -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | base_config = { 2 | "eval_interval": 1, 3 | "ema_rate": 0.9999, 4 | "get_prompt": False, 5 | "split": True, 6 | "csv_keys": ["Name", "Prompt", "Mean IoU", "Mean F1", "iters", "loss"], 7 | "opt": { 8 | "learning_rate": 1e-4, 9 | "weight_decay": 1e-4, 10 | "decay_factor": 10, 11 | "steps": [60000, 86666], 12 | "warmup_steps": 250, 13 | }, 14 | "corruptions": [ 15 | "gaussian_noise", 16 | "shot_noise", 17 | "impulse_noise", 18 | "defocus_blur", 19 | "glass_blur", 20 | "motion_blur", 21 | "zoom_blur", 22 | "snow", 23 | "frost", 24 | "fog", 25 | "brightness", 26 | "contrast", 27 | "elastic_transform", 28 | "pixelate", 29 | "jpeg_compression", 30 | ], 31 | "model": { 32 | "type": "vit_b", 33 | "checkpoint": "./checkpoints/", 34 | "ckpt": "", 35 | "freeze": { 36 | "image_encoder": True, 37 | "prompt_encoder": True, 38 | "mask_decoder": True, 39 | }, 40 | }, 41 | "datasets": { 42 | "coco": { 43 | "root_dir": "./data/coco2017/val2017", 44 | "annotation_file": "./data/coco2017/annotations/instances_val2017.json", 45 | }, 46 | "coconut": { 47 | "root_dir": "./data/coconut/val2017", 48 | "annotation_file": "./data/coconut/coconut_dataset/annotations/annotations/relabeled_instances_val.json", 49 | }, 50 | "PascalVOC": { 51 | "root_dir": "./data/VOC2012/", 52 | }, 53 | "sa": { 54 | "root_dir": "./data/SA-1B", 55 | }, 56 | "Polyp":{ 57 | "root_dir": "./data/polyp/Kvasir-SEG", 58 | "annotation_file": "./data/polyp/Kvasir-SEG/kavsir_bboxes.json" 59 | }, 60 | "ISIC": { 61 | "root_dir": "./data/ISIC/", 62 | "train_list": "./data/ISIC/ISBI2016_ISIC_Part1_Training_GroundTruth.csv", 63 | "test_list": "./data/ISIC/ISBI2016_ISIC_Part1_Test_GroundTruth.csv" 64 | }, 65 | "ISTD": { 66 | "train": "./data/ISTD/train/train_A", 67 | "test": "./data/ISTD/test/test_A", 68 | }, 69 | "MSD": { 70 | "train": "./data/MSD/train/image", 71 | "test": "./data/MSD/test/image", 72 | }, 73 | "GDD": { 74 | "train": "./data/GDD/train/image", 75 | "test": "./data/GDD/test/image", 76 | }, 77 | "CAMO":{ 78 | "GT": "./data/CAMO-V.1.0-CVIU2019/GT", 79 | "train": "./data/CAMO-V.1.0-CVIU2019/Images/Train", 80 | "test": "./data/CAMO-V.1.0-CVIU2019/Images/Test", 81 | }, 82 | "COD10K":{ 83 | "GT": "./data/COD10K-v2/Test/GT_Object", 84 | "test": "./data/COD10K-v2/Test/Image", 85 | }, 86 | "robot": { 87 | "OCID": "./data/OCID-dataset", 88 | "OSD": "./data/OSD-0.2-depth" 89 | }, 90 | }, 91 | } 92 | -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | from configs.base_config import base_config 3 | 4 | 5 | config = { 6 | "gpu_ids": "0,1,2,3", 7 | "batch_size": 1, 8 | "val_batchsize": 4, 9 | "num_workers": 4, 10 | "num_iters": 40000, 11 | "max_nums": 40, 12 | "num_points": 5, 13 | "eval_interval": 1, 14 | "dataset": "COCO", 15 | "prompt": "box", 16 | "out_dir": "output/benchmark/COCO", 17 | "name": "baseline", 18 | "augment": True, 19 | "corrupt": None, 20 | "visual": False, 21 | "opt": { 22 | "learning_rate": 1e-4, 23 | }, 24 | "model": { 25 | "type": "vit_b", 26 | }, 27 | } 28 | 29 | cfg = Box(base_config) 30 | cfg.merge_update(config) 31 | -------------------------------------------------------------------------------- /datasets/CAMO.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from skimage.draw import polygon2mask 10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 11 | 12 | 13 | class CAMODataset(Dataset): 14 | def __init__(self, cfg, image_root, gt_root, transform=None, if_self_training=False): 15 | self.cfg = cfg 16 | self.root_dir = image_root 17 | self.transform = transform 18 | images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg')] 19 | images = sorted(images) 20 | 21 | self.images = images 22 | self.gts = [os.path.join(gt_root, os.path.basename(image_path.replace(".jpg", ".png"))) for image_path in self.images] 23 | self.filter_files() 24 | 25 | self.if_self_training = if_self_training 26 | 27 | def filter_files(self): 28 | assert len(self.images) == len(self.gts) 29 | images = [] 30 | gts = [] 31 | for img_path, gt_path in zip(self.images, self.gts): 32 | img = Image.open(img_path) 33 | gt = Image.open(gt_path) 34 | if img.size == gt.size: 35 | images.append(img_path) 36 | gts.append(gt_path) 37 | self.images = images 38 | self.gts = gts 39 | 40 | def rgb_loader(self, path): 41 | with open(path, 'rb') as f: 42 | img = Image.open(f) 43 | return img.convert('RGB') 44 | 45 | def binary_loader(self, path): 46 | with open(path, 'rb') as f: 47 | img = Image.open(f) 48 | return img.convert('L') 49 | 50 | def __len__(self): 51 | return len(self.images) 52 | 53 | def __getitem__(self, idx): 54 | image = np.array(self.rgb_loader(self.images[idx])) 55 | gt_mask = np.array(self.binary_loader(self.gts[idx])) 56 | 57 | # mask = gt_mask.astype(bool).astype(np.uint8) 58 | if self.cfg.get_prompt: 59 | image_info = {} 60 | height, width, _ = image.shape 61 | image_info["file_path"] = self.images[idx] 62 | image_info["height"] = height 63 | image_info["width"] = width 64 | return idx, image_info, image 65 | 66 | bboxes = [] 67 | masks = [] 68 | categories = [] 69 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 70 | assert gt_masks.sum() == (gt_mask > 0).sum() 71 | for mask in gt_masks: 72 | if np.all(mask == 0): 73 | continue 74 | masks.append(mask) 75 | x, y, w, h = cv2.boundingRect(mask) 76 | bboxes.append([x, y, x + w, y + h]) 77 | categories.append("0") 78 | 79 | if self.if_self_training: 80 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 81 | 82 | if self.transform: 83 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 84 | image_strong = self.transform.transform_image(image_strong) 85 | 86 | bboxes_weak = np.stack(bboxes_weak, axis=0) 87 | masks_weak = np.stack(masks_weak, axis=0) 88 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 89 | 90 | elif self.cfg.visual: 91 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 92 | origin_image = image 93 | origin_bboxes = bboxes 94 | origin_masks = masks 95 | if self.transform: 96 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 97 | 98 | bboxes = np.stack(bboxes, axis=0) 99 | masks = np.stack(masks, axis=0) 100 | origin_bboxes = np.stack(origin_bboxes, axis=0) 101 | origin_masks = np.stack(origin_masks, axis=0) 102 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 103 | 104 | else: 105 | if self.transform: 106 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 107 | 108 | bboxes = np.stack(bboxes, axis=0) 109 | masks = np.stack(masks, axis=0) 110 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 111 | 112 | 113 | class CAMODatasetwithCoarse(CAMODataset): 114 | 115 | def __getitem__(self, idx): 116 | image = np.array(self.rgb_loader(self.images[idx])) 117 | gt_mask = np.array(self.binary_loader(self.gts[idx])) 118 | 119 | bboxes = [] 120 | masks = [] 121 | coarse_masks = [] 122 | categories = [] 123 | approxes = [] 124 | 125 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 126 | assert gt_masks.sum() == (gt_mask > 0).sum() 127 | for mask in gt_masks: 128 | if np.all(mask == 0.): 129 | continue 130 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 131 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 132 | num_vertices = num_vertices if num_vertices > 3 else 3 133 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 134 | approx = approx.squeeze(1) 135 | 136 | coordinates = np.array(approx) 137 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 138 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 139 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 140 | if x_min == x_max or y_min == y_max: 141 | x, y, w, h = cv2.boundingRect(mask) 142 | bboxes.append([x, y, x + w, y + h]) 143 | else: 144 | bboxes.append([x_min, y_min, x_max, y_max]) 145 | 146 | masks.append(mask) 147 | coarse_masks.append(coarse_mask) 148 | approxes.append(approx) 149 | categories.append("0") 150 | 151 | if self.if_self_training: 152 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 153 | 154 | if self.transform: 155 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 156 | image_strong = self.transform.transform_image(image_strong) 157 | 158 | bboxes_weak = np.stack(bboxes_weak, axis=0) 159 | masks_weak = np.stack(masks_weak, axis=0) 160 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 161 | 162 | elif self.cfg.visual: 163 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 164 | origin_image = image 165 | origin_approxes = approxes 166 | origin_masks = masks 167 | if self.transform: 168 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 169 | 170 | bboxes = np.stack(bboxes, axis=0) 171 | masks = np.stack(masks, axis=0) 172 | # origin_approxes = np.stack(origin_approxes, axis=0) 173 | origin_masks = np.stack(origin_masks, axis=0) 174 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 175 | 176 | else: 177 | if self.transform: 178 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 179 | 180 | bboxes = np.stack(bboxes, axis=0) 181 | masks = np.stack(masks, axis=0) 182 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 183 | 184 | 185 | def load_datasets(cfg, img_size): 186 | transform = ResizeAndPad(img_size) 187 | val = CAMODataset( 188 | cfg, 189 | image_root=cfg.datasets.CAMO.test, 190 | gt_root=cfg.datasets.CAMO.GT, 191 | transform=transform, 192 | ) 193 | train = CAMODataset( 194 | cfg, 195 | image_root=cfg.datasets.CAMO.train, 196 | gt_root=cfg.datasets.CAMO.GT, 197 | transform=transform, 198 | if_self_training=cfg.augment, 199 | ) 200 | val_dataloader = DataLoader( 201 | val, 202 | batch_size=cfg.val_batchsize, 203 | shuffle=False, 204 | num_workers=cfg.num_workers, 205 | collate_fn=collate_fn, 206 | ) 207 | train_dataloader = DataLoader( 208 | train, 209 | batch_size=cfg.batch_size, 210 | shuffle=True, 211 | num_workers=cfg.num_workers, 212 | collate_fn=collate_fn, 213 | ) 214 | return train_dataloader, val_dataloader 215 | 216 | 217 | def load_datasets_coarse(cfg, img_size): 218 | transform = ResizeAndPad(img_size) 219 | val = CAMODatasetwithCoarse( 220 | cfg, 221 | image_root=cfg.datasets.CAMO.test, 222 | gt_root=cfg.datasets.CAMO.GT, 223 | transform=transform, 224 | ) 225 | train = CAMODatasetwithCoarse( 226 | cfg, 227 | image_root=cfg.datasets.CAMO.train, 228 | gt_root=cfg.datasets.CAMO.GT, 229 | transform=transform, 230 | if_self_training=cfg.augment, 231 | ) 232 | val_dataloader = DataLoader( 233 | val, 234 | batch_size=cfg.val_batchsize, 235 | shuffle=False, 236 | num_workers=cfg.num_workers, 237 | collate_fn=collate_fn, 238 | ) 239 | train_dataloader = DataLoader( 240 | train, 241 | batch_size=cfg.batch_size, 242 | shuffle=True, 243 | num_workers=cfg.num_workers, 244 | collate_fn=collate_fn, 245 | ) 246 | return train_dataloader, val_dataloader 247 | 248 | 249 | def load_datasets_visual(cfg, img_size): 250 | transform = ResizeAndPad(img_size) 251 | val = CAMODataset( 252 | cfg, 253 | image_root=cfg.datasets.CAMO.test, 254 | gt_root=cfg.datasets.CAMO.GT, 255 | transform=transform, 256 | ) 257 | val_dataloader = DataLoader( 258 | val, 259 | batch_size=cfg.val_batchsize, 260 | shuffle=False, 261 | num_workers=cfg.num_workers, 262 | collate_fn=collate_fn_, 263 | ) 264 | return val_dataloader 265 | 266 | 267 | def load_datasets_visual_coarse(cfg, img_size): 268 | transform = ResizeAndPad(img_size) 269 | val = CAMODatasetwithCoarse( 270 | cfg, 271 | image_root=cfg.datasets.CAMO.test, 272 | gt_root=cfg.datasets.CAMO.GT, 273 | transform=transform, 274 | ) 275 | val_dataloader = DataLoader( 276 | val, 277 | batch_size=cfg.val_batchsize, 278 | shuffle=False, 279 | num_workers=cfg.num_workers, 280 | collate_fn=collate_fn_, 281 | ) 282 | return val_dataloader 283 | 284 | 285 | def load_datasets_prompt(cfg, img_size): 286 | transform = ResizeAndPad(img_size) 287 | train = CAMODataset( 288 | cfg, 289 | image_root=cfg.datasets.CAMO.train, 290 | gt_root=cfg.datasets.CAMO.GT, 291 | transform=transform, 292 | if_self_training=cfg.augment, 293 | ) 294 | train_dataloader = DataLoader( 295 | train, 296 | batch_size=cfg.batch_size, 297 | shuffle=True, 298 | num_workers=cfg.num_workers, 299 | collate_fn=collate_fn_, 300 | ) 301 | return train_dataloader -------------------------------------------------------------------------------- /datasets/GDD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import glob 5 | import json 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from torch.utils.data import Dataset, DataLoader 10 | from skimage.draw import polygon2mask 11 | 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 13 | 14 | 15 | class GDDDataset(Dataset): 16 | def __init__(self, cfg, root_dir, transform=None, if_self_training=False): 17 | self.cfg = cfg 18 | self.root_dir = root_dir 19 | self.transform = transform 20 | 21 | images = [os.path.join(root_dir, f) for f in os.listdir(root_dir)] 22 | images = sorted(images) 23 | 24 | self.images = images 25 | self.gts = [image_path.replace("image", "mask").replace("jpg", "png") for image_path in self.images] 26 | 27 | self.if_self_training = if_self_training 28 | 29 | def __len__(self): 30 | return len(self.images) 31 | 32 | def __getitem__(self, idx): 33 | image_path = self.images[idx] 34 | image = cv2.imread(image_path) 35 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 36 | 37 | if self.cfg.get_prompt: 38 | image_info = {} 39 | height, width, _ = image.shape 40 | image_info["file_path"] = image_path 41 | image_info["height"] = height 42 | image_info["width"] = width 43 | return idx, image_info, image 44 | 45 | gt_path = self.gts[idx] 46 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 47 | 48 | masks = [] 49 | bboxes = [] 50 | categories = [] 51 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 52 | assert gt_masks.sum() == (gt_mask > 0).sum() 53 | for mask in gt_masks: 54 | masks.append(mask) 55 | x, y, w, h = cv2.boundingRect(mask) 56 | bboxes.append([x, y, x + w, y + h]) 57 | categories.append("0") 58 | 59 | if self.if_self_training: 60 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 61 | 62 | if self.transform: 63 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 64 | image_strong = self.transform.transform_image(image_strong) 65 | 66 | bboxes_weak = np.stack(bboxes_weak, axis=0) 67 | masks_weak = np.stack(masks_weak, axis=0) 68 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 69 | 70 | elif self.cfg.visual: 71 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 72 | origin_image = image 73 | origin_bboxes = bboxes 74 | origin_masks = masks 75 | if self.transform: 76 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 77 | 78 | bboxes = np.stack(bboxes, axis=0) 79 | masks = np.stack(masks, axis=0) 80 | origin_bboxes = np.stack(origin_bboxes, axis=0) 81 | origin_masks = np.stack(origin_masks, axis=0) 82 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 83 | 84 | else: 85 | if self.transform: 86 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 87 | 88 | bboxes = np.stack(bboxes, axis=0) 89 | masks = np.stack(masks, axis=0) 90 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 91 | 92 | 93 | class GDDDatasetwithCoarse(GDDDataset): 94 | 95 | def __getitem__(self, idx): 96 | image_path = self.images[idx] 97 | image = cv2.imread(image_path) 98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 99 | 100 | gt_path = self.gts[idx] 101 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 102 | 103 | masks = [] 104 | bboxes = [] 105 | approxes = [] 106 | categories = [] 107 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 108 | assert gt_masks.sum() == (gt_mask > 0).sum() 109 | for mask in gt_masks: 110 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 111 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 112 | num_vertices = num_vertices if num_vertices > 3 else 3 113 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 114 | approx = approx.squeeze(1) 115 | 116 | coordinates = np.array(approx) 117 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 118 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 119 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 120 | if x_min == x_max or y_min == y_max: 121 | x, y, w, h = cv2.boundingRect(mask) 122 | bboxes.append([x, y, x + w, y + h]) 123 | else: 124 | bboxes.append([x_min, y_min, x_max, y_max]) 125 | 126 | masks.append(mask) 127 | categories.append("0") 128 | approxes.append(approx) 129 | 130 | if self.if_self_training: 131 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 132 | 133 | if self.transform: 134 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 135 | image_strong = self.transform.transform_image(image_strong) 136 | 137 | bboxes_weak = np.stack(bboxes_weak, axis=0) 138 | masks_weak = np.stack(masks_weak, axis=0) 139 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 140 | 141 | elif self.cfg.visual: 142 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 143 | origin_image = image 144 | origin_approxes = approxes 145 | origin_masks = masks 146 | if self.transform: 147 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 148 | 149 | bboxes = np.stack(bboxes, axis=0) 150 | masks = np.stack(masks, axis=0) 151 | origin_masks = np.stack(origin_masks, axis=0) 152 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 153 | 154 | else: 155 | if self.transform: 156 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 157 | 158 | bboxes = np.stack(bboxes, axis=0) 159 | masks = np.stack(masks, axis=0) 160 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 161 | 162 | 163 | def load_datasets(cfg, img_size): 164 | transform = ResizeAndPad(img_size) 165 | val = GDDDataset( 166 | cfg, 167 | root_dir=cfg.datasets.GDD.test, 168 | transform=transform, 169 | ) 170 | train = GDDDataset( 171 | cfg, 172 | root_dir=cfg.datasets.GDD.train, 173 | transform=transform, 174 | if_self_training=cfg.augment, 175 | ) 176 | val_dataloader = DataLoader( 177 | val, 178 | batch_size=cfg.val_batchsize, 179 | shuffle=False, 180 | num_workers=cfg.num_workers, 181 | collate_fn=collate_fn, 182 | ) 183 | train_dataloader = DataLoader( 184 | train, 185 | batch_size=cfg.batch_size, 186 | shuffle=True, 187 | num_workers=cfg.num_workers, 188 | collate_fn=collate_fn, 189 | ) 190 | return train_dataloader, val_dataloader 191 | 192 | 193 | def load_datasets_coarse(cfg, img_size): 194 | transform = ResizeAndPad(img_size) 195 | val = GDDDatasetwithCoarse( 196 | cfg, 197 | root_dir=cfg.datasets.GDD.test, 198 | transform=transform, 199 | ) 200 | train = GDDDatasetwithCoarse( 201 | cfg, 202 | root_dir=cfg.datasets.GDD.train, 203 | transform=transform, 204 | if_self_training=cfg.augment, 205 | ) 206 | val_dataloader = DataLoader( 207 | val, 208 | batch_size=cfg.val_batchsize, 209 | shuffle=False, 210 | num_workers=cfg.num_workers, 211 | collate_fn=collate_fn, 212 | ) 213 | train_dataloader = DataLoader( 214 | train, 215 | batch_size=cfg.batch_size, 216 | shuffle=True, 217 | num_workers=cfg.num_workers, 218 | collate_fn=collate_fn, 219 | ) 220 | return train_dataloader, val_dataloader 221 | 222 | 223 | def load_datasets_visual(cfg, img_size): 224 | transform = ResizeAndPad(img_size) 225 | val = GDDDataset( 226 | cfg, 227 | root_dir=cfg.datasets.GDD.test, 228 | transform=transform, 229 | ) 230 | val_dataloader = DataLoader( 231 | val, 232 | batch_size=cfg.val_batchsize, 233 | shuffle=False, 234 | num_workers=cfg.num_workers, 235 | collate_fn=collate_fn_, 236 | ) 237 | return val_dataloader 238 | 239 | 240 | def load_datasets_visual_coarse(cfg, img_size): 241 | transform = ResizeAndPad(img_size) 242 | val = GDDDatasetwithCoarse( 243 | cfg, 244 | root_dir=cfg.datasets.GDD.test, 245 | transform=transform, 246 | ) 247 | val_dataloader = DataLoader( 248 | val, 249 | batch_size=cfg.val_batchsize, 250 | shuffle=False, 251 | num_workers=cfg.num_workers, 252 | collate_fn=collate_fn_, 253 | ) 254 | return val_dataloader 255 | 256 | 257 | def load_datasets_prompt(cfg, img_size): 258 | transform = ResizeAndPad(img_size) 259 | train = GDDDataset( 260 | cfg, 261 | root_dir=cfg.datasets.GDD.train, 262 | transform=transform, 263 | if_self_training=cfg.augment, 264 | ) 265 | train_dataloader = DataLoader( 266 | train, 267 | batch_size=cfg.batch_size, 268 | shuffle=True, 269 | num_workers=cfg.num_workers, 270 | collate_fn=collate_fn_, 271 | ) 272 | return train_dataloader 273 | -------------------------------------------------------------------------------- /datasets/ISIC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import glob 5 | import json 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from torch.utils.data import Dataset, DataLoader 10 | from skimage.draw import polygon2mask 11 | 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 13 | 14 | 15 | class ISICDataset(Dataset): 16 | def __init__(self, cfg, root_dir, list_file, transform=None, if_self_training=False): 17 | self.cfg = cfg 18 | df = pd.read_csv(os.path.join(list_file), encoding='gbk') 19 | self.name_list = df.iloc[:,1].tolist() 20 | self.label_list = df.iloc[:,2].tolist() 21 | self.root_dir = root_dir 22 | self.transform = transform 23 | 24 | self.if_self_training = if_self_training 25 | 26 | def __len__(self): 27 | return len(self.name_list) 28 | 29 | def __getitem__(self, idx): 30 | name = self.name_list[idx] 31 | image_path = os.path.join(self.root_dir, name) 32 | image = cv2.imread(image_path) 33 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 34 | 35 | if self.cfg.get_prompt: 36 | image_info = {} 37 | height, width, _ = image.shape 38 | image_info["file_path"] = image_path 39 | image_info["height"] = height 40 | image_info["width"] = width 41 | return idx, image_info, image 42 | 43 | label_name = self.label_list[idx] 44 | gt_path = os.path.join(self.root_dir, label_name) 45 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 46 | 47 | masks = [] 48 | bboxes = [] 49 | categories = [] 50 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 51 | assert gt_masks.sum() == (gt_mask > 0).sum() 52 | for mask in gt_masks: 53 | masks.append(mask) 54 | x, y, w, h = cv2.boundingRect(mask) 55 | bboxes.append([x, y, x + w, y + h]) 56 | categories.append("0") 57 | 58 | if self.if_self_training: 59 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 60 | 61 | if self.transform: 62 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 63 | image_strong = self.transform.transform_image(image_strong) 64 | 65 | bboxes_weak = np.stack(bboxes_weak, axis=0) 66 | masks_weak = np.stack(masks_weak, axis=0) 67 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 68 | 69 | elif self.cfg.visual: 70 | file_name = os.path.splitext(os.path.basename(name))[0] 71 | origin_image = image 72 | origin_bboxes = bboxes 73 | origin_masks = masks 74 | if self.transform: 75 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 76 | 77 | bboxes = np.stack(bboxes, axis=0) 78 | masks = np.stack(masks, axis=0) 79 | origin_bboxes = np.stack(origin_bboxes, axis=0) 80 | origin_masks = np.stack(origin_masks, axis=0) 81 | return file_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 82 | 83 | else: 84 | if self.transform: 85 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 86 | 87 | bboxes = np.stack(bboxes, axis=0) 88 | masks = np.stack(masks, axis=0) 89 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 90 | 91 | 92 | class ISICDatasetwithCoarse(ISICDataset): 93 | 94 | def __getitem__(self, idx): 95 | name = self.name_list[idx] 96 | image_path = os.path.join(self.root_dir, name) 97 | image = cv2.imread(image_path) 98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 99 | 100 | label_name = self.label_list[idx] 101 | gt_path = os.path.join(self.root_dir, label_name) 102 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 103 | 104 | masks = [] 105 | bboxes = [] 106 | approxes = [] 107 | categories = [] 108 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 109 | assert gt_masks.sum() == (gt_mask > 0).sum() 110 | for mask in gt_masks: 111 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 112 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 113 | num_vertices = num_vertices if num_vertices > 3 else 3 114 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 115 | approx = approx.squeeze(1) 116 | 117 | coordinates = np.array(approx) 118 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 119 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 120 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 121 | if x_min == x_max or y_min == y_max: 122 | x, y, w, h = cv2.boundingRect(mask) 123 | bboxes.append([x, y, x + w, y + h]) 124 | else: 125 | bboxes.append([x_min, y_min, x_max, y_max]) 126 | 127 | masks.append(mask) 128 | categories.append("0") 129 | approxes.append(approx) 130 | 131 | if self.if_self_training: 132 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 133 | 134 | if self.transform: 135 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 136 | image_strong = self.transform.transform_image(image_strong) 137 | 138 | bboxes_weak = np.stack(bboxes_weak, axis=0) 139 | masks_weak = np.stack(masks_weak, axis=0) 140 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 141 | 142 | elif self.cfg.visual: 143 | file_name = os.path.splitext(os.path.basename(name))[0] 144 | origin_image = image 145 | origin_approxes = approxes 146 | origin_masks = masks 147 | if self.transform: 148 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 149 | 150 | bboxes = np.stack(bboxes, axis=0) 151 | masks = np.stack(masks, axis=0) 152 | origin_masks = np.stack(origin_masks, axis=0) 153 | return file_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 154 | 155 | else: 156 | if self.transform: 157 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 158 | 159 | bboxes = np.stack(bboxes, axis=0) 160 | masks = np.stack(masks, axis=0) 161 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 162 | 163 | 164 | def load_datasets(cfg, img_size): 165 | transform = ResizeAndPad(img_size) 166 | val = ISICDataset( 167 | cfg, 168 | root_dir=cfg.datasets.ISIC.root_dir, 169 | list_file=cfg.datasets.ISIC.test_list, 170 | transform=transform, 171 | ) 172 | train = ISICDataset( 173 | cfg, 174 | root_dir=cfg.datasets.ISIC.root_dir, 175 | list_file=cfg.datasets.ISIC.train_list, 176 | transform=transform, 177 | if_self_training=cfg.augment, 178 | ) 179 | val_dataloader = DataLoader( 180 | val, 181 | batch_size=cfg.val_batchsize, 182 | shuffle=False, 183 | num_workers=cfg.num_workers, 184 | collate_fn=collate_fn, 185 | ) 186 | train_dataloader = DataLoader( 187 | train, 188 | batch_size=cfg.batch_size, 189 | shuffle=True, 190 | num_workers=cfg.num_workers, 191 | collate_fn=collate_fn, 192 | ) 193 | return train_dataloader, val_dataloader 194 | 195 | 196 | def load_datasets_coarse(cfg, img_size): 197 | transform = ResizeAndPad(img_size) 198 | val = ISICDatasetwithCoarse( 199 | cfg, 200 | root_dir=cfg.datasets.ISIC.root_dir, 201 | list_file=cfg.datasets.ISIC.test_list, 202 | transform=transform, 203 | ) 204 | train = ISICDatasetwithCoarse( 205 | cfg, 206 | root_dir=cfg.datasets.ISIC.root_dir, 207 | list_file=cfg.datasets.ISIC.train_list, 208 | transform=transform, 209 | if_self_training=cfg.augment, 210 | ) 211 | val_dataloader = DataLoader( 212 | val, 213 | batch_size=cfg.val_batchsize, 214 | shuffle=False, 215 | num_workers=cfg.num_workers, 216 | collate_fn=collate_fn, 217 | ) 218 | train_dataloader = DataLoader( 219 | train, 220 | batch_size=cfg.batch_size, 221 | shuffle=True, 222 | num_workers=cfg.num_workers, 223 | collate_fn=collate_fn, 224 | ) 225 | return train_dataloader, val_dataloader 226 | 227 | 228 | def load_datasets_visual(cfg, img_size): 229 | transform = ResizeAndPad(img_size) 230 | val = ISICDataset( 231 | cfg, 232 | root_dir=cfg.datasets.ISIC.root_dir, 233 | list_file=cfg.datasets.ISIC.test_list, 234 | transform=transform, 235 | ) 236 | val_dataloader = DataLoader( 237 | val, 238 | batch_size=cfg.val_batchsize, 239 | shuffle=False, 240 | num_workers=cfg.num_workers, 241 | collate_fn=collate_fn_, 242 | ) 243 | return val_dataloader 244 | 245 | 246 | def load_datasets_visual_coarse(cfg, img_size): 247 | transform = ResizeAndPad(img_size) 248 | val = ISICDatasetwithCoarse( 249 | cfg, 250 | root_dir=cfg.datasets.ISIC.root_dir, 251 | list_file=cfg.datasets.ISIC.test_list, 252 | transform=transform, 253 | ) 254 | val_dataloader = DataLoader( 255 | val, 256 | batch_size=cfg.val_batchsize, 257 | shuffle=False, 258 | num_workers=cfg.num_workers, 259 | collate_fn=collate_fn_, 260 | ) 261 | return val_dataloader 262 | 263 | 264 | def load_datasets_prompt(cfg, img_size): 265 | transform = ResizeAndPad(img_size) 266 | train = ISICDataset( 267 | cfg, 268 | root_dir=cfg.datasets.ISIC.root_dir, 269 | list_file=cfg.datasets.ISIC.train_list, 270 | transform=transform, 271 | if_self_training=cfg.augment, 272 | ) 273 | train_dataloader = DataLoader( 274 | train, 275 | batch_size=cfg.batch_size, 276 | shuffle=True, 277 | num_workers=cfg.num_workers, 278 | collate_fn=collate_fn_, 279 | ) 280 | return train_dataloader 281 | -------------------------------------------------------------------------------- /datasets/ISTD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from skimage.draw import polygon2mask 10 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 11 | 12 | 13 | class ISTDDataset(Dataset): 14 | def __init__(self, cfg, image_root, transform=None, if_self_training=False): 15 | self.cfg = cfg 16 | self.root_dir = image_root 17 | self.transform = transform 18 | images = [os.path.join(image_root, f) for f in os.listdir(image_root)] 19 | images = sorted(images) 20 | 21 | self.images = images 22 | self.gts = [image_path.replace("A", "B") for image_path in self.images] 23 | 24 | self.if_self_training = if_self_training 25 | 26 | def rgb_loader(self, path): 27 | with open(path, 'rb') as f: 28 | img = Image.open(f) 29 | return img.convert('RGB') 30 | 31 | def binary_loader(self, path): 32 | with open(path, 'rb') as f: 33 | img = Image.open(f) 34 | return img.convert('L') 35 | 36 | def __len__(self): 37 | return len(self.images) 38 | 39 | def __getitem__(self, idx): 40 | image = np.array(self.rgb_loader(self.images[idx])) 41 | gt_mask = np.array(self.binary_loader(self.gts[idx])) 42 | 43 | if self.cfg.get_prompt: 44 | image_info = {} 45 | height, width, _ = image.shape 46 | image_info["file_path"] = self.images[idx] 47 | image_info["height"] = height 48 | image_info["width"] = width 49 | return idx, image_info, image 50 | 51 | bboxes = [] 52 | masks = [] 53 | categories = [] 54 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 55 | assert gt_masks.sum() == (gt_mask > 0).sum() 56 | for mask in gt_masks: 57 | if np.all(mask == 0): 58 | continue 59 | masks.append(mask) 60 | x, y, w, h = cv2.boundingRect(mask) 61 | bboxes.append([x, y, x + w, y + h]) 62 | categories.append("0") 63 | 64 | if self.if_self_training: 65 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 66 | 67 | if self.transform: 68 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 69 | image_strong = self.transform.transform_image(image_strong) 70 | 71 | bboxes_weak = np.stack(bboxes_weak, axis=0) 72 | masks_weak = np.stack(masks_weak, axis=0) 73 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 74 | 75 | elif self.cfg.visual: 76 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 77 | origin_image = image 78 | origin_bboxes = bboxes 79 | origin_masks = masks 80 | if self.transform: 81 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 82 | 83 | bboxes = np.stack(bboxes, axis=0) 84 | masks = np.stack(masks, axis=0) 85 | origin_bboxes = np.stack(origin_bboxes, axis=0) 86 | origin_masks = np.stack(origin_masks, axis=0) 87 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 88 | 89 | else: 90 | if self.transform: 91 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 92 | 93 | bboxes = np.stack(bboxes, axis=0) 94 | masks = np.stack(masks, axis=0) 95 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 96 | 97 | 98 | class ISTDDatasetwithCoarse(ISTDDataset): 99 | 100 | def __getitem__(self, idx): 101 | image = np. array(self.rgb_loader(self.images[idx])) 102 | gt_mask = np. array(self.binary_loader(self.gts[idx])) 103 | 104 | bboxes = [] 105 | masks = [] 106 | coarse_masks = [] 107 | categories = [] 108 | approxes = [] 109 | 110 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 111 | assert gt_masks.sum() == (gt_mask > 0).sum() 112 | for mask in gt_masks: 113 | if np.all(mask == 0.): 114 | continue 115 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 116 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 117 | num_vertices = num_vertices if num_vertices > 3 else 3 118 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 119 | approx = approx.squeeze(1) 120 | 121 | coordinates = np.array(approx) 122 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 123 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 124 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 125 | if x_min == x_max or y_min == y_max: 126 | x, y, w, h = cv2.boundingRect(mask) 127 | bboxes.append([x, y, x + w, y + h]) 128 | else: 129 | bboxes.append([x_min, y_min, x_max, y_max]) 130 | 131 | masks.append(mask) 132 | coarse_masks.append(coarse_mask) 133 | approxes.append(approx) 134 | categories.append("0") 135 | 136 | if self.if_self_training: 137 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 138 | 139 | if self.transform: 140 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 141 | image_strong = self.transform.transform_image(image_strong) 142 | 143 | bboxes_weak = np.stack(bboxes_weak, axis=0) 144 | masks_weak = np.stack(masks_weak, axis=0) 145 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 146 | 147 | elif self.cfg.visual: 148 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 149 | origin_image = image 150 | origin_approxes = approxes 151 | origin_masks = masks 152 | if self.transform: 153 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 154 | 155 | bboxes = np.stack(bboxes, axis=0) 156 | masks = np.stack(masks, axis=0) 157 | origin_masks = np.stack(origin_masks, axis=0) 158 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 159 | 160 | else: 161 | if self.transform: 162 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 163 | 164 | bboxes = np.stack(bboxes, axis=0) 165 | masks = np.stack(masks, axis=0) 166 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 167 | 168 | 169 | def load_datasets(cfg, img_size): 170 | transform = ResizeAndPad(img_size) 171 | val = ISTDDataset( 172 | cfg, 173 | image_root=cfg.datasets.ISTD.test, 174 | transform=transform, 175 | ) 176 | train = ISTDDataset( 177 | cfg, 178 | image_root=cfg.datasets.ISTD.train, 179 | transform=transform, 180 | if_self_training=cfg.augment, 181 | ) 182 | val_dataloader = DataLoader( 183 | val, 184 | batch_size=cfg.val_batchsize, 185 | shuffle=False, 186 | num_workers=cfg.num_workers, 187 | collate_fn=collate_fn, 188 | ) 189 | train_dataloader = DataLoader( 190 | train, 191 | batch_size=cfg.batch_size, 192 | shuffle=True, 193 | num_workers=cfg.num_workers, 194 | collate_fn=collate_fn, 195 | ) 196 | return train_dataloader, val_dataloader 197 | 198 | 199 | def load_datasets_coarse(cfg, img_size): 200 | transform = ResizeAndPad(img_size) 201 | val = ISTDDatasetwithCoarse( 202 | cfg, 203 | image_root=cfg.datasets.ISTD.test, 204 | transform=transform, 205 | ) 206 | train = ISTDDatasetwithCoarse( 207 | cfg, 208 | image_root=cfg.datasets.ISTD.train, 209 | transform=transform, 210 | ) 211 | val_dataloader = DataLoader( 212 | val, 213 | batch_size=cfg.val_batchsize, 214 | shuffle=False, 215 | num_workers=cfg.num_workers, 216 | collate_fn=collate_fn, 217 | ) 218 | train_dataloader = DataLoader( 219 | train, 220 | batch_size=cfg.batch_size, 221 | shuffle=True, 222 | num_workers=cfg.num_workers, 223 | collate_fn=collate_fn, 224 | ) 225 | return train_dataloader, val_dataloader 226 | 227 | 228 | def load_datasets_visual(cfg, img_size): 229 | transform = ResizeAndPad(img_size) 230 | val = ISTDDataset( 231 | cfg, 232 | image_root=cfg.datasets.ISTD.test, 233 | transform=transform, 234 | ) 235 | val_dataloader = DataLoader( 236 | val, 237 | batch_size=cfg.val_batchsize, 238 | shuffle=False, 239 | num_workers=cfg.num_workers, 240 | collate_fn=collate_fn_, 241 | ) 242 | return val_dataloader 243 | 244 | 245 | def load_datasets_visual_coarse(cfg, img_size): 246 | transform = ResizeAndPad(img_size) 247 | val = ISTDDatasetwithCoarse( 248 | cfg, 249 | image_root=cfg.datasets.ISTD.test, 250 | transform=transform, 251 | ) 252 | val_dataloader = DataLoader( 253 | val, 254 | batch_size=cfg.val_batchsize, 255 | shuffle=False, 256 | num_workers=cfg.num_workers, 257 | collate_fn=collate_fn_, 258 | ) 259 | return val_dataloader 260 | 261 | 262 | def load_datasets_prompt(cfg, img_size): 263 | transform = ResizeAndPad(img_size) 264 | train = ISTDDataset( 265 | cfg, 266 | image_root=cfg.datasets.ISTD.train, 267 | transform=transform, 268 | if_self_training=cfg.augment, 269 | ) 270 | train_dataloader = DataLoader( 271 | train, 272 | batch_size=cfg.batch_size, 273 | shuffle=True, 274 | num_workers=cfg.num_workers, 275 | collate_fn=collate_fn_, 276 | ) 277 | return train_dataloader 278 | -------------------------------------------------------------------------------- /datasets/MSD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import glob 5 | import json 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from torch.utils.data import Dataset, DataLoader 10 | from skimage.draw import polygon2mask 11 | 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 13 | 14 | 15 | class MSDDataset(Dataset): 16 | def __init__(self, cfg, root_dir, transform=None, if_self_training=False): 17 | self.cfg = cfg 18 | self.root_dir = root_dir 19 | self.transform = transform 20 | 21 | images = [os.path.join(root_dir, f) for f in os.listdir(root_dir)] 22 | images = sorted(images) 23 | 24 | self.images = images 25 | self.gts = [image_path.replace("image", "mask").replace("jpg", "png") for image_path in self.images] 26 | 27 | self.if_self_training = if_self_training 28 | 29 | def __len__(self): 30 | return len(self.images) 31 | 32 | def __getitem__(self, idx): 33 | image_path = self.images[idx] 34 | image = cv2.imread(image_path) 35 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 36 | 37 | if self.cfg.get_prompt: 38 | image_info = {} 39 | height, width, _ = image.shape 40 | image_info["file_path"] = image_path 41 | image_info["height"] = height 42 | image_info["width"] = width 43 | return idx, image_info, image 44 | 45 | gt_path = self.gts[idx] 46 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 47 | 48 | masks = [] 49 | bboxes = [] 50 | categories = [] 51 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 52 | assert gt_masks.sum() == (gt_mask > 0).sum() 53 | for mask in gt_masks: 54 | masks.append(mask) 55 | x, y, w, h = cv2.boundingRect(mask) 56 | bboxes.append([x, y, x + w, y + h]) 57 | categories.append("0") 58 | 59 | if self.if_self_training: 60 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 61 | 62 | if self.transform: 63 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 64 | image_strong = self.transform.transform_image(image_strong) 65 | 66 | bboxes_weak = np.stack(bboxes_weak, axis=0) 67 | masks_weak = np.stack(masks_weak, axis=0) 68 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 69 | 70 | elif self.cfg.visual: 71 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 72 | origin_image = image 73 | origin_bboxes = bboxes 74 | origin_masks = masks 75 | if self.transform: 76 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), True) 77 | 78 | bboxes = np.stack(bboxes, axis=0) 79 | masks = np.stack(masks, axis=0) 80 | origin_bboxes = np.stack(origin_bboxes, axis=0) 81 | origin_masks = np.stack(origin_masks, axis=0) 82 | return image_name, padding, origin_image, origin_bboxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 83 | 84 | else: 85 | if self.transform: 86 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 87 | 88 | bboxes = np.stack(bboxes, axis=0) 89 | masks = np.stack(masks, axis=0) 90 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 91 | 92 | 93 | class MSDDatasetwithCoarse(MSDDataset): 94 | 95 | def __getitem__(self, idx): 96 | image_path = self.images[idx] 97 | image = cv2.imread(image_path) 98 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 99 | 100 | gt_path = self.gts[idx] 101 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 102 | 103 | masks = [] 104 | bboxes = [] 105 | approxes = [] 106 | categories = [] 107 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 108 | assert gt_masks.sum() == (gt_mask > 0).sum() 109 | for mask in gt_masks: 110 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 111 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 112 | num_vertices = num_vertices if num_vertices > 3 else 3 113 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 114 | approx = approx.squeeze(1) 115 | 116 | coordinates = np.array(approx) 117 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 118 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 119 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 120 | if x_min == x_max or y_min == y_max: 121 | x, y, w, h = cv2.boundingRect(mask) 122 | bboxes.append([x, y, x + w, y + h]) 123 | else: 124 | bboxes.append([x_min, y_min, x_max, y_max]) 125 | 126 | masks.append(mask) 127 | categories.append("0") 128 | approxes.append(approx) 129 | 130 | if self.if_self_training: 131 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 132 | 133 | if self.transform: 134 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 135 | image_strong = self.transform.transform_image(image_strong) 136 | 137 | bboxes_weak = np.stack(bboxes_weak, axis=0) 138 | masks_weak = np.stack(masks_weak, axis=0) 139 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 140 | 141 | elif self.cfg.visual: 142 | image_name = os.path.splitext(os.path.basename(self.images[idx]))[0] 143 | origin_image = image 144 | origin_approxes = approxes 145 | origin_masks = masks 146 | if self.transform: 147 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 148 | 149 | bboxes = np.stack(bboxes, axis=0) 150 | masks = np.stack(masks, axis=0) 151 | origin_masks = np.stack(origin_masks, axis=0) 152 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 153 | 154 | else: 155 | if self.transform: 156 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 157 | 158 | bboxes = np.stack(bboxes, axis=0) 159 | masks = np.stack(masks, axis=0) 160 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 161 | 162 | 163 | def load_datasets(cfg, img_size): 164 | transform = ResizeAndPad(img_size) 165 | val = MSDDataset( 166 | cfg, 167 | root_dir=cfg.datasets.MSD.test, 168 | transform=transform, 169 | ) 170 | train = MSDDataset( 171 | cfg, 172 | root_dir=cfg.datasets.MSD.train, 173 | transform=transform, 174 | if_self_training=cfg.augment, 175 | ) 176 | val_dataloader = DataLoader( 177 | val, 178 | batch_size=cfg.val_batchsize, 179 | shuffle=False, 180 | num_workers=cfg.num_workers, 181 | collate_fn=collate_fn, 182 | ) 183 | train_dataloader = DataLoader( 184 | train, 185 | batch_size=cfg.batch_size, 186 | shuffle=True, 187 | num_workers=cfg.num_workers, 188 | collate_fn=collate_fn, 189 | ) 190 | return train_dataloader, val_dataloader 191 | 192 | 193 | def load_datasets_coarse(cfg, img_size): 194 | transform = ResizeAndPad(img_size) 195 | val = MSDDatasetwithCoarse( 196 | cfg, 197 | root_dir=cfg.datasets.MSD.test, 198 | transform=transform, 199 | ) 200 | train = MSDDatasetwithCoarse( 201 | cfg, 202 | root_dir=cfg.datasets.MSD.train, 203 | transform=transform, 204 | if_self_training=cfg.augment, 205 | ) 206 | val_dataloader = DataLoader( 207 | val, 208 | batch_size=cfg.val_batchsize, 209 | shuffle=False, 210 | num_workers=cfg.num_workers, 211 | collate_fn=collate_fn, 212 | ) 213 | train_dataloader = DataLoader( 214 | train, 215 | batch_size=cfg.batch_size, 216 | shuffle=True, 217 | num_workers=cfg.num_workers, 218 | collate_fn=collate_fn, 219 | ) 220 | return train_dataloader, val_dataloader 221 | 222 | 223 | def load_datasets_visual(cfg, img_size): 224 | transform = ResizeAndPad(img_size) 225 | val = MSDDataset( 226 | cfg, 227 | root_dir=cfg.datasets.MSD.test, 228 | transform=transform, 229 | ) 230 | val_dataloader = DataLoader( 231 | val, 232 | batch_size=cfg.val_batchsize, 233 | shuffle=False, 234 | num_workers=cfg.num_workers, 235 | collate_fn=collate_fn_, 236 | ) 237 | return val_dataloader 238 | 239 | 240 | def load_datasets_visual_coarse(cfg, img_size): 241 | transform = ResizeAndPad(img_size) 242 | val = MSDDatasetwithCoarse( 243 | cfg, 244 | root_dir=cfg.datasets.MSD.test, 245 | transform=transform, 246 | ) 247 | val_dataloader = DataLoader( 248 | val, 249 | batch_size=cfg.val_batchsize, 250 | shuffle=False, 251 | num_workers=cfg.num_workers, 252 | collate_fn=collate_fn_, 253 | ) 254 | return val_dataloader 255 | 256 | 257 | def load_datasets_prompt(cfg, img_size): 258 | transform = ResizeAndPad(img_size) 259 | train = MSDDataset( 260 | cfg, 261 | root_dir=cfg.datasets.MSD.train, 262 | transform=transform, 263 | if_self_training=cfg.augment, 264 | ) 265 | train_dataloader = DataLoader( 266 | train, 267 | batch_size=cfg.batch_size, 268 | shuffle=True, 269 | num_workers=cfg.num_workers, 270 | collate_fn=collate_fn_, 271 | ) 272 | return train_dataloader 273 | -------------------------------------------------------------------------------- /datasets/OSD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import glob 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from skimage.draw import polygon2mask 10 | from pathlib import Path 11 | from PIL import Image 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, decode_mask, collate_fn_ 13 | 14 | 15 | class OSDObject(Dataset): 16 | def __init__(self, cfg, root_dir, transform=None, split=False, training=False, if_self_training=False): 17 | self.cfg = cfg 18 | self._osd_object_path = root_dir 19 | self.transform = transform 20 | # get all images 21 | data_path = os.path.join(self._osd_object_path, 'image_color') 22 | all_image_paths = sorted(glob.glob(data_path + '/*.png')) 23 | all_image_paths = self.check_empty(all_image_paths) 24 | 25 | if split: 26 | train_image_paths = [] 27 | eval_image_paths = [] 28 | while all_image_paths: 29 | for _ in range(6): 30 | if all_image_paths: 31 | train_image_paths.append(all_image_paths.pop(0)) 32 | if all_image_paths: 33 | eval_image_paths.append(all_image_paths.pop(0)) 34 | 35 | if training: 36 | random.shuffle(train_image_paths) 37 | image_paths = train_image_paths 38 | else: 39 | random.shuffle(eval_image_paths) 40 | image_paths = eval_image_paths 41 | else: 42 | image_paths = all_image_paths 43 | 44 | self.image_files = image_paths 45 | # self.image_files = all_image_paths 46 | 47 | self.if_self_training = if_self_training 48 | assert os.path.exists(self._osd_object_path), \ 49 | 'osd_object path does not exist: {}'.format(self._osd_object_path) 50 | 51 | def process_label(self, foreground_labels): 52 | """ Process foreground_labels 53 | - Map the foreground_labels to {0, 1, ..., K-1} 54 | 55 | @param foreground_labels: a [H x W] numpy array of labels 56 | 57 | @return: foreground_labels 58 | """ 59 | # Find the unique (nonnegative) foreground_labels, map them to {0, ..., K-1} 60 | unique_nonnegative_indices = np.unique(foreground_labels) 61 | mapped_labels = foreground_labels.copy() 62 | for k in range(unique_nonnegative_indices.shape[0]): 63 | mapped_labels[foreground_labels == unique_nonnegative_indices[k]] = k 64 | foreground_labels = mapped_labels 65 | return foreground_labels 66 | 67 | def __len__(self): 68 | return len(self.image_files) 69 | 70 | def check_empty(self, image_paths): 71 | new_image_paths = [] 72 | for filename in image_paths: 73 | labels_filename = str(filename).replace('image_color', 'annotation') 74 | annotation = Image.open(labels_filename) 75 | foreground_labels = np.array(annotation) 76 | # mask table as background 77 | foreground_labels[foreground_labels == 1] = 0 78 | if 'table' in labels_filename: 79 | foreground_labels[foreground_labels == 2] = 0 80 | gt_mask = self.process_label(foreground_labels) 81 | if not np.all(gt_mask == 0): 82 | new_image_paths.append(filename) 83 | return new_image_paths 84 | 85 | def __getitem__(self, idx): 86 | filename = self.image_files[idx] 87 | 88 | image = cv2.imread(filename) 89 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 90 | 91 | if self.cfg.get_prompt: 92 | image_info = {} 93 | height, width, _ = image.shape 94 | image_info["file_path"] = filename 95 | image_info["height"] = height 96 | image_info["width"] = width 97 | return idx, image_info, image 98 | 99 | labels_filename = filename.replace('image_color', 'annotation') 100 | annotation = Image.open(labels_filename) 101 | foreground_labels = np.array(annotation) 102 | 103 | # mask table as background 104 | foreground_labels[foreground_labels == 1] = 0 105 | if 'table' in labels_filename: 106 | foreground_labels[foreground_labels == 2] = 0 107 | gt_mask = self.process_label(foreground_labels) 108 | 109 | bboxes = [] 110 | masks = [] 111 | categories = [] 112 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 113 | assert gt_masks.sum() == (gt_mask > 0).sum() 114 | for mask in gt_masks: 115 | masks.append(mask) 116 | x, y, w, h = cv2.boundingRect(mask) 117 | bboxes.append([x, y, x + w, y + h]) 118 | categories.append("0") 119 | 120 | if self.if_self_training: 121 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 122 | 123 | if self.transform: 124 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 125 | image_strong = self.transform.transform_image(image_strong) 126 | 127 | bboxes_weak = np.stack(bboxes_weak, axis=0) 128 | masks_weak = np.stack(masks_weak, axis=0) 129 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 130 | else: 131 | if self.transform: 132 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 133 | 134 | bboxes = np.stack(bboxes, axis=0) 135 | masks = np.stack(masks, axis=0) 136 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 137 | 138 | 139 | class OSDObjectwithCoarse(OSDObject): 140 | 141 | def __getitem__(self, idx): 142 | filename = self.image_files[idx] 143 | 144 | image = cv2.imread(filename) 145 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 146 | 147 | labels_filename = filename.replace('image_color', 'annotation') 148 | annotation = Image.open(labels_filename) 149 | foreground_labels = np.array(annotation) 150 | 151 | # mask table as background 152 | foreground_labels[foreground_labels == 1] = 0 153 | if 'table' in labels_filename: 154 | foreground_labels[foreground_labels == 2] = 0 155 | gt_mask = self.process_label(foreground_labels) 156 | 157 | bboxes = [] 158 | masks = [] 159 | categories = [] 160 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 161 | assert gt_masks.sum() == (gt_mask > 0).sum() 162 | for mask in gt_masks: 163 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 164 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 165 | num_vertices = num_vertices if num_vertices > 3 else 3 166 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 167 | approx = approx.squeeze(1) 168 | 169 | coordinates = np.array(approx) 170 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 171 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 172 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 173 | if x_min == x_max or y_min == y_max: 174 | x, y, w, h = cv2.boundingRect(mask) 175 | bboxes.append([x, y, x + w, y + h]) 176 | else: 177 | bboxes.append([x_min, y_min, x_max, y_max]) 178 | 179 | masks.append(mask) 180 | categories.append("0") 181 | 182 | if self.if_self_training: 183 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 184 | 185 | if self.transform: 186 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 187 | image_strong = self.transform.transform_image(image_strong) 188 | 189 | bboxes_weak = np.stack(bboxes_weak, axis=0) 190 | masks_weak = np.stack(masks_weak, axis=0) 191 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 192 | else: 193 | if self.transform: 194 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 195 | 196 | bboxes = np.stack(bboxes, axis=0) 197 | masks = np.stack(masks, axis=0) 198 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 199 | 200 | 201 | def load_datasets(cfg, img_size): 202 | transform = ResizeAndPad(img_size) 203 | train = OSDObject( 204 | cfg, 205 | root_dir=cfg.datasets.robot.OSD, 206 | transform=transform, 207 | split=cfg.split, 208 | training=True, 209 | if_self_training=cfg.augment, 210 | ) 211 | val = OSDObject( 212 | cfg, 213 | root_dir=cfg.datasets.robot.OSD, 214 | transform=transform, 215 | split=cfg.split, 216 | ) 217 | train_dataloader = DataLoader( 218 | train, 219 | batch_size=cfg.batch_size, 220 | shuffle=True, 221 | num_workers=cfg.num_workers, 222 | collate_fn=collate_fn, 223 | ) 224 | val_dataloader = DataLoader( 225 | val, 226 | batch_size=cfg.val_batchsize, 227 | shuffle=False, 228 | num_workers=cfg.num_workers, 229 | collate_fn=collate_fn, 230 | ) 231 | return train_dataloader, val_dataloader 232 | 233 | 234 | def load_datasets_coarse(cfg, img_size): 235 | transform = ResizeAndPad(img_size) 236 | train = OSDObjectwithCoarse( 237 | cfg, 238 | root_dir=cfg.datasets.robot.OSD, 239 | transform=transform, 240 | split=cfg.split, 241 | training=True, 242 | if_self_training=cfg.augment, 243 | ) 244 | val = OSDObjectwithCoarse( 245 | cfg, 246 | root_dir=cfg.datasets.robot.OSD, 247 | transform=transform, 248 | split=cfg.split, 249 | ) 250 | train_dataloader = DataLoader( 251 | train, 252 | batch_size=cfg.batch_size, 253 | shuffle=True, 254 | num_workers=cfg.num_workers, 255 | collate_fn=collate_fn, 256 | ) 257 | val_dataloader = DataLoader( 258 | val, 259 | batch_size=cfg.val_batchsize, 260 | shuffle=False, 261 | num_workers=cfg.num_workers, 262 | collate_fn=collate_fn, 263 | ) 264 | return train_dataloader, val_dataloader 265 | 266 | 267 | def load_datasets_prompt(cfg, img_size): 268 | transform = ResizeAndPad(img_size) 269 | train = OSDObject( 270 | cfg, 271 | root_dir=cfg.datasets.robot.OSD, 272 | transform=transform, 273 | split=cfg.split, 274 | training=True, 275 | if_self_training=cfg.augment, 276 | ) 277 | train_dataloader = DataLoader( 278 | train, 279 | batch_size=cfg.batch_size, 280 | shuffle=True, 281 | num_workers=cfg.num_workers, 282 | collate_fn=collate_fn_, 283 | ) 284 | return train_dataloader 285 | -------------------------------------------------------------------------------- /datasets/PascalVOC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import Dataset 8 | from skimage.draw import polygon2mask 9 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_, decode_mask 10 | 11 | 12 | class PascalVOCDataset(Dataset): 13 | def __init__(self, cfg, root_dir, transform=None, split=False, training=False, if_self_training=False): 14 | self.cfg = cfg 15 | self.root_dir = root_dir 16 | self.transform = transform 17 | 18 | segment_root = os.path.join(root_dir, "SegmentationObject") 19 | all_anns = [os.path.join(segment_root, f) for f in os.listdir(segment_root) if f.endswith('.png')] 20 | all_anns = sorted(all_anns) 21 | 22 | if split: 23 | train_list = [] 24 | eval_list = [] 25 | while all_anns: 26 | for _ in range(6): 27 | if all_anns: 28 | train_list.append(all_anns.pop(0)) 29 | if all_anns: 30 | eval_list.append(all_anns.pop(0)) 31 | 32 | if training: 33 | random.shuffle(train_list) 34 | image_ids = train_list 35 | else: 36 | random.shuffle(eval_list) 37 | image_ids = eval_list 38 | else: 39 | image_ids = all_anns 40 | 41 | self.image_ids = image_ids 42 | 43 | self.if_self_training = if_self_training 44 | 45 | def __len__(self): 46 | return len(self.image_ids) 47 | 48 | def get_pascal_labels(self): 49 | """Load the mapping that associates pascal classes with label colors 50 | 51 | Returns: 52 | np.ndarray with dimensions (21, 3) 53 | """ 54 | return np.asarray( 55 | [ 56 | [0, 0, 0], 57 | [128, 0, 0], 58 | [0, 128, 0], 59 | [128, 128, 0], 60 | [0, 0, 128], 61 | [128, 0, 128], 62 | [0, 128, 128], 63 | [128, 128, 128], 64 | [64, 0, 0], 65 | [192, 0, 0], 66 | [64, 128, 0], 67 | [192, 128, 0], 68 | [64, 0, 128], 69 | [192, 0, 128], 70 | [64, 128, 128], 71 | [192, 128, 128], 72 | [0, 64, 0], 73 | [128, 64, 0], 74 | [0, 192, 0], 75 | [128, 192, 0], 76 | [0, 64, 128], 77 | ] 78 | ) 79 | 80 | def encode_segmap(self, mask): 81 | """Encode segmentation label images as pascal classes 82 | 83 | Args: 84 | mask (np.ndarray): raw segmentation label image of dimension 85 | (M, N, 3), in which the Pascal classes are encoded as colours. 86 | 87 | Returns: 88 | (np.ndarray): class map with dimensions (M,N), where the value at 89 | a given location is the integer denoting the class index. 90 | """ 91 | mask = mask.astype(int) 92 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 93 | for ii, label in enumerate(self.get_pascal_labels()): 94 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 95 | label_mask = label_mask.astype(int) 96 | return label_mask 97 | 98 | def __getitem__(self, idx): 99 | 100 | anno_path = self.image_ids[idx] 101 | image_path = anno_path.replace("SegmentationObject", "JPEGImages").replace(".png", ".jpg") 102 | 103 | image = cv2.imread(image_path) 104 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 105 | gt_mask = cv2.imread(anno_path) 106 | gt_labels = self.encode_segmap(gt_mask) 107 | 108 | if self.cfg.get_prompt: 109 | image_info = {} 110 | height, width, _ = image.shape 111 | image_info["file_path"] = image_path 112 | image_info["height"] = height 113 | image_info["width"] = width 114 | return idx, image_info, image 115 | 116 | masks = [] 117 | bboxes = [] 118 | categories = [] 119 | gt_masks = decode_mask(torch.tensor(gt_labels[None, :, :])).numpy().astype(np.uint8) 120 | assert gt_masks.sum() == (gt_labels > 0).sum() 121 | for mask in gt_masks: 122 | masks.append(mask) 123 | x, y, w, h = cv2.boundingRect(mask) 124 | bboxes.append([x, y, x + w, y + h]) 125 | categories.append("0") 126 | 127 | if self.if_self_training: 128 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 129 | 130 | if self.transform: 131 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 132 | image_strong = self.transform.transform_image(image_strong) 133 | 134 | bboxes_weak = np.stack(bboxes_weak, axis=0) 135 | masks_weak = np.stack(masks_weak, axis=0) 136 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 137 | else: 138 | if self.transform: 139 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 140 | 141 | bboxes = np.stack(bboxes, axis=0) 142 | masks = np.stack(masks, axis=0) 143 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 144 | 145 | 146 | class PascalVOCDatasetwithCoarse(PascalVOCDataset): 147 | 148 | def __getitem__(self, idx): 149 | anno_path = self.image_ids[idx] 150 | image_path = anno_path.replace("SegmentationObject", "JPEGImages").replace(".png", ".jpg") 151 | 152 | image = cv2.imread(image_path) 153 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 154 | gt_mask = cv2.imread(anno_path) 155 | gt_labels = self.encode_segmap(gt_mask) 156 | 157 | masks = [] 158 | bboxes = [] 159 | approxes =[] 160 | categories = [] 161 | gt_masks = decode_mask(torch.tensor(gt_labels[None, :, :])).numpy().astype(np.uint8) 162 | assert gt_masks.sum() == (gt_labels > 0).sum() 163 | for mask in gt_masks: 164 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 165 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 166 | num_vertices = num_vertices if num_vertices > 5 else 5 167 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 168 | approx = approx.squeeze(1) 169 | coordinates = np.array(approx) 170 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 171 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 172 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 173 | if x_min == x_max or y_min == y_max: 174 | x, y, w, h = cv2.boundingRect(mask) 175 | bboxes.append([x, y, x + w, y + h]) 176 | else: 177 | bboxes.append([x_min, y_min, x_max, y_max]) 178 | masks.append(mask) 179 | categories.append("0") 180 | approxes.append(approx) 181 | 182 | if self.if_self_training: 183 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 184 | 185 | if self.transform: 186 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 187 | image_strong = self.transform.transform_image(image_strong) 188 | 189 | bboxes_weak = np.stack(bboxes_weak, axis=0) 190 | masks_weak = np.stack(masks_weak, axis=0) 191 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 192 | 193 | elif self.cfg.visual: 194 | image_name = os.path.splitext(os.path.basename(image_path))[0] 195 | 196 | origin_image = image 197 | origin_approxes = approxes 198 | origin_masks = masks 199 | if self.transform: 200 | padding, image, masks, bboxes = self.transform(image, masks, np.array(bboxes), self.cfg.visual) 201 | 202 | bboxes = np.stack(bboxes, axis=0) 203 | masks = np.stack(masks, axis=0) 204 | origin_masks = np.stack(origin_masks, axis=0) 205 | return image_name, padding, origin_image, origin_approxes, origin_masks, image, torch.tensor(bboxes), torch.tensor(masks).float() 206 | 207 | else: 208 | if self.transform: 209 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 210 | 211 | bboxes = np.stack(bboxes, axis=0) 212 | masks = np.stack(masks, axis=0) 213 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 214 | 215 | 216 | def load_datasets(cfg, img_size): 217 | transform = ResizeAndPad(img_size) 218 | val = PascalVOCDataset( 219 | cfg, 220 | root_dir=cfg.datasets.PascalVOC.root_dir, 221 | transform=transform, 222 | split=cfg.split, 223 | ) 224 | train = PascalVOCDataset( 225 | cfg, 226 | root_dir=cfg.datasets.PascalVOC.root_dir, 227 | transform=transform, 228 | split=cfg.split, 229 | training=True, 230 | if_self_training=cfg.augment, 231 | ) 232 | val_dataloader = DataLoader( 233 | val, 234 | batch_size=cfg.val_batchsize, 235 | shuffle=False, 236 | num_workers=cfg.num_workers, 237 | collate_fn=collate_fn, 238 | ) 239 | train_dataloader = DataLoader( 240 | train, 241 | batch_size=cfg.batch_size, 242 | shuffle=True, 243 | num_workers=cfg.num_workers, 244 | collate_fn=collate_fn, 245 | ) 246 | return train_dataloader, val_dataloader 247 | 248 | 249 | def load_datasets_coarse(cfg, img_size): 250 | transform = ResizeAndPad(img_size) 251 | val = PascalVOCDatasetwithCoarse( 252 | cfg, 253 | root_dir=cfg.datasets.PascalVOC.root_dir, 254 | transform=transform, 255 | split=cfg.split, 256 | ) 257 | train = PascalVOCDatasetwithCoarse( 258 | cfg, 259 | root_dir=cfg.datasets.PascalVOC.root_dir, 260 | transform=transform, 261 | split=cfg.split, 262 | training=True, 263 | if_self_training=cfg.augment, 264 | ) 265 | val_dataloader = DataLoader( 266 | val, 267 | batch_size=cfg.val_batchsize, 268 | shuffle=False, 269 | num_workers=cfg.num_workers, 270 | collate_fn=collate_fn, 271 | ) 272 | train_dataloader = DataLoader( 273 | train, 274 | batch_size=cfg.batch_size, 275 | shuffle=True, 276 | num_workers=cfg.num_workers, 277 | collate_fn=collate_fn, 278 | ) 279 | return train_dataloader, val_dataloader 280 | 281 | 282 | def load_datasets_prompt(cfg, img_size): 283 | transform = ResizeAndPad(img_size) 284 | train = PascalVOCDataset( 285 | cfg, 286 | root_dir=cfg.datasets.PascalVOC.root_dir, 287 | transform=transform, 288 | split=cfg.split, 289 | training=True, 290 | if_self_training=cfg.augment, 291 | ) 292 | train_dataloader = DataLoader( 293 | train, 294 | batch_size=cfg.batch_size, 295 | shuffle=True, 296 | num_workers=cfg.num_workers, 297 | collate_fn=collate_fn_, 298 | ) 299 | return train_dataloader -------------------------------------------------------------------------------- /datasets/Polyp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import glob 5 | import json 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from torch.utils.data import Dataset, DataLoader 10 | from skimage.draw import polygon2mask 11 | 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, decode_mask, collate_fn_ 13 | 14 | 15 | class PolypDataset(Dataset): 16 | def __init__(self, cfg, root_dir, annotation_file, transform=None, split=False, training=False, if_self_training=False): 17 | self.cfg = cfg 18 | self.root_dir = root_dir 19 | self.transform = transform 20 | with open(annotation_file, "r") as ann_file: 21 | anns = json.load(ann_file) 22 | all_images = list(anns.keys()) 23 | 24 | if split: 25 | train_images = [] 26 | eval_images = [] 27 | while all_images: 28 | for _ in range(5): 29 | if all_images: 30 | train_images.append(all_images.pop(0)) 31 | if all_images: 32 | eval_images.append(all_images.pop(0)) 33 | 34 | if training: 35 | random.shuffle(train_images) 36 | images = train_images 37 | else: 38 | random.shuffle(eval_images) 39 | images = eval_images 40 | else: 41 | images = all_images 42 | 43 | self.images = images 44 | self.anns = anns 45 | self.if_self_training = if_self_training 46 | 47 | def __len__(self): 48 | return len(self.images) 49 | 50 | def find_points_outside_bbox(self, mask, bboxes): 51 | points_outside_bbox = np.where(mask != 0) 52 | for bbox in bboxes: 53 | x_min, y_min, x_max, y_max = bbox 54 | points_outside_bbox = (points_outside_bbox[0][(points_outside_bbox[0] < y_min) | (points_outside_bbox[0] >= y_max)], 55 | points_outside_bbox[1][(points_outside_bbox[1] < x_min) | (points_outside_bbox[1] >= x_max)]) 56 | return points_outside_bbox 57 | 58 | def __getitem__(self, idx): 59 | name = self.images[idx] 60 | image_path = os.path.join(self.root_dir, "images", name + ".jpg") 61 | 62 | image = cv2.imread(image_path) 63 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | 65 | if self.cfg.get_prompt: 66 | image_info = {} 67 | height, width, _ = image.shape 68 | image_info["file_path"] = image_path 69 | image_info["height"] = height 70 | image_info["width"] = width 71 | return idx, image_info, image 72 | 73 | gt_path = image_path.replace("images", "masks") 74 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 75 | gt_mask[gt_mask > 0] = 255 76 | 77 | masks = [] 78 | bboxes = [] 79 | categories = [] 80 | anns = self.anns[name] 81 | ann_bboxes = anns["bbox"] 82 | 83 | for i, bbox in enumerate(ann_bboxes): 84 | x_min = bbox["xmin"] 85 | x_max = bbox["xmax"] 86 | y_min = bbox["ymin"] 87 | y_max = bbox["ymax"] 88 | gt_mask[y_min:y_max, x_min:x_max][gt_mask[y_min:y_max, x_min:x_max] > 0] = i + 1 89 | bboxes.append([x_min, y_min, x_max, y_max]) 90 | categories.append(bbox["label"]) 91 | 92 | gt_mask[gt_mask > i + 1] = 0 93 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 94 | assert gt_masks.sum() == (gt_mask > 0).sum() 95 | assert len(ann_bboxes) == gt_masks.shape[0] 96 | masks = [mask for mask in gt_masks] 97 | 98 | if self.if_self_training: 99 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 100 | 101 | if self.transform: 102 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 103 | image_strong = self.transform.transform_image(image_strong) 104 | 105 | bboxes_weak = np.stack(bboxes_weak, axis=0) 106 | masks_weak = np.stack(masks_weak, axis=0) 107 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 108 | else: 109 | if self.transform: 110 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 111 | 112 | bboxes = np.stack(bboxes, axis=0) 113 | masks = np.stack(masks, axis=0) 114 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 115 | 116 | 117 | class PolypDatasetwithCoarse(PolypDataset): 118 | 119 | def __getitem__(self, idx): 120 | name = self.images[idx] 121 | image_path = os.path.join(self.root_dir, "images", name + ".jpg") 122 | 123 | image = cv2.imread(image_path) 124 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 125 | 126 | gt_path = image_path.replace("images", "masks") 127 | gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) 128 | gt_mask[gt_mask > 0] = 255 129 | 130 | masks = [] 131 | bboxes = [] 132 | categories = [] 133 | anns = self.anns[name] 134 | ann_bboxes = anns["bbox"] 135 | 136 | for i, bbox in enumerate(ann_bboxes): 137 | x_min = bbox["xmin"] 138 | x_max = bbox["xmax"] 139 | y_min = bbox["ymin"] 140 | y_max = bbox["ymax"] 141 | gt_mask[y_min:y_max, x_min:x_max][gt_mask[y_min:y_max, x_min:x_max] > 0] = i + 1 142 | # bboxes.append([x_min, y_min, x_max, y_max]) 143 | categories.append(bbox["label"]) 144 | 145 | gt_mask[gt_mask > i + 1] = 0 146 | gt_masks = decode_mask(torch.tensor(gt_mask[None, :, :])).numpy().astype(np.uint8) 147 | assert gt_masks.sum() == (gt_mask > 0).sum() 148 | assert len(ann_bboxes) == gt_masks.shape[0] 149 | 150 | for mask in gt_masks: 151 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 152 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 153 | num_vertices = num_vertices if num_vertices > 3 else 3 154 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 155 | approx = approx.squeeze(1) 156 | 157 | coordinates = np.array(approx) 158 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 159 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 160 | if x_min == x_max or y_min == y_max: 161 | x, y, w, h = cv2.boundingRect(mask) 162 | bboxes.append([x, y, x + w, y + h]) 163 | else: 164 | bboxes.append([x_min, y_min, x_max, y_max]) 165 | 166 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 167 | 168 | masks.append(mask) 169 | 170 | masks = [mask for mask in gt_masks] 171 | 172 | if self.if_self_training: 173 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 174 | 175 | if self.transform: 176 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 177 | image_strong = self.transform.transform_image(image_strong) 178 | 179 | bboxes_weak = np.stack(bboxes_weak, axis=0) 180 | masks_weak = np.stack(masks_weak, axis=0) 181 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 182 | else: 183 | if self.transform: 184 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 185 | 186 | bboxes = np.stack(bboxes, axis=0) 187 | masks = np.stack(masks, axis=0) 188 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 189 | 190 | 191 | def load_datasets(cfg, img_size): 192 | transform = ResizeAndPad(img_size) 193 | val = PolypDataset( 194 | cfg, 195 | root_dir=cfg.datasets.Polyp.root_dir, 196 | annotation_file=cfg.datasets.Polyp.annotation_file, 197 | transform=transform, 198 | split=cfg.split, 199 | ) 200 | train = PolypDataset( 201 | cfg, 202 | root_dir=cfg.datasets.Polyp.root_dir, 203 | annotation_file=cfg.datasets.Polyp.annotation_file, 204 | transform=transform, 205 | split=cfg.split, 206 | training=True, 207 | if_self_training=cfg.augment, 208 | ) 209 | val_dataloader = DataLoader( 210 | val, 211 | batch_size=cfg.val_batchsize, 212 | shuffle=False, 213 | num_workers=cfg.num_workers, 214 | collate_fn=collate_fn, 215 | ) 216 | train_dataloader = DataLoader( 217 | train, 218 | batch_size=cfg.batch_size, 219 | shuffle=True, 220 | num_workers=cfg.num_workers, 221 | collate_fn=collate_fn, 222 | ) 223 | return train_dataloader, val_dataloader 224 | 225 | 226 | def load_datasets_coarse(cfg, img_size): 227 | transform = ResizeAndPad(img_size) 228 | val = PolypDatasetwithCoarse( 229 | cfg, 230 | root_dir=cfg.datasets.Polyp.root_dir, 231 | annotation_file=cfg.datasets.Polyp.annotation_file, 232 | transform=transform, 233 | split=cfg.split, 234 | ) 235 | train = PolypDatasetwithCoarse( 236 | cfg, 237 | root_dir=cfg.datasets.Polyp.root_dir, 238 | annotation_file=cfg.datasets.Polyp.annotation_file, 239 | transform=transform, 240 | split=cfg.split, 241 | training=True, 242 | if_self_training=cfg.augment, 243 | ) 244 | val_dataloader = DataLoader( 245 | val, 246 | batch_size=cfg.val_batchsize, 247 | shuffle=False, 248 | num_workers=cfg.num_workers, 249 | collate_fn=collate_fn, 250 | ) 251 | train_dataloader = DataLoader( 252 | train, 253 | batch_size=cfg.batch_size, 254 | shuffle=True, 255 | num_workers=cfg.num_workers, 256 | collate_fn=collate_fn, 257 | ) 258 | return train_dataloader, val_dataloader 259 | 260 | 261 | def load_datasets_prompt(cfg, img_size): 262 | transform = ResizeAndPad(img_size) 263 | train = PolypDataset( 264 | cfg, 265 | root_dir=cfg.datasets.Polyp.root_dir, 266 | annotation_file=cfg.datasets.Polyp.annotation_file, 267 | transform=transform, 268 | split=cfg.split, 269 | training=True, 270 | if_self_training=cfg.augment, 271 | ) 272 | train_dataloader = DataLoader( 273 | train, 274 | batch_size=cfg.batch_size, 275 | shuffle=True, 276 | num_workers=cfg.num_workers, 277 | collate_fn=collate_fn_, 278 | ) 279 | return train_dataloader 280 | -------------------------------------------------------------------------------- /datasets/SA_1B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import random 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from pycocotools.coco import COCO 10 | from pycocotools import mask as mask_utils 11 | from skimage.draw import polygon2mask 12 | from datasets.tools import ResizeAndPad, soft_transform, collate_fn, collate_fn_ 13 | 14 | 15 | class SADataset(Dataset): 16 | def __init__(self, cfg, root_dir, transform=None, training=False): 17 | self.root_dir = root_dir 18 | self.transform = transform 19 | self.image_list = [] 20 | self.sa_list = [ 21 | sa for sa in os.listdir(root_dir) 22 | if os.path.isdir(os.path.join(root_dir, sa)) 23 | ] 24 | 25 | for sa in self.sa_list: 26 | sa_dir = os.path.join(root_dir, sa) 27 | self.image_list.extend( 28 | [ 29 | os.path.join(sa_dir, f) 30 | for f in os.listdir(sa_dir) 31 | if f.endswith(".jpg") 32 | ] 33 | ) 34 | 35 | self.image_list = random.sample(self.image_list, 200) 36 | self.if_self_training = training 37 | 38 | def __len__(self): 39 | return len(self.image_list) 40 | 41 | def __getitem__(self, idx): 42 | image_path = self.image_list[idx] 43 | image = cv2.imread(image_path) 44 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 45 | 46 | if self.cfg.get_prompt: 47 | image_info = {} 48 | height, width, _ = image.shape 49 | image_info["file_path"] = image_path 50 | image_info["height"] = height 51 | image_info["width"] = width 52 | return idx, image_info, image 53 | 54 | json_path = image_path.replace(".jpg", ".json") 55 | with open(json_path, "r") as f: 56 | annotations = json.load(f) 57 | 58 | bboxes = [] 59 | masks = [] 60 | categories = [] 61 | for anno in annotations["annotations"]: 62 | x, y, w, h = anno["bbox"] 63 | bboxes.append([x, y, x + w, y + h]) 64 | mask = mask_utils.decode(anno["segmentation"]) 65 | masks.append(mask) 66 | categories.append("0") 67 | 68 | if self.if_self_training: 69 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 70 | 71 | if self.transform: 72 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 73 | image_strong = self.transform.transform_image(image_strong) 74 | 75 | bboxes_weak = np.stack(bboxes_weak, axis=0) 76 | masks_weak = np.stack(masks_weak, axis=0) 77 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 78 | else: 79 | if self.transform: 80 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 81 | 82 | bboxes = np.stack(bboxes, axis=0) 83 | masks = np.stack(masks, axis=0) 84 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 85 | 86 | 87 | class SADatasetwithCoarse(SADataset): 88 | def __getitem__(self, idx): 89 | image_path = self.image_list[idx] 90 | image = cv2.imread(image_path) 91 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 92 | 93 | json_path = image_path.replace(".jpg", ".json") 94 | with open(json_path, "r") as f: 95 | annotations = json.load(f) 96 | 97 | bboxes = [] 98 | masks = [] 99 | categories = [] 100 | for anno in annotations["annotations"]: 101 | mask = mask_utils.decode(anno["segmentation"]) 102 | 103 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 104 | num_vertices = 0.05 * cv2.arcLength(contours[0], True) 105 | num_vertices = num_vertices if num_vertices > 5 else 5 106 | approx = cv2.approxPolyDP(contours[0], num_vertices, True) # [x, y] 107 | approx = approx.squeeze(1) 108 | coordinates = np.array(approx) 109 | x_max, x_min = max(coordinates[:, 0]), min(coordinates[:, 0]) 110 | y_max, y_min = max(coordinates[:, 1]), min(coordinates[:, 1]) 111 | coarse_mask = polygon2mask(mask.shape, coordinates).astype(mask.dtype) 112 | if x_min == x_max or y_min == y_max: 113 | x, y, w, h = cv2.boundingRect(mask) 114 | bboxes.append([x, y, x + w, y + h]) 115 | else: 116 | bboxes.append([x_min, y_min, x_max, y_max]) 117 | 118 | masks.append(mask) 119 | categories.append("0") 120 | 121 | if self.if_self_training: 122 | image_weak, bboxes_weak, masks_weak, image_strong = soft_transform(image, bboxes, masks, categories) 123 | 124 | if self.transform: 125 | image_weak, masks_weak, bboxes_weak = self.transform(image_weak, masks_weak, np.array(bboxes_weak)) 126 | image_strong = self.transform.transform_image(image_strong) 127 | 128 | bboxes_weak = np.stack(bboxes_weak, axis=0) 129 | masks_weak = np.stack(masks_weak, axis=0) 130 | return image_weak, image_strong, torch.tensor(bboxes_weak), torch.tensor(masks_weak).float() 131 | else: 132 | if self.transform: 133 | image, masks, bboxes = self.transform(image, masks, np.array(bboxes)) 134 | 135 | bboxes = np.stack(bboxes, axis=0) 136 | masks = np.stack(masks, axis=0) 137 | return image, torch.tensor(bboxes), torch.tensor(masks).float() 138 | 139 | 140 | def load_datasets(cfg, img_size): 141 | transform = ResizeAndPad(img_size) 142 | train = SADataset( 143 | cfg, 144 | root_dir=cfg.datasets.sa.root_dir, 145 | transform=transform, 146 | training=True, 147 | if_self_training=cfg.augment, 148 | ) 149 | val = SADataset( 150 | cfg, 151 | root_dir=cfg.datasets.sa.root_dir, 152 | transform=transform, 153 | ) 154 | train_dataloader = DataLoader( 155 | train, 156 | batch_size=cfg.batch_size, 157 | shuffle=True, 158 | num_workers=cfg.num_workers, 159 | collate_fn=collate_fn, 160 | ) 161 | val_dataloader = DataLoader( 162 | val, 163 | batch_size=cfg.batch_size, 164 | shuffle=False, 165 | num_workers=cfg.num_workers, 166 | collate_fn=collate_fn, 167 | ) 168 | return train_dataloader, val_dataloader 169 | 170 | 171 | def load_datasets_coarse(cfg, img_size): 172 | transform = ResizeAndPad(img_size) 173 | train = SADatasetwithCoarse( 174 | cfg, 175 | root_dir=cfg.datasets.sa.root_dir, 176 | transform=transform, 177 | training=True, 178 | if_self_training=cfg.augment, 179 | ) 180 | val = SADatasetwithCoarse( 181 | cfg, 182 | root_dir=cfg.datasets.sa.root_dir, 183 | transform=transform, 184 | ) 185 | train_dataloader = DataLoader( 186 | train, 187 | batch_size=cfg.batch_size, 188 | shuffle=True, 189 | num_workers=cfg.num_workers, 190 | collate_fn=collate_fn, 191 | ) 192 | val_dataloader = DataLoader( 193 | val, 194 | batch_size=cfg.batch_size, 195 | shuffle=False, 196 | num_workers=cfg.num_workers, 197 | collate_fn=collate_fn, 198 | ) 199 | return train_dataloader, val_dataloader 200 | 201 | 202 | def load_datasets(cfg, img_size): 203 | transform = ResizeAndPad(img_size) 204 | train = SADataset( 205 | cfg, 206 | root_dir=cfg.datasets.sa.root_dir, 207 | transform=transform, 208 | training=True, 209 | if_self_training=cfg.augment, 210 | ) 211 | train_dataloader = DataLoader( 212 | train, 213 | batch_size=cfg.batch_size, 214 | shuffle=True, 215 | num_workers=cfg.num_workers, 216 | collate_fn=collate_fn_, 217 | ) 218 | return train_dataloader 219 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | def call_load_dataset(cfg): 2 | name = cfg.dataset 3 | 4 | key = name.split("-")[0] 5 | module_name = f"datasets.{key}" 6 | function_name = "load_datasets" 7 | 8 | if cfg.visual: 9 | function_name = function_name + "_" + "visual" 10 | 11 | if cfg.prompt == "coarse": 12 | function_name = function_name + "_" + "coarse" 13 | 14 | exec(f"from {module_name} import {function_name}") 15 | func = eval(function_name) 16 | return func 17 | 18 | 19 | def call_load_dataset_prompt(cfg): 20 | name = cfg.dataset 21 | 22 | key = name.split("-")[0] 23 | module_name = f"datasets.{key}" 24 | function_name = "load_datasets" 25 | 26 | function_name = function_name + "_" + "prompt" 27 | 28 | exec(f"from {module_name} import {function_name}") 29 | func = eval(function_name) 30 | return func 31 | 32 | 33 | def call_load_dataset_val(cfg): 34 | name = cfg.dataset 35 | 36 | key = name.split("-")[0] 37 | module_name = f"datasets.{key}" 38 | function_name = "load_datasets" 39 | 40 | function_name = function_name + "_" + "val" 41 | 42 | exec(f"from {module_name} import {function_name}") 43 | func = eval(function_name) 44 | return func 45 | -------------------------------------------------------------------------------- /datasets/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import albumentations as A 6 | import torchvision.transforms as transforms 7 | from segment_anything.utils.transforms import ResizeLongestSide 8 | from imagecorruptions import corrupt, get_corruption_names 9 | 10 | # A.RandomCropNearBBox() 11 | # A.BBoxSafeRandomCrop() 12 | # A.RandomSizedBBoxSafeCrop() 13 | 14 | weak_transforms = A.Compose( 15 | [ 16 | A.Flip(), 17 | A.HorizontalFlip(), 18 | A.VerticalFlip(), 19 | # A.BBoxSafeRandomCrop() 20 | ], 21 | bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_ids"]), 22 | # keypoint_params=A.KeypointParams(format='xy') 23 | ) 24 | 25 | 26 | strong_transforms = A.Compose( 27 | [ 28 | A.Posterize(), 29 | A.Equalize(), 30 | A.Sharpen(), 31 | A.Solarize(), 32 | A.RandomBrightnessContrast(), 33 | A.RandomShadow(), 34 | ] 35 | ) 36 | 37 | 38 | class ResizeAndPad: 39 | 40 | def __init__(self, target_size): 41 | self.target_size = target_size 42 | self.transform = ResizeLongestSide(target_size) 43 | self.to_tensor = transforms.ToTensor() 44 | 45 | def __call__(self, image, masks, bboxes=None, visual=False): 46 | # Resize image and masks 47 | og_h, og_w, _ = image.shape 48 | image = self.transform.apply_image(image) 49 | masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks] 50 | image = self.to_tensor(image) 51 | 52 | # Pad image and masks to form a square 53 | _, h, w = image.shape 54 | max_dim = max(w, h) 55 | pad_w = (max_dim - w) // 2 56 | pad_h = (max_dim - h) // 2 57 | 58 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) 59 | image = transforms.Pad(padding)(image) 60 | masks = [transforms.Pad(padding)(mask) for mask in masks] 61 | 62 | # Adjust bounding boxes 63 | if bboxes is not None: 64 | bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w)) 65 | bboxes = [ 66 | [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] 67 | for bbox in bboxes 68 | ] 69 | if visual: 70 | return padding, image, masks, bboxes 71 | else: 72 | return image, masks, bboxes 73 | else: 74 | if visual: 75 | return padding, image, masks 76 | else: 77 | return image, masks 78 | 79 | def transform_image(self, image): 80 | # Resize image and masks 81 | image = self.transform.apply_image(image) 82 | image = self.to_tensor(image) 83 | 84 | # Pad image and masks to form a square 85 | _, h, w = image.shape 86 | max_dim = max(w, h) 87 | pad_w = (max_dim - w) // 2 88 | pad_h = (max_dim - h) // 2 89 | 90 | padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) 91 | image = transforms.Pad(padding)(image) 92 | return image 93 | 94 | def transform_coord(self, points, image): 95 | og_h, og_w, _ = image.shape 96 | coords = points.reshape(1, -1, 2) 97 | points = self.transform.apply_coords(coords, (og_h, og_w)) 98 | return points.reshape(-1, 2) 99 | 100 | def transform_coords(self, points, image, n): 101 | og_h, og_w, _ = image.shape 102 | coords = points.reshape(-1, n, 2) 103 | points = self.transform.apply_coords(coords, (og_h, og_w)) 104 | return points.reshape(-1, n, 2) 105 | 106 | 107 | def corrupt_image(image, filename): 108 | file_name = os.path.basename(os.path.abspath(filename)) 109 | file_path = os.path.dirname(os.path.abspath(filename)) 110 | for corruption in get_corruption_names(): 111 | corrupted = corrupt(image, severity=5, corruption_name=corruption) 112 | corrupt_path = file_path.replace( 113 | "val2017", os.path.join("corruption", corruption) 114 | ) 115 | if not os.path.exists(corrupt_path): 116 | os.makedirs(corrupt_path, exist_ok=True) 117 | cv2.imwrite(os.path.join(corrupt_path, file_name), corrupted) 118 | 119 | 120 | def soft_transform(image: np.ndarray, bboxes: list, masks: list, categories: list): 121 | weak_transformed = weak_transforms( 122 | image=image, bboxes=bboxes, masks=masks, category_ids=categories 123 | ) 124 | image_weak = weak_transformed["image"] 125 | bboxes_weak = weak_transformed["bboxes"] 126 | masks_weak = weak_transformed["masks"] 127 | 128 | strong_transformed = strong_transforms(image=image_weak) 129 | image_strong = strong_transformed["image"] 130 | return image_weak, bboxes_weak, masks_weak, image_strong 131 | 132 | 133 | def soft_transform_all( 134 | image: np.ndarray, bboxes: list, masks: list, points: list, categories: list 135 | ): 136 | weak_transformed = weak_transforms( 137 | image=image, 138 | bboxes=bboxes, 139 | masks=masks, 140 | category_ids=categories, 141 | keypoints=points, 142 | ) 143 | image_weak = weak_transformed["image"] 144 | bboxes_weak = weak_transformed["bboxes"] 145 | masks_weak = weak_transformed["masks"] 146 | keypoints_weak = weak_transformed["keypoints"] 147 | 148 | strong_transformed = strong_transforms(image=image_weak) 149 | image_strong = strong_transformed["image"] 150 | return image_weak, bboxes_weak, masks_weak, keypoints_weak, image_strong 151 | 152 | 153 | def collate_fn(batch): 154 | if len(batch[0]) == 3: 155 | images, bboxes, masks = zip(*batch) 156 | images = torch.stack(images) 157 | return images, bboxes, masks 158 | elif len(batch[0]) == 4: 159 | images_soft, images, bboxes, masks = zip(*batch) 160 | images = torch.stack(images) 161 | images_soft = torch.stack(images_soft) 162 | return images_soft, images, bboxes, masks 163 | else: 164 | raise ValueError("Unexpected batch format") 165 | 166 | 167 | def collate_fn_(batch): 168 | return zip(*batch) 169 | 170 | 171 | def decode_mask(mask): 172 | """ 173 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects 174 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects. 175 | 176 | Args: 177 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 178 | 179 | Returns: 180 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 181 | """ 182 | unique_labels = torch.unique(mask) 183 | unique_labels = unique_labels[unique_labels != 0] 184 | n_objects = len(unique_labels) 185 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64) 186 | for i, label in enumerate(unique_labels): 187 | new_mask[i] = (mask == label).squeeze(0) 188 | return new_mask 189 | 190 | 191 | def encode_mask(mask): 192 | """ 193 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects 194 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 195 | 196 | Args: 197 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 198 | 199 | Returns: 200 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 201 | """ 202 | n_objects = mask.shape[0] 203 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64) 204 | for i in range(n_objects): 205 | new_mask[0][mask[i] == 1] = i + 1 206 | return new_mask 207 | 208 | 209 | if __name__ == "__main__": 210 | mask_encode = np.array([[[0, 0, 1], [2, 0, 2], [0, 3, 3]]]) 211 | mask_decode = np.array( 212 | [ 213 | [[0, 0, 1], [0, 0, 0], [0, 0, 0]], 214 | [[0, 0, 0], [1, 0, 1], [0, 0, 0]], 215 | [[0, 0, 0], [0, 0, 0], [0, 1, 1]], 216 | ] 217 | ) 218 | encoded_mask = encode_mask(torch.tensor(mask_decode)) 219 | decoded_mask = decode_mask(torch.tensor(mask_encode)) 220 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from abc import ABC 6 | 7 | ALPHA = 0.8 8 | GAMMA = 2 9 | 10 | 11 | class FocalLoss(nn.Module): 12 | 13 | def __init__(self, weight=None, size_average=True): 14 | super().__init__() 15 | 16 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): 17 | inputs = F.sigmoid(inputs) 18 | inputs = torch.clamp(inputs, min=0, max=1) 19 | #flatten label and prediction tensors 20 | inputs = inputs.view(-1) 21 | targets = targets.view(-1) 22 | 23 | BCE = F.binary_cross_entropy(inputs, targets, reduction='none') 24 | BCE_EXP = torch.exp(-BCE) 25 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE 26 | focal_loss = focal_loss.mean() 27 | 28 | return focal_loss 29 | 30 | 31 | class DiceLoss(nn.Module): 32 | 33 | def __init__(self, weight=None, size_average=True): 34 | super().__init__() 35 | 36 | def forward(self, inputs, targets, smooth=1): 37 | inputs = F.sigmoid(inputs) 38 | inputs = torch.clamp(inputs, min=0, max=1) 39 | #flatten label and prediction tensors 40 | inputs = inputs.view(-1) 41 | targets = targets.view(-1) 42 | 43 | intersection = (inputs * targets).sum() 44 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) 45 | 46 | return 1 - dice 47 | 48 | 49 | class ContraLoss(nn.Module): 50 | 51 | def __init__(self, temperature = 0.3, weight=None, size_average=True): 52 | super().__init__() 53 | self.temperature = temperature 54 | self.criterion = torch.nn.CrossEntropyLoss() 55 | 56 | def forward(self, embedd_x: torch.Tensor, embedd_y: torch.Tensor, mask_x: torch.Tensor, mask_y: torch.Tensor): 57 | x_embedding = self.norm_embed(embedd_x) # embedd_x: [256, 64, 64] 58 | y_embedding = self.norm_embed(embedd_y) 59 | 60 | x_masks = F.interpolate(mask_x, size=x_embedding.shape[-2:], mode="bilinear", align_corners=False).detach() 61 | y_masks = F.interpolate(mask_y, size=y_embedding.shape[-2:], mode="bilinear", align_corners=False).detach() 62 | 63 | x_masks = F.sigmoid(x_masks) 64 | x_masks = torch.clamp(x_masks, min=0, max=1) 65 | x_masks = x_masks > 0.5 66 | y_masks = F.sigmoid(y_masks) 67 | y_masks = torch.clamp(y_masks, min=0, max=1) 68 | y_masks = y_masks > 0.5 69 | 70 | # x_masks = self.add_background(x_masks) 71 | # y_masks = self.add_background(y_masks) 72 | 73 | sum_x = x_masks.sum(dim=[-1, -2]).clone() 74 | sum_y = y_masks.sum(dim=[-1, -2]).clone() 75 | sum_x[sum_x[:, 0] == 0.] = 1. 76 | sum_y[sum_y[:, 0] == 0.] = 1. 77 | 78 | multi_embedd_x = (x_embedding * x_masks).sum(dim=[-1, -2]) / sum_x # [n, 256, 64, 64] >> [n, 256] 79 | multi_embedd_y = (y_embedding * y_masks).sum(dim=[-1, -2]) / sum_y 80 | 81 | flatten_x = multi_embedd_x.view(multi_embedd_x.size(0), -1) # [n, 256] 82 | flatten_y = multi_embedd_y.view(multi_embedd_y.size(0), -1) 83 | # similarity_matrix = torch.matmul(multi_embedd_x, multi_embedd_y.T) 84 | similarity_matrix = F.cosine_similarity(flatten_x.unsqueeze(1), flatten_y.unsqueeze(0), dim=2) # [n, n] 85 | 86 | label_pos = torch.eye(x_masks.size(0)).bool().to(embedd_x.device) 87 | label_nag = ~label_pos 88 | 89 | similarity_matrix = similarity_matrix / self.temperature # [n insts, n insts] 90 | loss = -torch.log( 91 | similarity_matrix.masked_select(label_pos).exp().sum() / 92 | similarity_matrix.exp().sum() 93 | ) 94 | # loss = -torch.log( 95 | # similarity_matrix.masked_select(label_pos).exp().sum() 96 | # ) 97 | return loss 98 | 99 | def norm_embed(self, embedding: torch.Tensor): 100 | embedding = F.normalize(embedding, dim=0, p=2) 101 | return embedding 102 | 103 | def add_background(self, masks): 104 | mask_union = torch.max(masks, dim=0).values 105 | mask_complement = ~mask_union 106 | concatenated_masks = torch.cat((masks, mask_complement.unsqueeze(0)), dim=0) 107 | return concatenated_masks -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 7 | from sam_lora import LoRA_Sam 8 | 9 | 10 | class Model(nn.Module): 11 | 12 | def __init__(self, cfg): 13 | super().__init__() 14 | self.cfg = cfg 15 | 16 | def get_checkpoint(self, model_type): 17 | if model_type == "vit_b": 18 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_b_01ec64.pth") 19 | elif model_type == "vit_l": 20 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_l_0b3195.pth") 21 | elif model_type == "vit_h": 22 | checkpoint = os.path.join(self.cfg.model.checkpoint, "sam_vit_h_4b8939.pth") 23 | else: 24 | raise ValueError("Model type error!") 25 | return checkpoint 26 | 27 | def setup(self): 28 | checkpoint = self.get_checkpoint(self.cfg.model.type) 29 | self.model = sam_model_registry[self.cfg.model.type](checkpoint=checkpoint) 30 | 31 | self.model.train() 32 | if self.cfg.model.freeze.image_encoder: 33 | for param in self.model.image_encoder.parameters(): 34 | param.requires_grad = False 35 | if self.cfg.model.freeze.prompt_encoder: 36 | for param in self.model.prompt_encoder.parameters(): 37 | param.requires_grad = False 38 | if self.cfg.model.freeze.mask_decoder: 39 | for param in self.model.mask_decoder.parameters(): 40 | param.requires_grad = False 41 | 42 | # self.finetune() 43 | 44 | def finetune(self): 45 | LoRA_Sam(self.model, 4) 46 | # self.set_norm_layer() 47 | # self.set_evp_adaptor_layer() 48 | # self.set_prompt_layer() 49 | 50 | def set_norm_layer(self): 51 | for name, param in self.model.image_encoder.named_parameters(): 52 | if "norm" in name: 53 | param.requires_grad = True 54 | 55 | def set_evp_adaptor_layer(self): 56 | for param in self.model.image_encoder.prompt_generator.parameters(): 57 | param.requires_grad = True 58 | 59 | def set_prompt_layer(self): 60 | self.model.image_encoder.Prompt_Tokens.requires_grad = True 61 | 62 | def reset_parameters(self) -> None: 63 | for name, param in self.model.named_parameters(): 64 | if param.requires_grad == True: 65 | if "linear_a" in name: 66 | nn.init.kaiming_uniform_(param, a=math.sqrt(5)) 67 | if "linear_b" in name: 68 | nn.init.zeros_(param) 69 | 70 | def forward(self, images, prompts): 71 | image_embeddings = self.encode(images) 72 | pred_masks, ious, res_masks = self.decode(prompts, image_embeddings) 73 | return image_embeddings, pred_masks, ious, res_masks 74 | 75 | def encode(self, images): 76 | _, _, H, W = images.shape 77 | self.image_shape = (H, W) 78 | image_embeddings = self.model.image_encoder(images) 79 | return image_embeddings 80 | 81 | def decode(self, prompts, image_embeddings): 82 | pred_masks = [] 83 | ious = [] 84 | res_masks = [] 85 | for prompt, embedding in zip(prompts, image_embeddings): 86 | if isinstance(prompt, torch.Tensor): 87 | prompt = prompt.to(device=embedding.device) 88 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 89 | points=None, 90 | boxes=prompt, 91 | masks=None, 92 | ) 93 | elif isinstance(prompt, tuple): 94 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 95 | points=prompt, 96 | boxes=None, 97 | masks=None, 98 | ) 99 | 100 | low_res_masks, iou_predictions = self.model.mask_decoder( 101 | image_embeddings=embedding.unsqueeze(0), 102 | image_pe=self.model.prompt_encoder.get_dense_pe(), 103 | sparse_prompt_embeddings=sparse_embeddings, 104 | dense_prompt_embeddings=dense_embeddings, 105 | multimask_output=False, 106 | ) 107 | 108 | masks = F.interpolate( 109 | low_res_masks, 110 | self.image_shape, 111 | mode="bilinear", 112 | align_corners=False, 113 | ) 114 | pred_masks.append(masks.squeeze(1)) 115 | ious.append(iou_predictions) 116 | res_masks.append(low_res_masks) 117 | return pred_masks, ious, res_masks 118 | 119 | def get_predictor(self): 120 | return SamPredictor(self.model) 121 | 122 | def get_generator(self, output_mode): 123 | return SamAutomaticMaskGenerator(self.model, output_mode=output_mode) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-box==7.0.1 2 | pycocotools==2.0.6 3 | numpy==1.24.2 4 | opencv_python==4.7.0.72 5 | Pillow==9.3.0 6 | lightning==2.0.1 7 | segmentation-models-pytorch==0.3.2 8 | albumentations==1.3.1 9 | imagecorruptions==1.1.2 10 | safetensors==0.4.1 11 | torchsummary==1.5.1 12 | tensorboard==2.14.0 13 | einops -------------------------------------------------------------------------------- /sam_lora.py: -------------------------------------------------------------------------------- 1 | # Sheng Wang at Apr 6 2023 2 | # What a time to be alive (first half of 2023) 3 | 4 | from segment_anything.modeling import Sam 5 | from segment_anything import build_sam, SamPredictor 6 | from segment_anything import sam_model_registry 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch import Tensor 13 | from torch.nn.parameter import Parameter 14 | from segment_anything.modeling import Sam 15 | from safetensors import safe_open 16 | from safetensors.torch import save_file 17 | from timm.models.vision_transformer import VisionTransformer as timm_ViT 18 | 19 | 20 | class _LoRA_qkv(nn.Module): 21 | """In Sam it is implemented as 22 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 23 | B, N, C = x.shape 24 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 25 | q, k, v = qkv.unbind(0) 26 | """ 27 | 28 | def __init__( 29 | self, 30 | qkv: nn.Module, 31 | linear_a_q: nn.Module, 32 | linear_b_q: nn.Module, 33 | linear_a_v: nn.Module, 34 | linear_b_v: nn.Module, 35 | ): 36 | super().__init__() 37 | self.qkv = qkv 38 | self.linear_a_q = linear_a_q 39 | self.linear_b_q = linear_b_q 40 | self.linear_a_v = linear_a_v 41 | self.linear_b_v = linear_b_v 42 | self.dim = qkv.in_features 43 | self.w_identity = torch.eye(qkv.in_features) 44 | 45 | def forward(self, x): 46 | # x: [25, 14, 14, 768]; self.qkv: Linear(in_features=768, out_features=2304, bias=True) 47 | qkv = self.qkv(x) # B,N,N,3*org_C 48 | new_q = self.linear_b_q(self.linear_a_q(x)) 49 | new_v = self.linear_b_v(self.linear_a_v(x)) 50 | qkv[:, :, :, : self.dim] += new_q 51 | qkv[:, :, :, -self.dim :] += new_v 52 | return qkv 53 | 54 | 55 | class LoRA(nn.Module): 56 | def __init__(self, *args, **kwargs) -> None: 57 | super().__init__(*args, **kwargs) 58 | 59 | def save_fc_parameters(self, filename: str) -> None: 60 | r"""Only safetensors is supported now. 61 | 62 | pip install safetensor if you do not have one installed yet. 63 | """ 64 | assert filename.endswith(".safetensors") 65 | _in = self.lora_vit.head.in_features 66 | _out = self.lora_vit.head.out_features 67 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} 68 | save_file(fc_tensors, filename) 69 | 70 | def load_fc_parameters(self, filename: str) -> None: 71 | r"""Only safetensors is supported now. 72 | 73 | pip install safetensor if you do not have one installed yet. 74 | """ 75 | 76 | assert filename.endswith(".safetensors") 77 | _in = self.lora_vit.head.in_features 78 | _out = self.lora_vit.head.out_features 79 | with safe_open(filename, framework="pt") as f: 80 | saved_key = f"fc_{_in}in_{_out}out" 81 | try: 82 | saved_tensor = f.get_tensor(saved_key) 83 | self.lora_vit.head.weight = Parameter(saved_tensor) 84 | except ValueError: 85 | print("this fc weight is not for this model") 86 | 87 | def save_lora_parameters(self, filename: str) -> None: 88 | r"""Only safetensors is supported now. 89 | 90 | pip install safetensor if you do not have one installed yet. 91 | 92 | save both lora and fc parameters. 93 | """ 94 | 95 | assert filename.endswith(".safetensors") 96 | 97 | num_layer = len(self.w_As) # actually, it is half 98 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} 99 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} 100 | 101 | _in = self.lora_vit.head.in_features 102 | _out = self.lora_vit.head.out_features 103 | fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight} 104 | 105 | merged_dict = {**a_tensors, **b_tensors, **fc_tensors} 106 | save_file(merged_dict, filename) 107 | 108 | def load_lora_parameters(self, filename: str) -> None: 109 | r"""Only safetensors is supported now. 110 | 111 | pip install safetensor if you do not have one installed yet.\ 112 | 113 | load both lora and fc parameters. 114 | """ 115 | 116 | assert filename.endswith(".safetensors") 117 | 118 | with safe_open(filename, framework="pt") as f: 119 | for i, w_A_linear in enumerate(self.w_As): 120 | saved_key = f"w_a_{i:03d}" 121 | saved_tensor = f.get_tensor(saved_key) 122 | w_A_linear.weight = Parameter(saved_tensor) 123 | 124 | for i, w_B_linear in enumerate(self.w_Bs): 125 | saved_key = f"w_b_{i:03d}" 126 | saved_tensor = f.get_tensor(saved_key) 127 | w_B_linear.weight = Parameter(saved_tensor) 128 | 129 | _in = self.lora_vit.head.in_features 130 | _out = self.lora_vit.head.out_features 131 | saved_key = f"fc_{_in}in_{_out}out" 132 | try: 133 | saved_tensor = f.get_tensor(saved_key) 134 | self.lora_vit.head.weight = Parameter(saved_tensor) 135 | except ValueError: 136 | print("this fc weight is not for this model") 137 | 138 | def reset_parameters(self) -> None: 139 | for w_A in self.w_As: 140 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) 141 | for w_B in self.w_Bs: 142 | nn.init.zeros_(w_B.weight) 143 | 144 | 145 | class LoRA_Sam(LoRA): 146 | """Applies low-rank adaptation to a Sam model's image encoder. 147 | 148 | Args: 149 | sam_model: a vision transformer model, see base_vit.py 150 | r: rank of LoRA 151 | num_classes: how many classes the model output, default to the vit model 152 | lora_layer: which layer we apply LoRA. 153 | 154 | Examples:: 155 | >>> model = ViT('B_16_imagenet1k') 156 | >>> lora_model = LoRA_ViT(model, r=4) 157 | >>> preds = lora_model(img) 158 | >>> print(preds.shape) 159 | torch.Size([1, 1000]) 160 | """ 161 | 162 | def __init__(self, sam_model: Sam, r: int, lora_layer=None): 163 | super(LoRA_Sam, self).__init__() 164 | 165 | assert r > 0 166 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels 167 | # dim = base_vit_dim 168 | if lora_layer: 169 | self.lora_layer = lora_layer 170 | else: 171 | self.lora_layer = list(range(len(sam_model.image_encoder.blocks))) 172 | # create for storage, then we can init them or load weights 173 | self.w_As = [] # These are linear layers 174 | self.w_Bs = [] 175 | 176 | # lets freeze first 177 | for param in sam_model.image_encoder.parameters(): 178 | param.requires_grad = False 179 | 180 | # Here, we do the surgery 181 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): 182 | # If we only want few lora layer instead of all 183 | if t_layer_i not in self.lora_layer: 184 | continue 185 | w_qkv_linear = blk.attn.qkv 186 | self.dim = w_qkv_linear.in_features 187 | w_a_linear_q = nn.Linear(self.dim, r, bias=False) 188 | w_b_linear_q = nn.Linear(r, self.dim, bias=False) 189 | w_a_linear_v = nn.Linear(self.dim, r, bias=False) 190 | w_b_linear_v = nn.Linear(r, self.dim, bias=False) 191 | self.w_As.append(w_a_linear_q) 192 | self.w_Bs.append(w_b_linear_q) 193 | self.w_As.append(w_a_linear_v) 194 | self.w_Bs.append(w_b_linear_v) 195 | blk.attn.qkv = _LoRA_qkv( 196 | w_qkv_linear, 197 | w_a_linear_q, 198 | w_b_linear_q, 199 | w_a_linear_v, 200 | w_b_linear_v, 201 | ) 202 | self.reset_parameters() 203 | # self.sam = sam_model 204 | self.lora_vit = sam_model 205 | 206 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import lightning as L 4 | import segmentation_models_pytorch as smp 5 | from box import Box 6 | from torch.utils.data import DataLoader 7 | from model import Model 8 | from utils.sample_utils import get_point_prompts 9 | from utils.tools import write_csv 10 | 11 | 12 | class AverageMeter: 13 | """Computes and stores the average and current value.""" 14 | 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 32 | pred_mask = (pred_mask >= 0.5).float() 33 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) 34 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection 35 | epsilon = 1e-7 36 | batch_iou = intersection / (union + epsilon) 37 | 38 | batch_iou = batch_iou.unsqueeze(1) 39 | return batch_iou 40 | 41 | 42 | def get_prompts(cfg: Box, bboxes, gt_masks): 43 | if cfg.prompt == "box" or cfg.prompt == "coarse": 44 | prompts = bboxes 45 | elif cfg.prompt == "point": 46 | prompts = get_point_prompts(gt_masks, cfg.num_points) 47 | else: 48 | raise ValueError("Prompt Type Error!") 49 | return prompts 50 | 51 | 52 | def validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0): 53 | model.eval() 54 | ious = AverageMeter() 55 | f1_scores = AverageMeter() 56 | 57 | with torch.no_grad(): 58 | for iter, data in enumerate(val_dataloader): 59 | images, bboxes, gt_masks = data 60 | num_images = images.size(0) 61 | 62 | prompts = get_prompts(cfg, bboxes, gt_masks) 63 | 64 | _, pred_masks, _, _ = model(images, prompts) 65 | for pred_mask, gt_mask in zip(pred_masks, gt_masks): 66 | batch_stats = smp.metrics.get_stats( 67 | pred_mask, 68 | gt_mask.int(), 69 | mode='binary', 70 | threshold=0.5, 71 | ) 72 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise") 73 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise") 74 | ious.update(batch_iou, num_images) 75 | f1_scores.update(batch_f1, num_images) 76 | fabric.print( 77 | f'Val: [{iters}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]' 78 | ) 79 | torch.cuda.empty_cache() 80 | 81 | fabric.print(f'Validation [{iters}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 82 | csv_dict = {"Name": name, "Prompt": cfg.prompt, "Mean IoU": f"{ious.avg:.4f}", "Mean F1": f"{f1_scores.avg:.4f}", "iters": iters} 83 | 84 | if fabric.global_rank == 0: 85 | write_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_dict, csv_head=cfg.csv_keys) 86 | model.train() 87 | return ious.avg, f1_scores.avg 88 | 89 | 90 | def unspervised_validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0): 91 | init_prompt = cfg.prompt 92 | cfg.prompt = "box" 93 | iou_box, f1_box = validate(fabric, cfg, model, val_dataloader, name, iters) 94 | cfg.prompt = "point" 95 | iou_point, f1_point = validate(fabric, cfg, model, val_dataloader, name, iters) 96 | # cfg.prompt = "coarse" 97 | # validate(fabric, cfg, model, val_dataloader, name, iters) 98 | cfg.prompt = init_prompt 99 | return iou_box, f1_box, iou_point, f1_point 100 | 101 | 102 | def contrast_validate(fabric: L.Fabric, cfg: Box, model: Model, val_dataloader: DataLoader, name: str, iters: int = 0, loss: float = 0.): 103 | model.eval() 104 | ious = AverageMeter() 105 | f1_scores = AverageMeter() 106 | 107 | with torch.no_grad(): 108 | for iter, data in enumerate(val_dataloader): 109 | images, bboxes, gt_masks = data 110 | num_images = images.size(0) 111 | 112 | prompts = get_prompts(cfg, bboxes, gt_masks) 113 | 114 | _, pred_masks, _, _ = model(images, prompts) 115 | for pred_mask, gt_mask in zip(pred_masks, gt_masks): 116 | batch_stats = smp.metrics.get_stats( 117 | pred_mask, 118 | gt_mask.int(), 119 | mode='binary', 120 | threshold=0.5, 121 | ) 122 | batch_iou = smp.metrics.iou_score(*batch_stats, reduction="micro-imagewise") 123 | batch_f1 = smp.metrics.f1_score(*batch_stats, reduction="micro-imagewise") 124 | ious.update(batch_iou, num_images) 125 | f1_scores.update(batch_f1, num_images) 126 | fabric.print( 127 | f'Val: [{iters}] - [{iter}/{len(val_dataloader)}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]' 128 | ) 129 | torch.cuda.empty_cache() 130 | 131 | fabric.print(f'Validation [{iters}]: Mean IoU: [{ious.avg:.4f}] -- Mean F1: [{f1_scores.avg:.4f}]') 132 | csv_dict = {"Name": name, "Prompt": cfg.prompt, "Mean IoU": f"{ious.avg:.4f}", "Mean F1": f"{f1_scores.avg:.4f}", "iters": iters, "loss": loss} 133 | 134 | if fabric.global_rank == 0: 135 | write_csv(os.path.join(cfg.out_dir, f"metrics-{cfg.prompt}.csv"), csv_dict, csv_head=cfg.csv_keys) 136 | model.train() 137 | return ious.avg, f1_scores.avg 138 | -------------------------------------------------------------------------------- /utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.cluster import KMeans 4 | 5 | 6 | def uniform_sampling(masks, N=5): 7 | n_points = [] 8 | for mask in masks: 9 | if not isinstance(mask, np.ndarray): 10 | mask = mask.cpu().numpy() 11 | 12 | indices = np.argwhere(mask == 1) # [y, x] 13 | sampled_indices = np.random.choice(len(indices), N, replace=True) 14 | sampled_points = np.flip(indices[sampled_indices], axis=1) 15 | n_points.append(sampled_points.tolist()) 16 | 17 | return n_points 18 | 19 | 20 | def get_multi_distance_points(input_point, mask, points_nubmer): 21 | new_points = np.zeros((points_nubmer + 1, 2)) 22 | new_points[0] = [input_point[1], input_point[0]] 23 | for i in range(points_nubmer): 24 | new_points[i + 1] = get_next_distance_point(new_points[:i + 1, :], mask) 25 | 26 | new_points = swap_xy(new_points) 27 | return new_points 28 | 29 | 30 | def get_next_distance_point(input_points, mask): 31 | max_distance_point = [0, 0] 32 | max_distance = 0 33 | input_points = np.array(input_points) 34 | 35 | indices = np.argwhere(mask == True) 36 | for x, y in indices: 37 | # print(x,y,input_points) 38 | distance = np.sum(np.sqrt((x - input_points[:, 0]) ** 2 + (y - input_points[:, 1]) ** 2)) 39 | if max_distance < distance: 40 | max_distance_point = [x, y] 41 | max_distance = distance 42 | return max_distance_point 43 | 44 | 45 | def swap_xy(points): 46 | new_points = np.zeros((len(points),2)) 47 | new_points[:,0] = points[:,1] 48 | new_points[:,1] = points[:,0] 49 | return new_points 50 | 51 | 52 | def k_means_sampling(mask, k): 53 | points = np.argwhere(mask == 1) # [y, x] 54 | points = np.flip(points, axis=1) 55 | 56 | kmeans = KMeans(n_clusters=k) 57 | kmeans.fit(points) 58 | points = kmeans.cluster_centers_ 59 | return points 60 | 61 | 62 | def get_point_prompt_max_dist(masks, num_points): 63 | n_points = [] 64 | for mask in masks: 65 | mask_np = mask.cpu().numpy() 66 | 67 | indices = np.argwhere(mask_np > 0) 68 | random_index = np.random.choice(len(indices), 1)[0] 69 | 70 | first_point = [indices[random_index][1], indices[random_index][0]] 71 | new_points = get_multi_distance_points(first_point, mask_np, num_points - 1) 72 | n_points.append(new_points) 73 | 74 | return n_points 75 | 76 | 77 | def get_point_prompt_kmeans(masks, num_points): 78 | n_points = [] 79 | for mask in masks: 80 | mask_np = mask.cpu().numpy() 81 | points = k_means_sampling(mask_np, num_points) 82 | n_points.append(points.astype(int)) 83 | return n_points 84 | 85 | 86 | def get_point_prompts(gt_masks, num_points): 87 | prompts = [] 88 | for mask in gt_masks: 89 | po_points = uniform_sampling(mask, num_points) 90 | na_points = uniform_sampling((~mask.to(bool)).to(float), num_points) 91 | po_point_coords = torch.tensor(po_points, device=mask.device) 92 | na_point_coords = torch.tensor(na_points, device=mask.device) 93 | point_coords = torch.cat((po_point_coords, na_point_coords), dim=1) 94 | po_point_labels = torch.ones(po_point_coords.shape[:2], dtype=torch.int, device=po_point_coords.device) 95 | na_point_labels = torch.zeros(na_point_coords.shape[:2], dtype=torch.int, device=na_point_coords.device) 96 | point_labels = torch.cat((po_point_labels, na_point_labels), dim=1) 97 | in_points = (point_coords, point_labels) 98 | prompts.append(in_points) 99 | return prompts 100 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import copy 5 | import numpy as np 6 | from torchsummary import summary 7 | 8 | def freeze(model: torch.nn.Module): 9 | model.eval() 10 | for param in model.parameters(): 11 | param.requires_grad = False 12 | 13 | 14 | def momentum_update(student_model, teacher_model, momentum=0.99): 15 | for (src_name, src_param), (tgt_name, tgt_param) in zip( 16 | student_model.named_parameters(), teacher_model.named_parameters() 17 | ): 18 | if src_param.requires_grad: 19 | tgt_param.data.mul_(momentum).add_(src_param.data, alpha=1 - momentum) 20 | 21 | 22 | def decode_mask(mask): 23 | """ 24 | Convert mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects 25 | to a mask with shape [n, h, w] using a new dimension to represent the number of objects. 26 | 27 | Args: 28 | mask (torch.Tensor): Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 29 | 30 | Returns: 31 | torch.Tensor: Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 32 | """ 33 | unique_labels = torch.unique(mask) 34 | unique_labels = unique_labels[unique_labels != 0] 35 | n_objects = len(unique_labels) 36 | new_mask = torch.zeros((n_objects, *mask.shape[1:]), dtype=torch.int64) 37 | for i, label in enumerate(unique_labels): 38 | new_mask[i] = (mask == label).squeeze(0) 39 | return new_mask 40 | 41 | 42 | def encode_mask(mask): 43 | """ 44 | Convert mask with shape [n, h, w] using a new dimension to represent the number of objects 45 | to a mask with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 46 | 47 | Args: 48 | mask (torch.Tensor): Mask tensor with shape [n, h, w] using a new dimension to represent the number of objects. 49 | 50 | Returns: 51 | torch.Tensor: Mask tensor with shape [1, h, w] using 1, 2, 3, ... to represent different objects. 52 | """ 53 | n_objects = mask.shape[0] 54 | new_mask = torch.zeros((1, *mask.shape[1:]), dtype=torch.int64) 55 | for i in range(n_objects): 56 | new_mask[0][mask[i] == 1] = i + 1 57 | return new_mask 58 | 59 | 60 | def copy_model(model: torch.nn.Module): 61 | new_model = copy.deepcopy(model) 62 | freeze(new_model) 63 | return new_model 64 | 65 | 66 | def create_csv(filename, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]): 67 | if os.path.exists(filename): 68 | return 69 | with open(filename, 'w') as csvfile: 70 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head) 71 | csv_write.writeheader() 72 | 73 | 74 | def write_csv(filename, csv_dict, csv_head=["corrupt", "Mean IoU", "Mean F1", "epoch"]): 75 | with open(filename, 'a+') as csvfile: 76 | csv_write = csv.DictWriter(csvfile, fieldnames=csv_head, extrasaction='ignore') 77 | csv_write.writerow(csv_dict) 78 | 79 | 80 | def check_grad(model: torch.nn.Module): 81 | for name, param in model.named_parameters(): 82 | print(f"{name}: {param.requires_grad}") 83 | 84 | 85 | def check_equal(model1: torch.nn.Module, model2: torch.nn.Module): 86 | for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()): 87 | if name1 == name2: 88 | if not torch.allclose(param1, param2): 89 | print(f"{name1} is different") 90 | else: 91 | print(f"same") 92 | else: 93 | print("The models have different structures") 94 | 95 | 96 | def check_model(model): 97 | return summary(model, (3, 1024, 1024), batch_size=1, device="cuda") 98 | 99 | 100 | def reduce_instances(bboxes, gt_masks, max_nums=50): 101 | bboxes_ = [] 102 | gt_masks_ = [] 103 | for bbox, gt_mask in zip(bboxes, gt_masks): 104 | idx = np.arange(bbox.shape[0]) 105 | np.random.shuffle(idx) 106 | bboxes_.append(bbox[idx[:max_nums]]) 107 | gt_masks_.append(gt_mask[idx[:max_nums]]) 108 | 109 | bboxes = bboxes_ 110 | gt_masks = gt_masks_ 111 | return bboxes, gt_masks 112 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torchvision.utils import draw_bounding_boxes 5 | from torchvision.utils import draw_segmentation_masks 6 | 7 | from box import Box 8 | from tqdm import tqdm 9 | from model import Model 10 | from datasets.COCO import COCODataset 11 | 12 | 13 | def draw_image(image, masks, boxes, labels, alpha=0.4): 14 | image = torch.from_numpy(image).permute(2, 0, 1) 15 | if boxes is not None: 16 | image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2) 17 | if masks is not None: 18 | image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha) 19 | return image.numpy().transpose(1, 2, 0) 20 | 21 | 22 | def visualize(cfg: Box): 23 | model = Model(cfg) 24 | model.setup() 25 | model.eval() 26 | model.cuda() 27 | dataset = COCODataset(root_dir=cfg.dataset.val.root_dir, 28 | annotation_file=cfg.dataset.val.annotation_file, 29 | transform=None) 30 | predictor = model.get_predictor() 31 | os.makedirs(cfg.out_dir, exist_ok=True) 32 | 33 | for image_id in tqdm(dataset.image_ids): 34 | image_info = dataset.coco.loadImgs(image_id)[0] 35 | image_path = os.path.join(dataset.root_dir, image_info['file_name']) 36 | image_output_path = os.path.join(cfg.out_dir, image_info['file_name']) 37 | image = cv2.imread(image_path) 38 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 39 | ann_ids = dataset.coco.getAnnIds(imgIds=image_id) 40 | anns = dataset.coco.loadAnns(ann_ids) 41 | bboxes = [] 42 | for ann in anns: 43 | x, y, w, h = ann['bbox'] 44 | bboxes.append([x, y, x + w, y + h]) 45 | bboxes = torch.as_tensor(bboxes, device=model.model.device) 46 | transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2]) 47 | predictor.set_image(image) 48 | masks, _, _ = predictor.predict_torch( 49 | point_coords=None, 50 | point_labels=None, 51 | boxes=transformed_boxes, 52 | multimask_output=False, 53 | ) 54 | image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None) 55 | cv2.imwrite(image_output_path, image_output) 56 | 57 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import lightning as L 5 | import segmentation_models_pytorch as smp 6 | from box import Box 7 | from lightning.fabric.loggers import TensorBoardLogger 8 | 9 | from configs.config_validate import cfg 10 | from datasets import call_load_dataset 11 | from model import Model 12 | from utils.eval_utils import AverageMeter, get_prompts, validate 13 | from utils.tools import copy_model, create_csv 14 | from model import Model 15 | from sam_lora import LoRA_Sam 16 | 17 | 18 | def multi_main(cfg): 19 | prompts = ["box", "point"] 20 | for prompt in prompts: 21 | cfg.prompt = prompt 22 | torch.cuda.empty_cache() 23 | main(cfg) 24 | 25 | 26 | def main(cfg: Box, ckpt: str = None) -> None: 27 | gpu_ids = cfg.gpu_ids.split(',') 28 | num_devices = len(gpu_ids) 29 | 30 | fabric = L.Fabric(accelerator="auto", 31 | devices=num_devices, 32 | strategy="auto", 33 | loggers=[TensorBoardLogger(cfg.out_dir)]) 34 | fabric.launch() 35 | fabric.seed_everything(1337 + fabric.global_rank) 36 | 37 | if fabric.global_rank == 0: 38 | os.makedirs(cfg.out_dir, exist_ok=True) 39 | create_csv(os.path.join(cfg.out_dir, f"{cfg.dataset}-{cfg.prompt}.csv"), csv_head=cfg.csv_keys) 40 | 41 | with fabric.device: 42 | model = Model(cfg) 43 | model.setup() 44 | LoRA_Sam(model.model, 4) 45 | 46 | load_datasets = call_load_dataset(cfg) 47 | _, val_data = load_datasets(cfg, model.model.image_encoder.img_size) 48 | 49 | fabric.print(f"Val Data: {len(val_data) * cfg.val_batchsize}") 50 | val_data = fabric._setup_dataloader(val_data) 51 | 52 | if ckpt is not None: 53 | full_checkpoint = fabric.load(ckpt) 54 | model.load_state_dict(full_checkpoint["model"]) 55 | 56 | validate(fabric, cfg, model, val_data, name=cfg.name, iters=0) 57 | 58 | del model, val_data 59 | 60 | 61 | if __name__ == "__main__": 62 | torch.cuda.empty_cache() 63 | torch.set_float32_matmul_precision('high') 64 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_ids 65 | 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--ckpt", default=None, type=str) 68 | args = parser.parse_args() 69 | 70 | main(cfg, args.ckpt) 71 | # multi_main(cfg, args.ckpt) 72 | torch.cuda.empty_cache() 73 | --------------------------------------------------------------------------------