├── LICENSE ├── README.md ├── config └── semantic_seg.yaml ├── datasets ├── __init__.py ├── detection.py ├── instance_seg.py ├── matting.py ├── semantic_seg.py └── transforms.py ├── extend_sam ├── __init__.py ├── extend_sam.py ├── image_encoder_adapter.py ├── mask_decoder_adapter.py ├── mask_decoder_heads.py ├── mask_decoder_neck.py ├── prompt_encoder_adapter.py ├── runner.py ├── scheduler.py ├── segment_anything_ori │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py └── utils.py ├── how_to_use_finetune_anything.md ├── losses ├── __init__.py └── losses.py ├── requirements.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ziqi-jin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | The [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) has revolutionized computer vision. Relying on fine-tuning of SAM will solve a large number of basic computer vision tasks. We are designing a **class-aware one-stage** tool for training fine-tuning models based on SAM. 4 | 5 | You need to supply the datasets for your tasks and the [supported task](#Supported-Tasks) name, this tool will help you to get a finetuned model for your task. You are also allowed to design your own extend-SAM model, and FA supply the training, testing and deploy process for you. 6 | 7 | 8 | 9 | ## Design 10 | Finetune-Anything further encapsulates the three parts of the original SAM, i.e., Image Encoder Adapter, Prompt Encoder Adapter, and Mask Decoder Adatper. We will support the base extend-SAM model for each task. Users also could design your own customized modules in each adapter, use FA to design different adapters, and set whether the parameters of any module are fixed. For modules with unfixed parameters, parameters such as `lr`, `weight decay` can be set to coordinate with the fine-tuning of the model. 11 | check details in [How_to_use](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md). 12 | For example, MaskDecoder is encapsulated as MaskDecoderAdapter. The current MaskDecoderAdatper contains two parts, DecoderNeck and DecoderHead. 13 | 14 | 15 | 16 | ## Supported Tasks 17 | - [x] Semantic Segmentation 18 | - [x] train 19 | - [x] eval 20 | - [ ] test 21 | - [ ] Matting 22 | - [ ] Instance Segmentation 23 | - [ ] Detection 24 | ## Supported Datasets 25 | - [x] TorchVOCSegmentation 26 | - [x] BaseSemantic 27 | - [ ] BaseInstance 28 | - [ ] BaseMatting 29 | 30 | ## Deploy 31 | - [ ] Onnx export 32 | 33 | ## Support Plan 34 | FA will be updated in the following order, 35 | 36 | - Mattng (task) 37 | - Prompt Part (structure) 38 | - [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) (model) 39 | - Instance Segmentation (task) 40 | 41 | # Usage 42 | finetune-anything(FA) supports the entire training process of SAM model fine-tuning, including the modification of the model structure, as well as the model training, verification, and testing processes. For details, check the [How_to_use](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md), the [Quick Start](#Quick-Start) gives an example of quickly using FA to train a custom semantic segmentation model. 43 | ## Quick Start 44 | ### Install 45 | - Step1 46 | ``` 47 | git clone https://github.com/ziqi-jin/finetune-anything.git 48 | cd finetune-anything 49 | pip install -r requirements.txt 50 | ``` 51 | - Step2 52 | Download the SAM weights from [SAM repository](https://github.com/facebookresearch/segment-anything#model-checkpoints) 53 | 54 | - Step3 55 | Modify the contents of yaml file for the specific task in **/config**, e.g., ckpt_path, model_type ... 56 | 57 | ### Train 58 | ``` 59 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --task_name semantic_seg 60 | ``` 61 | 62 | ## One more thing 63 | 64 | If you need to use loss, dataset, or other functions that are not supported by FA, please submit an issue, and I will help you to implement them. At the same time, developers are also welcome to develop new loss, dataset or other new functions for FA, please submit your PR (pull requests). 65 | 66 | ## Related Resources 67 | 68 | - [Documents](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md) 69 | 70 | -------------------------------------------------------------------------------- /config/semantic_seg.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | experiment_name: 'semantic_sam' 3 | 4 | # Model 5 | model: 6 | sam_name: 'sem_sam' 7 | params: 8 | # Fix the a part of parameters in SAM 9 | fix_img_en: True 10 | fix_prompt_en: True 11 | fix_mask_de: False 12 | ckpt_path: 'sam_ckpt/sam_vit_b_01ec64.pth' 13 | class_num: 21 # 20 + 1 14 | model_type: 'vit_b' # type should be in [vit_h, vit_b, vit_l, default] 15 | 16 | # Dataset 17 | dataset: 18 | name: 'torch_voc_sem' 19 | params: 20 | root: '/data/jinziqi/DATASETS/' 21 | year: '2012' 22 | image_set: 'train' 23 | transforms: 24 | resize: 25 | params: 26 | size: [1024, 1024] 27 | to_tensor: 28 | params: ~ 29 | target_transforms: 30 | resize: 31 | params: 32 | size: [1024, 1024] 33 | 34 | # Losses 35 | losses: 36 | ce: 37 | weight: 0.5 38 | params: # ~ means None type, the initial params of loss could be identified here 39 | ignore_index: 255 40 | label_one_hot: False 41 | 42 | # Optimizer 43 | opt_params: 44 | lr_default: 1e-3 45 | wd_default: 1e-4 46 | momentum: 0.9 47 | lr_list: [ 1e-2, ] 48 | group_keys: [ [ 'mask_adapter.decoder_head.output_hypernetworks_mlps', ], ] 49 | wd_list: [ 0.0, ] 50 | opt_name: 'sgd' # 'sgd' 51 | scheduler_name: 'cosine' 52 | 53 | # Runner 54 | max_iter: 100000 55 | log_iter: 20 56 | eval_iter: 200 57 | runner_name: 'sem_runner' 58 | # Dataloader 59 | bs: 8 # 8 60 | num_workers: 2 61 | drop_last: True 62 | # Logger 63 | use_tensorboard: True 64 | tensorboard_folder: './experiment/tensorboard' 65 | log_folder: './experiment/log' 66 | model_folder: './experiment/model' 67 | 68 | val: 69 | # Dataset 70 | dataset: 71 | name: 'torch_voc_sem' 72 | params: 73 | root: '/data/jinziqi/DATASETS/' 74 | year: '2012' 75 | image_set: 'train' 76 | transforms: 77 | resize: 78 | params: 79 | size: [1024, 1024] 80 | to_tensor: 81 | params: ~ 82 | target_transforms: 83 | resize: 84 | params: 85 | size: [1024, 1024] 86 | 87 | bs: 8 88 | num_workers: 2 89 | drop_last: True 90 | 91 | 92 | test: 93 | need_test: False 94 | 95 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import BaseDetectionDataset 2 | from .instance_seg import BaseInstanceDataset 3 | from .semantic_seg import BaseSemanticDataset, VOCSemanticDataset, TorchVOCSegmentation 4 | from .transforms import get_transforms 5 | from torchvision.datasets import VOCSegmentation 6 | 7 | segment_datasets = {'base_ins': BaseInstanceDataset, 'base_sem': BaseSemanticDataset, 8 | 'voc_sem': VOCSemanticDataset, 'torch_voc_sem': TorchVOCSegmentation} 9 | det_dataset = {'base_det': BaseDetectionDataset, } 10 | 11 | 12 | def get_dataset(cfg): 13 | name = cfg.name 14 | assert name in segment_datasets or name in det_dataset, \ 15 | print('{name} is not supported, please implement it first.'.format(name=name)) 16 | # TODO customized dataset params: 17 | # customized dataset params example: 18 | # if xxx: 19 | # param1 = cfg.xxx 20 | # param2 = cfg.xxx 21 | # return name_dict[name](path, model, param1, param2, ...) 22 | transform = get_transforms(cfg.transforms) 23 | if name in det_dataset: 24 | return det_dataset[name](**cfg.params, transform=transform) 25 | target_transform = get_transforms(cfg.target_transforms) 26 | return segment_datasets[name](**cfg.params, transform=transform, target_transform=target_transform) 27 | 28 | 29 | class Iterator: 30 | def __init__(self, loader): 31 | self.loader = loader 32 | self.init() 33 | 34 | def init(self): 35 | self.iterator = iter(self.loader) 36 | 37 | def get(self): 38 | try: 39 | data = next(self.iterator) 40 | except StopIteration: 41 | self.init() 42 | data = next(self.iterator) 43 | 44 | return data 45 | -------------------------------------------------------------------------------- /datasets/detection.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class BaseDetectionDataset(Dataset): 5 | def __init__(self): 6 | assert False, print('BaseDetectionDataset is not Unimplemented.') 7 | 8 | def __getitem__(self, item): 9 | pass 10 | -------------------------------------------------------------------------------- /datasets/instance_seg.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class BaseInstanceDataset(Dataset): 5 | def __init__(self): 6 | assert False, print("Unimplement Dataset.") 7 | 8 | def __getitem__(self, item): 9 | pass 10 | -------------------------------------------------------------------------------- /datasets/matting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets import VisionDataset 5 | import numpy as np 6 | 7 | class BaseMattingDataset(VisionDataset): 8 | """ 9 | if you want to customize a new dataset to train the matting task, 10 | the img and mask file need be arranged as this sturcture. 11 | ├── data 12 | │ ├── my_dataset 13 | │ │ ├── img 14 | │ │ │ ├── train 15 | │ │ │ │ ├── xxx{img_suffix} 16 | │ │ │ │ ├── yyy{img_suffix} 17 | │ │ │ │ ├── zzz{img_suffix} 18 | │ │ │ ├── val 19 | │ │ ├── trimap 20 | │ │ │ ├── train 21 | │ │ │ │ ├── xxx{img_suffix} 22 | │ │ │ │ ├── yyy{img_suffix} 23 | │ │ │ │ ├── zzz{img_suffix} 24 | │ │ │ ├── val 25 | │ │ ├── ann 26 | │ │ │ ├── train 27 | │ │ │ │ ├── xxx{ann_suffix} 28 | │ │ │ │ ├── yyy{ann_suffix} 29 | │ │ │ │ ├── zzz{ann_suffix} 30 | │ │ │ ├── val 31 | """ 32 | 33 | def __init__(self, metainfo, dataset_dir, transform, target_transform, 34 | trimap_transform=None, 35 | image_set='train', 36 | img_suffix='.jpg', 37 | ann_suffix='.png', 38 | trimap_suffix=None, 39 | data_prefix: dict = dict(img_path='img', ann_path='ann', trimap_path='trimap_pth'), 40 | return_dict=False): 41 | ''' 42 | 43 | :param metainfo: meta data in original dataset, e.g. class_names 44 | :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above 45 | :param image_set: 'train' or 'val' 46 | :param img_suffix: your image suffix 47 | :param ann_suffix: your annotation suffix 48 | :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann' 49 | :param return_dict: return dict() or tuple(img, ann) 50 | ''' 51 | super(BaseMattingDataset, self).__init__(root=dataset_dir, transform=transform, 52 | target_transform=target_transform) 53 | 54 | self.class_names = metainfo['class_names'] 55 | self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set) 56 | self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set) 57 | 58 | print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format( 59 | img_folder_name=self.img_path, ann_folder_name=self.ann_path)) 60 | self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if 61 | img_name.endswith(img_suffix)] 62 | 63 | self.has_trimap = trimap_suffix is not None 64 | if self.has_trimap: 65 | self.trimap_path = os.path.join(dataset_dir, data_prefix['trimap_pth'], image_set) 66 | print('trimap_folder_name: {trimap_folder_name}'.format(trimap_folder_name=self.trimap_path)) 67 | self.img_suffix = img_suffix 68 | self.ann_suffix = ann_suffix 69 | self.return_dict = return_dict 70 | self.trimap_transform = trimap_transform 71 | 72 | def __getitem__(self, index): 73 | img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix)) 74 | ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) 75 | if self.transforms is not None: 76 | img, ann = self.transforms(img, ann) 77 | ann = np.array(ann) 78 | if self.has_trimap: 79 | ## return for self.has_trimpa==True 80 | trimap = Image.open(os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix)) 81 | if self.trimap_transform: 82 | trimap = self.trimap_transform(trimap) 83 | else: 84 | print("Warnning: you may need set transform function for trimap input") 85 | if self.return_dict: 86 | data = dict(img_name=self.img_names[index], img=img, ann=ann, trimap=trimap, 87 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), 88 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix), 89 | trimap_path=os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix)) 90 | return data 91 | return img, ann, trimap 92 | else: 93 | ## return for self.has_trimpa==False 94 | if self.return_dict: 95 | data = dict(img_name=self.img_names[index], img=img, ann=ann, 96 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), 97 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) 98 | return data 99 | return img, ann 100 | 101 | def __len__(self): 102 | return len(self.img_names) 103 | 104 | -------------------------------------------------------------------------------- /datasets/semantic_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets import VOCSegmentation, VisionDataset 5 | import numpy as np 6 | 7 | 8 | class BaseSemanticDataset(VisionDataset): 9 | """ 10 | if you want to customize a new dataset to train the segmentation task, 11 | the img and mask file need be arranged as this sturcture. 12 | ├── data 13 | │ ├── my_dataset 14 | │ │ ├── img 15 | │ │ │ ├── train 16 | │ │ │ │ ├── xxx{img_suffix} 17 | │ │ │ │ ├── yyy{img_suffix} 18 | │ │ │ │ ├── zzz{img_suffix} 19 | │ │ │ ├── val 20 | │ │ ├── ann 21 | │ │ │ ├── train 22 | │ │ │ │ ├── xxx{ann_suffix} 23 | │ │ │ │ ├── yyy{ann_suffix} 24 | │ │ │ │ ├── zzz{ann_suffix} 25 | │ │ │ ├── val 26 | """ 27 | 28 | def __init__(self, metainfo, dataset_dir, transform, target_transform, 29 | image_set='train', 30 | img_suffix='.jpg', 31 | ann_suffix='.png', 32 | data_prefix: dict = dict(img_path='img', ann_path='ann'), 33 | return_dict=False): 34 | ''' 35 | 36 | :param metainfo: meta data in original dataset, e.g. class_names 37 | :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above 38 | :param image_set: 'train' or 'val' 39 | :param img_suffix: your image suffix 40 | :param ann_suffix: your annotation suffix 41 | :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann' 42 | :param return_dict: return dict() or tuple(img, ann) 43 | ''' 44 | super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform, 45 | target_transform=target_transform) 46 | 47 | self.class_names = metainfo['class_names'] 48 | self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set) 49 | self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set) 50 | print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format( 51 | img_folder_name=self.img_path, ann_folder_name=self.ann_path)) 52 | self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if 53 | img_name.endswith(img_suffix)] 54 | self.img_suffix = img_suffix 55 | self.ann_suffix = ann_suffix 56 | self.return_dict = return_dict 57 | 58 | def __getitem__(self, index): 59 | img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix)) 60 | ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) 61 | if self.transforms is not None: 62 | img, ann = self.transforms(img, ann) 63 | ann = np.array(ann) 64 | 65 | if self.return_dict: 66 | data = dict(img_name=self.img_names[index], img=img, ann=ann, 67 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), 68 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) 69 | return data 70 | return img, ann 71 | 72 | def __len__(self): 73 | return len(self.img_names) 74 | 75 | 76 | class VOCSemanticDataset(Dataset): 77 | def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False): 78 | super(VOCSemanticDataset, self).__init__() 79 | self.root_dir = root_dir 80 | 81 | self.image_dir = self.root_dir + 'JPEGImages/' 82 | self.xml_dir = self.root_dir + 'Annotations/' 83 | self.mask_dir = self.root_dir + 'SegmentationClass/' 84 | 85 | self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()] 86 | self.transform = transform 87 | self.with_id = with_id 88 | self.with_mask = with_mask 89 | self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 90 | 'bus', 'car', 'cat', 'chair', 'cow', 91 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 92 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 93 | 94 | def __len__(self): 95 | return len(self.image_id_list) 96 | 97 | def get_image(self, image_id): 98 | image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB') 99 | if self.transform is not None: 100 | image = self.transform(image) 101 | return image 102 | 103 | def get_mask(self, image_id): 104 | mask_path = self.mask_dir + image_id + '.png' 105 | if os.path.isfile(mask_path): 106 | mask = Image.open(mask_path) 107 | else: 108 | mask = None 109 | return mask 110 | 111 | def __getitem__(self, index): 112 | image_id = self.image_id_list[index] 113 | 114 | data_list = [self.get_image(image_id)] 115 | 116 | if self.with_id: 117 | data_list.append(image_id) 118 | 119 | if self.with_mask: 120 | data_list.append(self.get_mask(image_id)) 121 | 122 | return data_list 123 | 124 | 125 | class TorchVOCSegmentation(VOCSegmentation): 126 | def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None): 127 | super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download, 128 | transform=transform, target_transform=target_transform) 129 | self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 130 | 'bus', 'car', 'cat', 'chair', 'cow', 131 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 132 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 133 | 134 | def __getitem__(self, index: int): 135 | """ 136 | Args: 137 | index (int): Index 138 | 139 | Returns: 140 | tuple: (image, target) where target is the image segmentation. 141 | """ 142 | img = Image.open(self.images[index]).convert('RGB') 143 | target = Image.open(self.masks[index]) 144 | 145 | if self.transforms is not None: 146 | img, target = self.transforms(img, target) 147 | 148 | target = np.array(target) 149 | return img, target 150 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from omegaconf.dictconfig import DictConfig 3 | import torch.nn as nn 4 | 5 | AVIAL_TRANSFORM = {'resize': T.Resize, 'to_tensor': T.ToTensor} 6 | 7 | 8 | def get_transforms(transforms: DictConfig): 9 | T_list = [] 10 | for t_name in transforms.keys(): 11 | assert t_name in AVIAL_TRANSFORM, "{T_name} is not supported transform, please implement it and add it to " \ 12 | "AVIAL_TRANSFORM first.".format(T_name=t_name) 13 | if transforms[t_name].params is not None: 14 | T_list.append(AVIAL_TRANSFORM[t_name](**transforms[t_name].params)) 15 | else: 16 | T_list.append(AVIAL_TRANSFORM[t_name]()) 17 | return T.Compose(T_list) 18 | 19 | 20 | class CustomTransform(nn.Module): 21 | def __init__(self): 22 | pass 23 | 24 | def forward(self): 25 | pass 26 | -------------------------------------------------------------------------------- /extend_sam/__init__.py: -------------------------------------------------------------------------------- 1 | # copyright ziqi-jin 2 | import torch 3 | from .extend_sam import BaseExtendSam, SemanticSam 4 | from .runner import BaseRunner, SemRunner 5 | # from .optimizer import BaseOptimizer 6 | from .scheduler import WarmupMultiStepLR 7 | from .utils import get_opt_pamams 8 | 9 | AVAI_SCH = ["single_step", "multi_step", "warmup_multi_step", "cosine", "linear"] 10 | AVAI_MODEL = {'base_sam': BaseExtendSam, 'sem_sam': SemanticSam} 11 | # AVAI_OPT = {'base_opt': BaseOptimizer, 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam} 12 | AVAI_OPT = {'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adamw': torch.optim.AdamW} 13 | AVAI_RUNNER = {'base_runner': BaseRunner, 'sem_runner': SemRunner} 14 | 15 | 16 | def get_model(model_name, **kwargs): 17 | if model_name not in AVAI_MODEL: 18 | print('not supported model name, please implement it first.') 19 | return AVAI_MODEL[model_name](**kwargs).cuda() 20 | 21 | 22 | def get_optimizer(opt_name, **kwargs): 23 | if opt_name not in AVAI_OPT: 24 | print('not supported optimizer name, please implement it first.') 25 | return AVAI_OPT[opt_name](**{k: v for k, v in kwargs.items() if v is not None}) 26 | 27 | 28 | def get_runner(runner_name): 29 | if runner_name not in AVAI_RUNNER: 30 | print('not supported runner name, please implement it first.') 31 | return AVAI_RUNNER[runner_name] 32 | 33 | 34 | def get_scheduler( 35 | optimizer, 36 | lr_scheduler="single_step", 37 | stepsize=1, 38 | gamma=0.1, 39 | warmup_factor=0.01, 40 | warmup_steps=10, 41 | max_epoch=1, 42 | n_epochs_init=50, 43 | n_epochs_decay=50, 44 | 45 | ): 46 | """A function wrapper for building a learning rate scheduler. 47 | Args: 48 | optimizer (Optimizer): an Optimizer. 49 | lr_scheduler (str, optional): learning rate scheduler method. Default is 50 | single_step. 51 | stepsize (int or list, optional): step size to decay learning rate. 52 | When ``lr_scheduler`` is "single_step", ``stepsize`` should be an integer. 53 | When ``lr_scheduler`` is "multi_step", ``stepsize`` is a list. Default is 1. 54 | gamma (float, optional): decay rate. Default is 0.1. 55 | max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1. 56 | Examples:: 57 | >>> # Decay learning rate by every 20 epochs. 58 | >>> scheduler = get_scheduler( 59 | >>> optimizer, lr_scheduler='single_step', stepsize=20 60 | >>> ) 61 | >>> # Decay learning rate at 30, 50 and 55 epochs. 62 | >>> scheduler = get_scheduler( 63 | >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55] 64 | >>> ) 65 | """ 66 | if lr_scheduler not in AVAI_SCH: 67 | raise ValueError( 68 | "Unsupported scheduler: {}. Must be one of {}".format( 69 | lr_scheduler, AVAI_SCH 70 | ) 71 | ) 72 | 73 | if lr_scheduler == "single_step": 74 | if isinstance(stepsize, list): 75 | stepsize = stepsize[-1] 76 | 77 | if not isinstance(stepsize, int): 78 | raise TypeError( 79 | "For single_step lr_scheduler, stepsize must " 80 | "be an integer, but got {}".format(type(stepsize)) 81 | ) 82 | 83 | scheduler = torch.optim.lr_scheduler.StepLR( 84 | optimizer, step_size=stepsize, gamma=gamma 85 | ) 86 | 87 | elif lr_scheduler == "multi_step": 88 | if not isinstance(stepsize, list): 89 | raise TypeError( 90 | "For multi_step lr_scheduler, stepsize must " 91 | "be a list, but got {}".format(type(stepsize)) 92 | ) 93 | 94 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 95 | optimizer, milestones=stepsize, gamma=gamma 96 | ) 97 | 98 | elif lr_scheduler == "warmup_multi_step": 99 | if not isinstance(stepsize, list): 100 | raise TypeError( 101 | "For warmup multi_step lr_scheduler, stepsize must " 102 | "be a list, but got {}".format(type(stepsize)) 103 | ) 104 | 105 | scheduler = WarmupMultiStepLR( 106 | optimizer, 107 | milestones=stepsize, 108 | gamma=gamma, 109 | warmup_factor=warmup_factor, 110 | warmup_iters=warmup_steps, 111 | ) 112 | 113 | elif lr_scheduler == "cosine": 114 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 115 | optimizer, int(max_epoch) 116 | ) 117 | 118 | elif lr_scheduler == "linear": 119 | def lambda_rule(epoch): 120 | lr_l = 1.0 - max(0, epoch - n_epochs_init) / float(n_epochs_decay + 1) 121 | return lr_l 122 | 123 | scheduler = torch.optim.lr_scheduler.LambdaLR( 124 | optimizer, lr_lambda=lambda_rule 125 | ) 126 | 127 | return scheduler 128 | -------------------------------------------------------------------------------- /extend_sam/extend_sam.py: -------------------------------------------------------------------------------- 1 | # copyright ziqi-jin 2 | import torch 3 | import torch.nn as nn 4 | from .segment_anything_ori import sam_model_registry 5 | from .image_encoder_adapter import BaseImgEncodeAdapter 6 | from .mask_decoder_adapter import BaseMaskDecoderAdapter, SemMaskDecoderAdapter 7 | from .prompt_encoder_adapter import BasePromptEncodeAdapter 8 | 9 | 10 | class BaseExtendSam(nn.Module): 11 | 12 | def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type='vit_b'): 13 | super(BaseExtendSam, self).__init__() 14 | assert model_type in ['default', 'vit_b', 'vit_l', 'vit_h'], print( 15 | "Wrong model_type, SAM only can be built as vit_b, vot_l, vit_h and default ") 16 | self.ori_sam = sam_model_registry[model_type](ckpt_path) 17 | self.img_adapter = BaseImgEncodeAdapter(self.ori_sam, fix=fix_img_en) 18 | self.prompt_adapter = BasePromptEncodeAdapter(self.ori_sam, fix=fix_prompt_en) 19 | self.mask_adapter = BaseMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de) 20 | 21 | def forward(self, img): 22 | x = self.img_adapter(img) 23 | points = None 24 | boxes = None 25 | masks = None 26 | 27 | sparse_embeddings, dense_embeddings = self.prompt_adapter( 28 | points=points, 29 | boxes=boxes, 30 | masks=masks, 31 | ) 32 | multimask_output = True 33 | low_res_masks, iou_predictions = self.mask_adapter( 34 | image_embeddings=x, 35 | prompt_adapter=self.prompt_adapter, 36 | sparse_embeddings=sparse_embeddings, 37 | dense_embeddings=dense_embeddings, 38 | multimask_output=multimask_output, 39 | ) 40 | return low_res_masks, iou_predictions 41 | 42 | 43 | class SemanticSam(BaseExtendSam): 44 | 45 | def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, class_num=20, model_type='vit_b'): 46 | super().__init__(ckpt_path=ckpt_path, fix_img_en=fix_img_en, fix_prompt_en=fix_prompt_en, 47 | fix_mask_de=fix_mask_de, model_type=model_type) 48 | self.mask_adapter = SemMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de, class_num=class_num) 49 | -------------------------------------------------------------------------------- /extend_sam/image_encoder_adapter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .segment_anything_ori.modeling.sam import Sam 3 | from .utils import fix_params 4 | 5 | 6 | class BaseImgEncodeAdapter(nn.Module): 7 | 8 | def __init__(self, ori_sam: Sam, fix=False): 9 | super(BaseImgEncodeAdapter, self).__init__() 10 | self.sam_img_encoder = ori_sam.image_encoder 11 | if fix: 12 | fix_params(self.sam_img_encoder) 13 | 14 | def forward(self, x): 15 | x = self.sam_img_encoder(x) 16 | return x 17 | -------------------------------------------------------------------------------- /extend_sam/mask_decoder_adapter.py: -------------------------------------------------------------------------------- 1 | # @copyright ziqi-jin 2 | 3 | import torch.nn as nn 4 | import torch 5 | from .segment_anything_ori.modeling.sam import Sam 6 | from .utils import fix_params 7 | from .segment_anything_ori.modeling.mask_decoder import MaskDecoder 8 | from typing import List, Tuple 9 | from torch.nn import functional as F 10 | from .mask_decoder_heads import SemSegHead 11 | from .mask_decoder_neck import MaskDecoderNeck 12 | 13 | 14 | class BaseMaskDecoderAdapter(MaskDecoder): 15 | ''' 16 | multimask_output (bool): If true, the model will return three masks. 17 | For ambiguous input prompts (such as a single click), this will often 18 | produce better masks than a single prediction. If only a single 19 | mask is needed, the model's predicted quality score can be used 20 | to select the best mask. For non-ambiguous prompts, such as multiple 21 | input prompts, multimask_output=False can give better results. 22 | ''' 23 | 24 | # is fix and load params 25 | def __init__(self, ori_sam: Sam, fix=False): 26 | super(BaseMaskDecoderAdapter, self).__init__(transformer_dim=ori_sam.mask_decoder.transformer_dim, 27 | transformer=ori_sam.mask_decoder.transformer) 28 | self.sam_mask_decoder = ori_sam.mask_decoder 29 | if fix: 30 | fix_params(self.sam_mask_decoder) # move to runner to implement 31 | 32 | def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True): 33 | low_res_masks, iou_predictions = self.sam_mask_decoder(image_embeddings=image_embeddings, 34 | image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(), 35 | sparse_prompt_embeddings=sparse_embeddings, 36 | dense_prompt_embeddings=dense_embeddings, 37 | multimask_output=multimask_output, ) 38 | return low_res_masks, iou_predictions 39 | 40 | 41 | class SemMaskDecoderAdapter(BaseMaskDecoderAdapter): 42 | def __init__(self, ori_sam: Sam, fix=False, class_num=20): 43 | super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix) 44 | self.decoder_neck = MaskDecoderNeck(transformer_dim=self.sam_mask_decoder.transformer_dim, 45 | transformer=self.sam_mask_decoder.transformer, 46 | num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs) 47 | self.decoder_head = SemSegHead(transformer_dim=self.sam_mask_decoder.transformer_dim, 48 | num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs, 49 | iou_head_depth=self.sam_mask_decoder.iou_head_depth, 50 | iou_head_hidden_dim=self.sam_mask_decoder.iou_head_hidden_dim, 51 | class_num=class_num) 52 | # pair the params between ori mask_decoder and new mask_decoder_adapter 53 | self.pair_params(self.decoder_neck) 54 | self.pair_params(self.decoder_head) 55 | 56 | def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True, 57 | scale=1): 58 | src, iou_token_out, mask_tokens_out, src_shape = self.decoder_neck(image_embeddings=image_embeddings, 59 | image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(), 60 | sparse_prompt_embeddings=sparse_embeddings, 61 | dense_prompt_embeddings=dense_embeddings, 62 | multimask_output=multimask_output, ) 63 | masks, iou_pred = self.decoder_head(src, iou_token_out, mask_tokens_out, src_shape, mask_scale=scale) 64 | return masks, iou_pred 65 | 66 | def pair_params(self, target_model: nn.Module): 67 | src_dict = self.sam_mask_decoder.state_dict() 68 | for name, value in target_model.named_parameters(): 69 | if name in src_dict.keys(): 70 | value.data.copy_(src_dict[name].data) 71 | 72 | 73 | # Lightly adapted from 74 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 75 | class MLP(nn.Module): 76 | def __init__( 77 | self, 78 | input_dim: int, 79 | hidden_dim: int, 80 | output_dim: int, 81 | num_layers: int, 82 | sigmoid_output: bool = False, 83 | ) -> None: 84 | super().__init__() 85 | self.num_layers = num_layers 86 | h = [hidden_dim] * (num_layers - 1) 87 | self.layers = nn.ModuleList( 88 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 89 | ) 90 | self.sigmoid_output = sigmoid_output 91 | 92 | def forward(self, x): 93 | for i, layer in enumerate(self.layers): 94 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 95 | if self.sigmoid_output: 96 | x = F.sigmoid(x) 97 | return x 98 | -------------------------------------------------------------------------------- /extend_sam/mask_decoder_heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from typing import List, Tuple, Type 6 | 7 | from .segment_anything_ori.modeling.common import LayerNorm2d 8 | 9 | 10 | class OriHead(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | *, 15 | transformer_dim: int, 16 | num_multimask_outputs: int = 3, 17 | activation: Type[nn.Module] = nn.GELU, 18 | iou_head_depth: int = 3, 19 | iou_head_hidden_dim: int = 256, 20 | ) -> None: 21 | """ 22 | Predicts masks given an image and prompt embeddings, using a 23 | tranformer architecture. 24 | 25 | Arguments: 26 | transformer_dim (int): the channel dimension of the transformer 27 | num_multimask_outputs (int): the number of masks to predict 28 | when disambiguating masks 29 | activation (nn.Module): the type of activation to use when 30 | upscaling masks 31 | iou_head_depth (int): the depth of the MLP used to predict 32 | mask quality 33 | iou_head_hidden_dim (int): the hidden dimension of the MLP 34 | used to predict mask quality 35 | """ 36 | super().__init__() 37 | self.transformer_dim = transformer_dim 38 | 39 | self.num_multimask_outputs = num_multimask_outputs 40 | 41 | self.num_mask_tokens = num_multimask_outputs + 1 42 | 43 | self.output_upscaling = nn.Sequential( 44 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 45 | LayerNorm2d(transformer_dim // 4), 46 | activation(), 47 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 48 | activation(), 49 | ) 50 | self.output_hypernetworks_mlps = nn.ModuleList( 51 | [ 52 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 53 | for i in range(self.num_mask_tokens) 54 | ] 55 | ) 56 | 57 | self.iou_prediction_head = MLP( 58 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 59 | ) 60 | 61 | def forward( 62 | self, 63 | src: torch.Tensor, 64 | iou_token_out: torch.Tensor, 65 | mask_tokens_out: torch.Tensor, 66 | multimask_output: bool, 67 | ) -> Tuple[torch.Tensor, torch.Tensor]: 68 | """ 69 | Predict masks given image and prompt embeddings. 70 | 71 | Arguments: 72 | image_embeddings (torch.Tensor): the embeddings from the image encoder 73 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 74 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 75 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 76 | multimask_output (bool): Whether to return multiple masks or a single 77 | mask. 78 | 79 | Returns: 80 | torch.Tensor: batched predicted masks 81 | torch.Tensor: batched predictions of mask quality 82 | """ 83 | b, c, h, w = src.shape 84 | 85 | # Upscale mask embeddings and predict masks using the mask tokens 86 | src = src.transpose(1, 2).view(b, c, h, w) 87 | upscaled_embedding = self.output_upscaling(src) 88 | hyper_in_list: List[torch.Tensor] = [] 89 | for i in range(self.num_mask_tokens): 90 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 91 | hyper_in = torch.stack(hyper_in_list, dim=1) 92 | b, c, h, w = upscaled_embedding.shape 93 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 94 | 95 | # Generate mask quality predictions 96 | iou_pred = self.iou_prediction_head(iou_token_out) 97 | 98 | # Select the correct mask or masks for outptu 99 | if multimask_output: 100 | mask_slice = slice(1, None) 101 | else: 102 | mask_slice = slice(0, 1) 103 | masks = masks[:, mask_slice, :, :] 104 | iou_pred = iou_pred[:, mask_slice] 105 | 106 | # Prepare output 107 | return masks, iou_pred 108 | 109 | 110 | class SemSegHead(nn.Module): 111 | 112 | def __init__( 113 | self, 114 | *, 115 | transformer_dim: int, 116 | num_multimask_outputs: int = 3, 117 | activation: Type[nn.Module] = nn.GELU, 118 | iou_head_depth: int = 3, 119 | iou_head_hidden_dim: int = 256, 120 | class_num: int = 20, 121 | ) -> None: 122 | """ 123 | Predicts masks given an image and prompt embeddings, using a 124 | tranformer architecture. 125 | 126 | Arguments: 127 | transformer_dim (int): the channel dimension of the transformer 128 | num_multimask_outputs (int): the number of masks to predict 129 | when disambiguating masks 130 | activation (nn.Module): the type of activation to use when 131 | upscaling masks 132 | iou_head_depth (int): the depth of the MLP used to predict 133 | mask quality 134 | iou_head_hidden_dim (int): the hidden dimension of the MLP 135 | used to predict mask quality 136 | """ 137 | super().__init__() 138 | self.transformer_dim = transformer_dim 139 | self.num_multimask_outputs = num_multimask_outputs 140 | self.num_mask_tokens = num_multimask_outputs + 1 141 | self.class_num = class_num 142 | 143 | self.output_upscaling = nn.Sequential( 144 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 145 | LayerNorm2d(transformer_dim // 4), 146 | activation(), 147 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 148 | activation(), 149 | ) 150 | 151 | self.output_hypernetworks_mlps = nn.ModuleList( 152 | [ 153 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 154 | for _ in range(self.class_num) 155 | ] 156 | ) 157 | 158 | self.iou_prediction_head = MLP( 159 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 160 | ) 161 | 162 | def forward( 163 | self, 164 | src: torch.Tensor, 165 | iou_token_out: torch.Tensor, 166 | mask_tokens_out: torch.Tensor, 167 | src_shape, 168 | mask_scale=1, 169 | ) -> Tuple[torch.Tensor, torch.Tensor]: 170 | """ 171 | Predict masks given image and prompt embeddings. 172 | 173 | Arguments: 174 | src (torch.Tensor): The tensor contains image embedding and sparse prompt embedding 175 | iou_token_out (torch.Tensor): Tokens of iou prediction from neck module 176 | mask_tokens_out (torch.Tensor): Tokens of mask prediction form neck module 177 | mask_scale (int): Original SAM output 3 masks which is from local to global as default 178 | This Class use one of three mask tokens to transform it into class-ware 179 | semantic segmentation prediction 180 | 181 | Returns: 182 | torch.Tensor: batched predicted semantic masks 183 | torch.Tensor: batched predictions of mask quality 184 | """ 185 | b, c, h, w = src_shape 186 | 187 | # Upscale mask embeddings and predict masks using the mask tokens 188 | src = src.transpose(1, 2).view(b, c, h, w) 189 | upscaled_embedding = self.output_upscaling(src) 190 | hyper_in_list: List[torch.Tensor] = [] 191 | for i in range(self.class_num): 192 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :])) 193 | hyper_in = torch.stack(hyper_in_list, dim=1) 194 | 195 | b, c, h, w = upscaled_embedding.shape 196 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # B N H W, N is num of category 197 | 198 | # Generate mask quality predictions 199 | iou_pred = self.iou_prediction_head(iou_token_out) # B N H W, N is num of category 200 | 201 | return masks, iou_pred 202 | 203 | 204 | # Lightly adapted from 205 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 206 | class MLP(nn.Module): 207 | def __init__( 208 | self, 209 | input_dim: int, 210 | hidden_dim: int, 211 | output_dim: int, 212 | num_layers: int, 213 | sigmoid_output: bool = False, 214 | ) -> None: 215 | super().__init__() 216 | self.num_layers = num_layers 217 | h = [hidden_dim] * (num_layers - 1) 218 | self.layers = nn.ModuleList( 219 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 220 | ) 221 | self.sigmoid_output = sigmoid_output 222 | 223 | def forward(self, x): 224 | for i, layer in enumerate(self.layers): 225 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 226 | if self.sigmoid_output: 227 | x = F.sigmoid(x) 228 | return x 229 | -------------------------------------------------------------------------------- /extend_sam/mask_decoder_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | from .segment_anything_ori.modeling.common import LayerNorm2d 13 | 14 | ''' 15 | This file save the mask_decoder's neck class, 16 | which is the former part of original mask decoder of SAM. 17 | Then the mask_decoder_heads can be used with the neck. 18 | ''' 19 | 20 | 21 | class MaskDecoderNeck(nn.Module): 22 | def __init__( 23 | self, 24 | *, 25 | transformer_dim: int, 26 | transformer: nn.Module, 27 | num_multimask_outputs: int = 3, 28 | activation: Type[nn.Module] = nn.GELU, 29 | ) -> None: 30 | """ 31 | Predicts masks given an image and prompt embeddings, using a 32 | tranformer architecture. 33 | 34 | Arguments: 35 | transformer_dim (int): the channel dimension of the transformer 36 | transformer (nn.Module): the transformer used to predict masks 37 | num_multimask_outputs (int): the number of masks to predict 38 | when disambiguating masks 39 | activation (nn.Module): the type of activation to use when 40 | upscaling masks 41 | """ 42 | super().__init__() 43 | self.transformer_dim = transformer_dim 44 | self.transformer = transformer 45 | 46 | self.num_multimask_outputs = num_multimask_outputs 47 | 48 | self.iou_token = nn.Embedding(1, transformer_dim) 49 | self.num_mask_tokens = num_multimask_outputs + 1 50 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 51 | 52 | self.output_upscaling = nn.Sequential( 53 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 54 | LayerNorm2d(transformer_dim // 4), 55 | activation(), 56 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 57 | activation(), 58 | ) 59 | 60 | def forward( 61 | self, 62 | image_embeddings: torch.Tensor, 63 | image_pe: torch.Tensor, 64 | sparse_prompt_embeddings: torch.Tensor, 65 | dense_prompt_embeddings: torch.Tensor, 66 | multimask_output: bool, 67 | ) -> Tuple[torch.Tensor, torch.Tensor]: 68 | """ 69 | Predict masks given image and prompt embeddings. 70 | 71 | Arguments: 72 | image_embeddings (torch.Tensor): the embeddings from the image encoder 73 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 74 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 75 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 76 | multimask_output (bool): Whether to return multiple masks or a single 77 | mask. 78 | 79 | Returns: 80 | torch.Tensor: The tensor contains image embedding and sparse prompt embedding 81 | torch.Tensor: Tokens of iou prediction 82 | torch.Tensor: Tokens of mask prediction 83 | """ 84 | # Concatenate output tokens 85 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 86 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 87 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 88 | 89 | # Expand per-image data in batch direction to be per-mask 90 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 91 | src = src + dense_prompt_embeddings 92 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 93 | src_shape = src.shape 94 | # Run the transformer 95 | hs, src = self.transformer(src, pos_src, tokens) 96 | iou_token_out = hs[:, 0, :] 97 | mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] 98 | 99 | return src, iou_token_out, mask_tokens_out, src_shape 100 | -------------------------------------------------------------------------------- /extend_sam/prompt_encoder_adapter.py: -------------------------------------------------------------------------------- 1 | # copyright ziqi-jin 2 | 3 | import torch.nn as nn 4 | from .segment_anything_ori.modeling.sam import Sam 5 | from .utils import fix_params 6 | 7 | 8 | class BasePromptEncodeAdapter(nn.Module): 9 | 10 | def __init__(self, ori_sam: Sam, fix=False): 11 | super(BasePromptEncodeAdapter, self).__init__() 12 | 13 | self.sam_prompt_encoder = ori_sam.prompt_encoder 14 | if fix: 15 | fix_params(self.sam_prompt_encoder) 16 | 17 | def forward(self, points=None, boxes=None, masks=None): 18 | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(points, boxes, masks) 19 | return sparse_embeddings, dense_embeddings 20 | -------------------------------------------------------------------------------- /extend_sam/runner.py: -------------------------------------------------------------------------------- 1 | from datasets import Iterator 2 | from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \ 3 | check_folder, one_hot_embedding_3d 4 | import torch 5 | import cv2 6 | import torch.nn.functional as F 7 | import os 8 | import torch.nn as nn 9 | 10 | 11 | class BaseRunner(): 12 | def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): 13 | self.optimizer = optimizer 14 | self.losses = losses 15 | self.train_loader = train_loader 16 | self.val_loader = val_loader 17 | self.model = model 18 | self.scheduler = scheduler 19 | self.train_timer = Timer() 20 | self.eval_timer = Timer() 21 | try: 22 | use_gpu = os.environ['CUDA_VISIBLE_DEVICES'] 23 | except KeyError: 24 | use_gpu = '0' 25 | self.the_number_of_gpu = len(use_gpu.split(',')) 26 | self.original_size = self.model.img_adapter.sam_img_encoder.img_size 27 | if self.the_number_of_gpu > 1: 28 | self.model = nn.DataParallel(self.model) 29 | 30 | 31 | class SemRunner(BaseRunner): 32 | # def __init__(self, **kwargs): 33 | # super().__init__(kwargs) 34 | 35 | def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): 36 | super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler) 37 | self.exist_status = ['train', 'eval', 'test'] 38 | 39 | def train(self, cfg): 40 | # initial identify 41 | train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss']) 42 | train_iterator = Iterator(self.train_loader) 43 | best_valid_mIoU = -1 44 | model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg) 45 | log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg) 46 | check_folder(model_path) 47 | check_folder(log_path) 48 | writer = None 49 | if cfg.use_tensorboard is True: 50 | tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg) 51 | from torch.utils.tensorboard import SummaryWriter 52 | writer = SummaryWriter(tensorboard_dir) 53 | # train 54 | for iteration in range(cfg.max_iter): 55 | images, labels = train_iterator.get() 56 | images, labels = images.cuda(), labels.cuda().long() 57 | masks_pred, iou_pred = self.model(images) 58 | masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False) 59 | 60 | total_loss = torch.zeros(1).cuda() 61 | loss_dict = {} 62 | self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg) 63 | self.optimizer.zero_grad() 64 | total_loss.backward() 65 | self.optimizer.step() 66 | self.scheduler.step() 67 | loss_dict['total_loss'] = total_loss.item() 68 | train_meter.add(loss_dict) 69 | 70 | # log 71 | if (iteration + 1) % cfg.log_iter == 0: 72 | write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True), 73 | status=self.exist_status[0], 74 | writer=writer, timer=self.train_timer) 75 | # eval 76 | if (iteration + 1) % cfg.eval_iter == 0: 77 | mIoU, _ = self._eval() 78 | if best_valid_mIoU == -1 or best_valid_mIoU < mIoU: 79 | best_valid_mIoU = mIoU 80 | save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1) 81 | print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path) 82 | log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU} 83 | write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1], 84 | writer=writer, timer=self.eval_timer) 85 | # final process 86 | save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1) 87 | if writer is not None: 88 | writer.close() 89 | 90 | def test(self): 91 | pass 92 | 93 | def _eval(self): 94 | self.model.eval() 95 | self.eval_timer.start() 96 | class_names = self.val_loader.dataset.class_names 97 | eval_metric = mIoUOnline(class_names=class_names) 98 | with torch.no_grad(): 99 | for index, (images, labels) in enumerate(self.val_loader): 100 | images = images.cuda() 101 | labels = labels.cuda() 102 | masks_pred, iou_pred = self.model(images) 103 | predictions = torch.argmax(masks_pred, dim=1) 104 | for batch_index in range(images.size()[0]): 105 | pred_mask = get_numpy_from_tensor(predictions[batch_index]) 106 | gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0)) 107 | h, w = pred_mask.shape 108 | gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST) 109 | 110 | eval_metric.add(pred_mask, gt_mask) 111 | self.model.train() 112 | return eval_metric.get(clear=True) 113 | 114 | def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg): 115 | """ 116 | Due to the inputs of losses are different, so if you want to add new losses, 117 | you may need to modify the process in this function 118 | """ 119 | loss_cfg = cfg.losses 120 | for index, item in enumerate(self.losses.items()): 121 | # item -> (key: loss_name, val: loss) 122 | real_labels = labels 123 | if loss_cfg[item[0]].label_one_hot: 124 | class_num = cfg.model.params.class_num 125 | real_labels = one_hot_embedding_3d(real_labels, class_num=class_num) 126 | tmp_loss = item[1](mask_pred, real_labels) 127 | loss_dict[item[0]] = tmp_loss.item() 128 | total_loss += loss_cfg[item[0]].weight * tmp_loss 129 | -------------------------------------------------------------------------------- /extend_sam/scheduler.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/optim/lr_scheduler.py # noqa 2 | # and https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/solver/lr_scheduler.py 3 | 4 | from bisect import bisect_right 5 | from typing import List 6 | 7 | import torch 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class WarmupMultiStepLR(_LRScheduler): 12 | def __init__( 13 | self, 14 | optimizer: torch.optim.Optimizer, 15 | milestones: List[int], 16 | gamma: float = 0.1, 17 | warmup_factor: float = 0.001, 18 | warmup_iters: int = 1000, 19 | warmup_method: str = "linear", 20 | last_epoch: int = -1, 21 | **kwargs, 22 | ): 23 | if not list(milestones) == sorted(milestones): 24 | raise ValueError( 25 | "Milestones should be a list of" " increasing integers. Got {}", 26 | milestones, 27 | ) 28 | self.milestones = milestones 29 | self.gamma = gamma 30 | self.warmup_factor = warmup_factor 31 | self.warmup_iters = warmup_iters 32 | self.warmup_method = warmup_method 33 | super().__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self) -> List[float]: 36 | warmup_factor = _get_warmup_factor_at_iter( 37 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 38 | ) 39 | return [ 40 | base_lr 41 | * warmup_factor 42 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 43 | for base_lr in self.base_lrs 44 | ] 45 | 46 | def _compute_values(self) -> List[float]: 47 | # The new interface 48 | return self.get_lr() 49 | 50 | 51 | def _get_warmup_factor_at_iter( 52 | method: str, iter: int, warmup_iters: int, warmup_factor: float 53 | ) -> float: 54 | """ 55 | Return the learning rate warmup factor at a specific iteration. 56 | See https://arxiv.org/abs/1706.02677 for more details. 57 | Args: 58 | method (str): warmup method; either "constant" or "linear". 59 | iter (int): iteration at which to calculate the warmup factor. 60 | warmup_iters (int): the number of warmup iterations. 61 | warmup_factor (float): the base warmup factor (the meaning changes according 62 | to the method used). 63 | Returns: 64 | float: the effective warmup factor at the given iteration. 65 | """ 66 | if iter >= warmup_iters: 67 | return 1.0 68 | 69 | if method == "constant": 70 | return warmup_factor 71 | elif method == "linear": 72 | alpha = iter / warmup_iters 73 | return warmup_factor * (1 - alpha) + alpha 74 | else: 75 | raise ValueError("Unknown warmup method: {}".format(method)) 76 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # modified by ziqi-jin 8 | 9 | from .build_sam import ( 10 | build_sam, 11 | build_sam_vit_h, 12 | build_sam_vit_l, 13 | build_sam_vit_b, 14 | sam_model_registry, 15 | ) 16 | from .modeling.sam import Sam 17 | from .predictor import SamPredictor 18 | from .automatic_mask_generator import SamAutomaticMaskGenerator 19 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros(len(data["boxes"])), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros(len(data["boxes"])), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros(len(boxes)), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # modified by ziqi-jin 8 | 9 | import torch 10 | 11 | from functools import partial 12 | 13 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 14 | 15 | 16 | def build_sam_vit_h(checkpoint=None): 17 | return _build_sam( 18 | encoder_embed_dim=1280, 19 | encoder_depth=32, 20 | encoder_num_heads=16, 21 | encoder_global_attn_indexes=[7, 15, 23, 31], 22 | checkpoint=checkpoint, 23 | ) 24 | 25 | 26 | build_sam = build_sam_vit_h 27 | 28 | 29 | def build_sam_vit_l(checkpoint=None): 30 | return _build_sam( 31 | encoder_embed_dim=1024, 32 | encoder_depth=24, 33 | encoder_num_heads=16, 34 | encoder_global_attn_indexes=[5, 11, 17, 23], 35 | checkpoint=checkpoint, 36 | ) 37 | 38 | 39 | def build_sam_vit_b(checkpoint=None): 40 | return _build_sam( 41 | encoder_embed_dim=768, 42 | encoder_depth=12, 43 | encoder_num_heads=12, 44 | encoder_global_attn_indexes=[2, 5, 8, 11], 45 | checkpoint=checkpoint, 46 | ) 47 | 48 | 49 | sam_model_registry = { 50 | "default": build_sam_vit_h, 51 | "vit_h": build_sam_vit_h, 52 | "vit_l": build_sam_vit_l, 53 | "vit_b": build_sam_vit_b, 54 | } 55 | 56 | 57 | def _build_sam( 58 | encoder_embed_dim, 59 | encoder_depth, 60 | encoder_num_heads, 61 | encoder_global_attn_indexes, 62 | checkpoint=None, 63 | ): 64 | prompt_embed_dim = 256 65 | image_size = 1024 66 | vit_patch_size = 16 67 | image_embedding_size = image_size // vit_patch_size 68 | sam = Sam( 69 | image_encoder=ImageEncoderViT( 70 | depth=encoder_depth, 71 | embed_dim=encoder_embed_dim, 72 | img_size=image_size, 73 | mlp_ratio=4, 74 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 75 | num_heads=encoder_num_heads, 76 | patch_size=vit_patch_size, 77 | qkv_bias=True, 78 | use_rel_pos=True, 79 | global_attn_indexes=encoder_global_attn_indexes, 80 | window_size=14, 81 | out_chans=prompt_embed_dim, 82 | ), 83 | prompt_encoder=PromptEncoder( 84 | embed_dim=prompt_embed_dim, 85 | image_embedding_size=(image_embedding_size, image_embedding_size), 86 | input_image_size=(image_size, image_size), 87 | mask_in_chans=16, 88 | ), 89 | mask_decoder=MaskDecoder( 90 | num_multimask_outputs=3, 91 | transformer=TwoWayTransformer( 92 | depth=2, 93 | embedding_dim=prompt_embed_dim, 94 | mlp_dim=2048, 95 | num_heads=8, 96 | ), 97 | transformer_dim=prompt_embed_dim, 98 | iou_head_depth=3, 99 | iou_head_hidden_dim=256, 100 | ), 101 | pixel_mean=[123.675, 116.28, 103.53], 102 | pixel_std=[58.395, 57.12, 57.375], 103 | ) 104 | if checkpoint is not None: 105 | with open(checkpoint, "rb") as f: 106 | state_dict = torch.load(f) 107 | sam.load_state_dict(state_dict) 108 | return sam 109 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (int or None): Input resolution for calculating the relative positional 148 | parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (int or None): Input resolution for calculating the relative positional 205 | parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | self.iou_head_depth = iou_head_depth 53 | self.iou_head_hidden_dim = iou_head_hidden_dim 54 | self.output_upscaling = nn.Sequential( 55 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 56 | LayerNorm2d(transformer_dim // 4), 57 | activation(), 58 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 59 | activation(), 60 | ) 61 | self.output_hypernetworks_mlps = nn.ModuleList( 62 | [ 63 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 64 | for i in range(self.num_mask_tokens) 65 | ] 66 | ) 67 | 68 | self.iou_prediction_head = MLP( 69 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 70 | ) 71 | 72 | def forward( 73 | self, 74 | image_embeddings: torch.Tensor, 75 | image_pe: torch.Tensor, 76 | sparse_prompt_embeddings: torch.Tensor, 77 | dense_prompt_embeddings: torch.Tensor, 78 | multimask_output: bool, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | ) 101 | 102 | # Select the correct mask or masks for outptu 103 | if multimask_output: 104 | mask_slice = slice(1, None) 105 | else: 106 | mask_slice = slice(0, 1) 107 | masks = masks[:, mask_slice, :, :] 108 | iou_pred = iou_pred[:, mask_slice] 109 | 110 | # Prepare output 111 | return masks, iou_pred 112 | 113 | def predict_masks( 114 | self, 115 | image_embeddings: torch.Tensor, 116 | image_pe: torch.Tensor, 117 | sparse_prompt_embeddings: torch.Tensor, 118 | dense_prompt_embeddings: torch.Tensor, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """Predicts masks. See 'forward' for more details.""" 121 | # Concatenate output tokens 122 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 123 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 124 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 125 | 126 | # Expand per-image data in batch direction to be per-mask 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | src = src + dense_prompt_embeddings 129 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 130 | b, c, h, w = src.shape 131 | 132 | # Run the transformer 133 | hs, src = self.transformer(src, pos_src, tokens) 134 | iou_token_out = hs[:, 0, :] 135 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 136 | 137 | # Upscale mask embeddings and predict masks using the mask tokens 138 | src = src.transpose(1, 2).view(b, c, h, w) 139 | upscaled_embedding = self.output_upscaling(src) 140 | hyper_in_list: List[torch.Tensor] = [] 141 | for i in range(self.num_mask_tokens): 142 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 143 | hyper_in = torch.stack(hyper_in_list, dim=1) 144 | b, c, h, w = upscaled_embedding.shape 145 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 146 | 147 | # Generate mask quality predictions 148 | iou_pred = self.iou_prediction_head(iou_token_out) 149 | 150 | return masks, iou_pred 151 | 152 | 153 | # Lightly adapted from 154 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 155 | class MLP(nn.Module): 156 | def __init__( 157 | self, 158 | input_dim: int, 159 | hidden_dim: int, 160 | output_dim: int, 161 | num_layers: int, 162 | sigmoid_output: bool = False, 163 | ) -> None: 164 | super().__init__() 165 | self.num_layers = num_layers 166 | h = [hidden_dim] * (num_layers - 1) 167 | self.layers = nn.ModuleList( 168 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 169 | ) 170 | self.sigmoid_output = sigmoid_output 171 | 172 | def forward(self, x): 173 | for i, layer in enumerate(self.layers): 174 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 175 | if self.sigmoid_output: 176 | x = F.sigmoid(x) 177 | return x 178 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # modified by ziqi-jin 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from typing import Any, Dict, List, Tuple 14 | 15 | from .image_encoder import ImageEncoderViT 16 | from .mask_decoder import MaskDecoder 17 | from .prompt_encoder import PromptEncoder 18 | 19 | 20 | class Sam(nn.Module): 21 | mask_threshold: float = 0.0 22 | image_format: str = "RGB" 23 | 24 | def __init__( 25 | self, 26 | image_encoder: ImageEncoderViT, 27 | prompt_encoder: PromptEncoder, 28 | mask_decoder: MaskDecoder, 29 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 30 | pixel_std: List[float] = [58.395, 57.12, 57.375], 31 | ) -> None: 32 | """ 33 | SAM predicts object masks from an image and input prompts. 34 | 35 | Arguments: 36 | image_encoder (ImageEncoderViT): The backbone used to encode the 37 | image into image embeddings that allow for efficient mask prediction. 38 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 39 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 40 | and encoded prompts. 41 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 42 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 43 | """ 44 | super().__init__() 45 | self.image_encoder = image_encoder 46 | self.prompt_encoder = prompt_encoder 47 | self.mask_decoder = mask_decoder 48 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 49 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 50 | 51 | @property 52 | def device(self) -> Any: 53 | return self.pixel_mean.device 54 | 55 | def forward( 56 | self, 57 | batched_input: List[Dict[str, Any]], 58 | multimask_output: bool, 59 | ) -> List[Dict[str, torch.Tensor]]: 60 | """ 61 | Predicts masks end-to-end from provided images and prompts. 62 | If prompts are not known in advance, using SamPredictor is 63 | recommended over calling the model directly. 64 | 65 | Arguments: 66 | batched_input (list(dict)): A list over input images, each a 67 | dictionary with the following keys. A prompt key can be 68 | excluded if it is not present. 69 | 'image': The image as a torch tensor in 3xHxW format, 70 | already transformed for input to the model. 71 | 'original_size': (tuple(int, int)) The original size of 72 | the image before transformation, as (H, W). 73 | 'point_coords': (torch.Tensor) Batched point prompts for 74 | this image, with shape BxNx2. Already transformed to the 75 | input frame of the model. 76 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 77 | with shape BxN. 78 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 79 | Already transformed to the input frame of the model. 80 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 81 | in the form Bx1xHxW. 82 | multimask_output (bool): Whether the model should predict multiple 83 | disambiguating masks, or return a single mask. 84 | 85 | Returns: 86 | (list(dict)): A list over input images, where each element is 87 | as dictionary with the following keys. 88 | 'masks': (torch.Tensor) Batched binary mask predictions, 89 | with shape BxCxHxW, where B is the number of input promts, 90 | C is determiend by multimask_output, and (H, W) is the 91 | original size of the image. 92 | 'iou_predictions': (torch.Tensor) The model's predictions 93 | of mask quality, in shape BxC. 94 | 'low_res_logits': (torch.Tensor) Low resolution logits with 95 | shape BxCxHxW, where H=W=256. Can be passed as mask input 96 | to subsequent iterations of prediction. 97 | """ 98 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 99 | image_embeddings = self.image_encoder(input_images) 100 | 101 | outputs = [] 102 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | ) 119 | masks = self.postprocess_masks( 120 | low_res_masks, 121 | input_size=image_record["image"].shape[-2:], 122 | original_size=image_record["original_size"], 123 | ) 124 | masks = masks > self.mask_threshold 125 | outputs.append( 126 | { 127 | "masks": masks, 128 | "iou_predictions": iou_predictions, 129 | "low_res_logits": low_res_masks, 130 | } 131 | ) 132 | return outputs 133 | 134 | def postprocess_masks( 135 | self, 136 | masks: torch.Tensor, 137 | input_size: Tuple[int, ...], 138 | original_size: Tuple[int, ...], 139 | ) -> torch.Tensor: 140 | """ 141 | Remove padding and upscale masks to the original image size. 142 | 143 | Arguments: 144 | masks (torch.Tensor): Batched masks from the mask_decoder, 145 | in BxCxHxW format. 146 | input_size (tuple(int, int)): The size of the image input to the 147 | model, in (H, W) format. Used to remove padding. 148 | original_size (tuple(int, int)): The original size of the image 149 | before resizing for input to the model, in (H, W) format. 150 | 151 | Returns: 152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 153 | is given by original_size. 154 | """ 155 | masks = F.interpolate( 156 | masks, 157 | (self.image_encoder.img_size, self.image_encoder.img_size), 158 | mode="bilinear", 159 | align_corners=False, 160 | ) 161 | masks = masks[..., : input_size[0], : input_size[1]] 162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 163 | return masks 164 | 165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 166 | """Normalize pixel values and pad to a square input.""" 167 | # Normalize colors 168 | x = (x - self.pixel_mean) / self.pixel_std 169 | 170 | # Pad 171 | h, w = x.shape[-2:] 172 | padh = self.image_encoder.img_size - h 173 | padw = self.image_encoder.img_size - w 174 | x = F.pad(x, (0, padw, 0, padh)) 175 | return x 176 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from extend_sam.segment_anything_ori.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks = masks[0].detach().cpu().numpy() 164 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 166 | return masks, iou_predictions, low_res_masks 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | box (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecesary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /extend_sam/segment_anything_ori/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /extend_sam/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @copyright ziqi-jin 3 | ''' 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import os.path as osp 9 | import os 10 | 11 | 12 | def fix_params(model): 13 | for name, param in model.named_parameters(): 14 | param.requires_grad = False 15 | 16 | 17 | def load_params(model, params): 18 | pass 19 | 20 | 21 | def get_opt_pamams(model, lr_list, group_keys, wd_list): 22 | ''' 23 | 24 | :param model: model 25 | :param lr_list: list, contain the lr for each params group 26 | :param wd_list: list, contain the weight decay for each params group 27 | :param group_keys: list of list, according to the sub list to divide params to different groups 28 | :return: list of dict 29 | ''' 30 | assert len(lr_list) == len(group_keys), "lr_list should has the same length as group_keys" 31 | assert len(lr_list) == len(wd_list), "lr_list should has the same length as wd_list" 32 | params_group = [[] for _ in range(len(lr_list))] 33 | for name, value in model.named_parameters(): 34 | for index, g_keys in enumerate(group_keys): 35 | for g_key in g_keys: 36 | if g_key in name: 37 | params_group[index].append(value) 38 | return [{'params': params_group[i], 'lr': lr_list[i], 'weight_decay': wd_list[i]} for i in range(len(lr_list))] 39 | 40 | 41 | class Timer: 42 | 43 | def __init__(self): 44 | self.start_time = 0.0 45 | self.end_time = 0.0 46 | 47 | self.start() 48 | 49 | def start(self): 50 | self.start_time = time.time() 51 | 52 | def end(self, ms=False, clear=False): 53 | self.end_time = time.time() 54 | 55 | if ms: 56 | duration = int((self.end_time - self.start_time) * 1000) 57 | else: 58 | duration = int(self.end_time - self.start_time) 59 | 60 | if clear: 61 | self.start() 62 | 63 | return duration 64 | 65 | 66 | class Average_Meter: 67 | def __init__(self, keys): 68 | self.keys = keys 69 | self.clear() 70 | 71 | def add(self, dic): 72 | for key, value in dic.items(): 73 | self.data_dic[key].append(value) 74 | 75 | def get(self, keys=None, clear=False): 76 | if keys is None: 77 | keys = self.keys 78 | 79 | dataset = {} 80 | for key in keys: 81 | dataset[key] = float(np.mean(self.data_dic[key])) 82 | 83 | if clear: 84 | self.clear() 85 | 86 | return dataset 87 | 88 | def clear(self): 89 | self.data_dic = {key: [] for key in self.keys} 90 | 91 | 92 | def print_and_save_log(message, path): 93 | print(message) 94 | 95 | with open(path, 'a+') as f: 96 | f.write(message + '\n') 97 | 98 | 99 | class mIoUOnline: 100 | def __init__(self, class_names): 101 | self.class_names = ['background'] + class_names 102 | self.class_num = len(self.class_names) 103 | 104 | self.clear() 105 | 106 | def get_data(self, pred_mask, gt_mask): 107 | obj_mask = gt_mask < 255 108 | correct_mask = (pred_mask == gt_mask) * obj_mask 109 | 110 | P_list, T_list, TP_list = [], [], [] 111 | for i in range(self.class_num): 112 | P_list.append(np.sum((pred_mask == i) * obj_mask)) 113 | T_list.append(np.sum((gt_mask == i) * obj_mask)) 114 | TP_list.append(np.sum((gt_mask == i) * correct_mask)) 115 | 116 | return (P_list, T_list, TP_list) 117 | 118 | def add_using_data(self, data): 119 | P_list, T_list, TP_list = data 120 | for i in range(self.class_num): 121 | self.P[i] += P_list[i] 122 | self.T[i] += T_list[i] 123 | self.TP[i] += TP_list[i] 124 | 125 | def add(self, pred_mask, gt_mask): 126 | obj_mask = gt_mask < 255 127 | correct_mask = (pred_mask == gt_mask) * obj_mask 128 | 129 | for i in range(self.class_num): 130 | self.P[i] += np.sum((pred_mask == i) * obj_mask) 131 | self.T[i] += np.sum((gt_mask == i) * obj_mask) 132 | self.TP[i] += np.sum((gt_mask == i) * correct_mask) 133 | 134 | def get(self, detail=False, clear=True): 135 | IoU_dic = {} 136 | IoU_list = [] 137 | 138 | FP_list = [] # over activation 139 | FN_list = [] # under activation 140 | 141 | for i in range(self.class_num): 142 | IoU = self.TP[i] / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) * 100 143 | FP = (self.P[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) 144 | FN = (self.T[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) 145 | 146 | IoU_dic[self.class_names[i]] = IoU 147 | 148 | IoU_list.append(IoU) 149 | FP_list.append(FP) 150 | FN_list.append(FN) 151 | 152 | mIoU = np.mean(np.asarray(IoU_list)) 153 | mIoU_foreground = np.mean(np.asarray(IoU_list)[1:]) 154 | 155 | FP = np.mean(np.asarray(FP_list)) 156 | FN = np.mean(np.asarray(FN_list)) 157 | 158 | if clear: 159 | self.clear() 160 | 161 | if detail: 162 | return mIoU, mIoU_foreground, IoU_dic, FP, FN 163 | else: 164 | return mIoU, mIoU_foreground 165 | 166 | def clear(self): 167 | self.TP = [] 168 | self.P = [] 169 | self.T = [] 170 | 171 | for _ in range(self.class_num): 172 | self.TP.append(0) 173 | self.P.append(0) 174 | self.T.append(0) 175 | 176 | 177 | def get_numpy_from_tensor(tensor): 178 | return tensor.cpu().detach().numpy() 179 | 180 | 181 | def save_model(model, model_path, parallel=False, is_final=False): 182 | if is_final: 183 | model_path_split = model_path.split('.') 184 | model_path = model_path_split[0] + "_final.pth" 185 | if parallel: 186 | torch.save(model.module.state_dict(), model_path) 187 | else: 188 | torch.save(model.state_dict(), model_path) 189 | 190 | 191 | def write_log(iteration, log_path, log_data, status, writer, timer): 192 | log_data['iteration'] = iteration 193 | log_data['time'] = timer.end(clear=True) 194 | message = "iteration : {val}, ".format(val=log_data['iteration']) 195 | for key, value in log_data.items(): 196 | if key == 'iteration': 197 | continue 198 | message += "{key} : {val}, ".format(key=key, val=value) 199 | message = message[:-2] # + '\n' 200 | print_and_save_log(message, log_path) 201 | # visualize 202 | if writer is not None: 203 | for key, value in log_data.items(): 204 | writer.add_scalar("{status}/{key}".format(status=status, key=key), value, iteration) 205 | 206 | 207 | def check_folder(file_path, is_folder=False): 208 | ''' 209 | 210 | :param file_path: the path of file, default input is a complete file name with dir path. 211 | :param is_folder: if the input is a dir, not a file_name, is_folder should be True 212 | :return: no return, this function will check and judge whether need to make dirs. 213 | ''' 214 | if is_folder: 215 | if not osp.exists(is_folder): 216 | os.makedirs(file_path) 217 | 218 | else: 219 | splits = file_path.split("/") 220 | folder_name = "/".join(splits[:-1]) 221 | if not osp.exists(folder_name): 222 | os.makedirs(folder_name) 223 | 224 | 225 | def one_hot_embedding_3d(labels, class_num=21): 226 | ''' 227 | 228 | :param real_labels: B H W 229 | :param class_num: N 230 | :return: B N H W 231 | ''' 232 | one_hot_labels = labels.clone() 233 | one_hot_labels[one_hot_labels == 255] = 0 # 0 is background 234 | return F.one_hot(one_hot_labels, num_classes=class_num).permute(0, 3, 1, 2).contiguous().float() 235 | -------------------------------------------------------------------------------- /how_to_use_finetune_anything.md: -------------------------------------------------------------------------------- 1 | # How to use finetune-anything 2 | finetune-anything (FA) is intended as a tool to help users quickly build extended SAM models. It not only supports the built-in basic tasks and basic models, but also supports user-defined extensions of different modules, training processes, and datasets for the extend SAM. 3 | 4 | - Content 5 | - [Structure](#Structure) 6 | - [Model](#Model) 7 | - [Datasets](#Datasets) 8 | - [Losses](#Losses) 9 | - [Optimizer](#Optimizer) 10 | - [Runner](#Runner) 11 | - [Logger](#Logger) 12 | - [One more thing](#One-more-thing) 13 | 14 | 15 | ## Structure 16 | Using FA can be divided into two parts: training and testing. The training part includes [model](#Model), [Datasets](#Datasets), [Losses](#Losses), [Optimizer](#Optimizer), [Logger](#Logger), and [Runner](#Runner). 17 | The above content needs to be configured through the yaml file in `config`. 18 | - The tasks already supported by FA can be trained and tested directly by inputting `task_name`. 19 | ``` 20 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --task_name ${one of supported task names} 21 | ``` 22 | - Custom configuration files can be trained and tested by reading `cfg` 23 | ``` 24 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --cfg config/${yaml file name} 25 | ``` 26 | The testing part is coming soon ~ 27 | 28 | ## Model 29 | The SAM model includes image encdoer, prompt encoder and mask decoder. FA further encapsulates the encoder and decoder of SAM and identify Extend-SAM model consists of image encoder adapter, prompt encoder adapter and mask decoder adapter. The initialized process of Extend-SAM as below, 30 | 31 | 32 | Users can choose the adapter that need to be fixed or learned during the finetune process. This function can be configured in the `model` part of the yaml file, as shown in the following example: 33 | 34 | ```yaml 35 | model: 36 | sam_name: 'extend sam name' # e.g., 'sem_sam', custom SAM model name, you should implement this model('sem_sam') first 37 | params: 38 | # Fix the a part of parameters in SAM 39 | fix_img_en: True # fix image encoder adapter parameters 40 | fix_prompt_en: True # fix prompt encoder adapter parameters 41 | fix_mask_de: False # unfix mask decoder adapter parameters to learn 42 | ckpt_path: 'your original sam weights' # e.g., 'sam_ckpt/sam_vit_b_01ec64.pth' 43 | class_num: 21 # number of classes for your dataset(20) + background(1) 44 | model_type: 'vit_b' # type should be in [vit_h, vit_b, vit_l, default], this is original SAM type 45 | # related to different original SAM model. the type should be corresponded to the ckpt_path 46 | ``` 47 | ### Customized Model 48 | If you need to redesign the structure of a certain module of SAM, you need to write code according to the following three steps. Take [SemanticSAM](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/extend_sam.py#L43) as an example. 49 | - step1 50 | 51 | First, inherit the corresponding adapter base class in `extend_sam\xxx_(encoder or decoder)_adapter.py`, and then implement the `__init__` and `forward` function corresponding to the adapter. 52 | ```python 53 | class SemMaskDecoderAdapter(BaseMaskDecoderAdapter): 54 | def __init__(self, ori_sam: Sam, fix=False, class_num=20): 55 | super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix) # init super class 56 | self.decoder_neck = MaskDecoderNeck(...) # custom module 57 | self.decoder_head = SemSegHead(...) # custom module 58 | # pair the params between ori mask_decoder and new mask_decoder_adapter 59 | self.pair_params(self.decoder_neck) # give the weights which are with the same name in original SAM to customized module 60 | self.pair_params(self.decoder_head) 61 | 62 | def forward(self, ...): 63 | ... = self.decoder_neck(...) 64 | masks, iou_pred = self.decoder_head(...) 65 | return masks, iou_pred 66 | ``` 67 | - step2 68 | 69 | First inherit the BaseExtendSAM base class in [extend_sam.py](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/extend_sam.py#L43), and make necessary modifications to `__init__` function. 70 | ```python 71 | class SemanticSam(BaseExtendSam): 72 | 73 | def __init__(self, ...): 74 | super().__init__(...) # init super class 75 | self.mask_adapter = SemMaskDecoderAdapter(...) # replace original Adapter as the new identified customized Adapter 76 | ``` 77 | - step3 78 | 79 | Add new Extend-SAM class to [AVAI_MODEL](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/__init__.py#L10) dict and give it a key. 80 | then you can train this new model by modify the `sam_name` in config file. 81 | 82 | ## Datasets 83 | 84 | FA comes with datasets for multiple tasks, and also supports custom datasets, and sets the training and test datasets separately. Takes `torch_voc_sem` as an example, the configuration file of the dataset part is as follows, 85 | The dataset part includes `name`, `params`, `transforms` and `target_transforms`, 86 | The `params` which is a `dict` include the key and value your want to set about the init function's parameters of corresponding dataset. make sure the dataset has parameters with the same names as the key. 87 | `transforms` and `target_transforms` respectively correspond to the input image and Ground Truth for transform processing. 88 | `transforms/target_transforms` support to set the implemented transform function and the corresponding `params`, `params` are still in the form of a `dict`, and transform will process the datasets according to the input order of the configuration file. 89 | ```yaml 90 | # Dataset 91 | dataset: 92 | name: 'torch_voc_sem' 93 | params: 94 | root: '/your/dataset/path/' 95 | year: '2012' 96 | image_set: 'train' 97 | transforms: 98 | resize: 99 | params: 100 | size: [1024, 1024] 101 | to_tensor: 102 | params: ~ # no parameters, set to '~' 103 | target_transforms: 104 | resize: 105 | params: 106 | size: [1024, 1024] 107 | ``` 108 | 109 | ### Customized Dataset 110 | 111 | ### Customized Transform 112 | 113 | If you want to customize the transform, you can follow the following three steps, 114 | 115 | - step1 116 | 117 | - Torch-supported transform, skip this step. 118 | 119 | - Torch-unsupported transform 120 | 121 | Create it in [datasets/transforms.py](https://github.com/ziqi-jin/finetune-anything/blob/main/datasets/transforms.py), implement the `__init__` and `forward` function. 122 | 123 | ```python 124 | import torch.nn as nn 125 | class CustomTransform(nn.Module): 126 | def __init__(self): 127 | # identify your init process here 128 | def forward(self): 129 | # identify your transform process here 130 | ``` 131 | 132 | 133 | - step2 134 | 135 | Import torch-supported transform you want or torch-unsupported transform your identify in [datasets/transforms.py](https://github.com/ziqi-jin/finetune-anything/blob/main/datasets/transforms.py). 136 | Then add this transform into the AVIAL_TRANSFORM dict, give this transform a key like `resize`, and the value is the transform class. 137 | 138 | ```python 139 | import torchvision.transforms as T 140 | AVIAL_TRANSFORM = {'your_transform_name': T.XXX, 'your_transform_name': CustomTransform} 141 | ``` 142 | 143 | - step3 144 | 145 | Set the loss in your config file. 146 | ```yaml 147 | transforms: 148 | your_transform_name: 149 | params: # if there are parameters of the transform's __init__ function to be set. else set to '~' 150 | params_1: xxx 151 | params_2: xxx 152 | ``` 153 | 154 | ## Losses 155 | 156 | FA supports multiple torch loss functions, and also allows users to customize the loss function. The configuration content of the loss function part is as below, 157 | ```yaml 158 | losses: 159 | ce: 160 | weight: 0.5 161 | params: # the initial params of loss could be identified here 162 | ignore_index: 255 163 | label_one_hot: False 164 | mse: 165 | weight: 5.0 166 | params: ~ # no parameters, set '~' 167 | label_one_hot: True 168 | ``` 169 | Now loss part has `weight`, `params`, and `label_one_hot` keys, `weight` control the weight of each loss in total loss. Take the config above as example, assume the `ce` loss as $Loss_{ce}$ and the `mse` as $Loss_{mse}$, the final total loss as below, 170 | 171 | $$ 172 | Loss_{total} = weight_{ce} \times Loss_{ce} + weight_{mse} \times Loss_{mse} = 0.5 \times Loss_{ce} + 5.0 \times Loss_{mse} 173 | $$ 174 | 175 | The `params` which is a `dict` include the key and value your want to set about the corresponding loss function's parameters, make sure the loss function has parameters with the same names as the key. if you don't need the set params, give params `~`. 176 | for semantic segmentation task, if your loss function need a one hot label, set the `label_one_hot` to `True`. 177 | 178 | 179 | ### Customized Losses 180 | 181 | If you want to customize the loss function, you can follow the following three steps, 182 | 183 | - step1 184 | 185 | - Torch-supported Loss, skip this step. 186 | 187 | - Torch-unsupported Loss 188 | 189 | Create it in [loss.py](https://github.com/ziqi-jin/finetune-anything/blob/main/losses/losses.py), implement the `__init__` and `forward` function. 190 | 191 | ```python 192 | import torch.nn as nn 193 | class CustormLoss(nn.Module): 194 | def __init__(self,xxx): 195 | # identify your init process here 196 | def forward(self, x, y, xxx): 197 | # identify your forward process here 198 | ``` 199 | 200 | 201 | - step2 202 | 203 | Import torch-supported loss you want or torch-unsupported loss your identify in [losses/\_\_init\_\_,py](https://github.com/ziqi-jin/finetune-anything/blob/26b9ebd1b035a2f0ec8ce4e358eac79de7e263a2/losses/__init__.py#L2). 204 | Then add this loss into the AVAI_LOSS dict, give this loss a key like `ce`, and the value is the loss function. 205 | 206 | ```python 207 | 208 | import torch.nn as nn 209 | from .losses import YourCuntomLoss 210 | AVAI_LOSS = {'your loss key': YourCuntomLoss, 'your loss key': nn.xxxLoss} 211 | ``` 212 | 213 | - step3 214 | 215 | Set the loss in your config file. 216 | 217 | ```yaml 218 | losses: 219 | your_loss_key: 220 | weight: your_weight # float 221 | params: 222 | your_loss_param1: xx 223 | your_loss_param2: xx 224 | label_one_hot: False 225 | ``` 226 | 227 | ## Optimizer 228 | FA's optimizer supports setting learning_rate(`lr`) and weight_decay(`wd`) for any module in the adapter that is not fixed. 229 | User could use keyword `sgd`, `adam`, and `adamw` to set the optimizer. the `opt_params` save necessary params for each kind of optimizer. 230 | - Normal module setting 231 | 232 | `lr_default` save the default learing rate for all unfixed params, `wd_default` save the default weight decay for all unfixed params, 233 | `momentum` save the momentum for optimizer. if the corresponding optimizer has no parameter, e.g., `adam` has no `momentum`, just set the `momentum` to `~`. 234 | - Specific module setting 235 | 236 | The left three params `group_keys`, `lr_list` and `wd_list` is for specific module. 237 | They are list have the same length and correspond to the module name, learning rate and weight decay respectively. 238 | for example, if you want to give `mask_adapter.decoder_head.output_hypernetworks_mlps` module a specific optimizing parameter, put it into `group_keys` as a list first, and then set the corresponding learning rate and weight decay into `lr_list` and `wd_list`. 239 | If there are multiple modules that need to use the same specific parameter setting, just add the key to the corresponding list in the `group_keys`. For example, add `modulexxx` to the first list of `group_keys`. 240 | ```yaml 241 | # Optimizer 242 | opt_params: 243 | lr_default: 1e-3 244 | wd_default: 1e-4 245 | momentum: 0.9 246 | group_keys: [ [ 'mask_adapter.decoder_head.output_hypernetworks_mlps', 'modulexxx' ], ['second_module'], ] 247 | lr_list: [ 1e-2, 1e-4, ] 248 | wd_list: [ 0.0, 0.1, ] 249 | opt_name: 'sgd' # 'sgd' 250 | scheduler_name: 'cosine' 251 | ``` 252 | FA also supports multiple schedulers, which can be set using the keyword `single_step`, `multi_step`, `warmup_multi_step`, `cosine`, `linear`. 253 | ## Runner 254 | 255 | ## Logger 256 | As shown in the config file, FA provides two kinds of loggers, one is the log output by default and will be saved in `log_folder`, and the other is the log output of tensorboard saved in `tensorboard_folder` when `use_tensorboard` is `True`. 257 | The best model will be saved in `model_folder`. 258 | ```yaml 259 | # Logger 260 | use_tensorboard: True 261 | tensorboard_folder: './experiment/tensorboard' 262 | log_folder: './experiment/log' 263 | model_folder: './experiment/model' 264 | ``` 265 | 266 | ## One more thing 267 | 268 | If you need to use loss, dataset, or other functions that are not supported by FA, please submit an issue, and I will help you to implement them. At the same time, developers are also welcome to develop new loss, dataset or other new functions for FA, please submit your PR (pull requests). -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .losses import CustormLoss 3 | 4 | AVAI_LOSS = {'ce': nn.CrossEntropyLoss, 'multi_label_soft_margin': nn.MultiLabelSoftMarginLoss, 5 | 'test_custom': CustormLoss, 'mse': nn.MSELoss} 6 | 7 | 8 | def get_losses(losses): 9 | loss_dict = {} 10 | for name in losses: 11 | assert name in AVAI_LOSS, print('{name} is not supported, please implement it first.'.format(name=name)) 12 | if losses[name].params is not None: 13 | loss_dict[name] = AVAI_LOSS[name](**losses[name].params) 14 | else: 15 | loss_dict[name] = AVAI_LOSS[name]() 16 | return loss_dict 17 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @copyright ziqi-jin 3 | You can create custom loss function in this file, then import the created loss in ./__init__.py and add the loss into AVAI_LOSS 4 | ''' 5 | import torch.nn as nn 6 | 7 | 8 | # example 9 | class CustormLoss(nn.Module): 10 | def __init__(self): 11 | pass 12 | 13 | def forward(self, x, y): 14 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.2 2 | omegaconf==2.3.0 3 | opencv_python==4.7.0.72 4 | pandas==2.0.1 5 | Pillow==9.5.0 6 | torch==1.7.1 7 | torchvision==0.8.2 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @copyright ziqi-jin 3 | ''' 4 | import argparse 5 | from omegaconf import OmegaConf 6 | from torch.utils.data import DataLoader 7 | from datasets import get_dataset 8 | from losses import get_losses 9 | from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner 10 | 11 | supported_tasks = ['detection', 'semantic_seg', 'instance_seg'] 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--task_name', default='semantic_seg', type=str) 14 | parser.add_argument('--cfg', default=None, type=str) 15 | 16 | if __name__ == '__main__': 17 | args = parser.parse_args() 18 | task_name = args.task_name 19 | if args.cfg is not None: 20 | config = OmegaConf.load(args.cfg) 21 | else: 22 | assert task_name in supported_tasks, "Please input the supported task name." 23 | config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name)) 24 | 25 | train_cfg = config.train 26 | val_cfg = config.val 27 | test_cfg = config.test 28 | 29 | train_dataset = get_dataset(train_cfg.dataset) 30 | train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers, 31 | drop_last=train_cfg.drop_last) 32 | val_dataset = get_dataset(val_cfg.dataset) 33 | val_loader = DataLoader(val_dataset, batch_size=val_cfg.bs, shuffle=False, num_workers=val_cfg.num_workers, 34 | drop_last=val_cfg.drop_last) 35 | losses = get_losses(losses=train_cfg.losses) 36 | # according the model name to get the adapted model 37 | model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params) 38 | opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys, 39 | wd_list=train_cfg.opt_params.wd_list) 40 | optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default, 41 | momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default) 42 | scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name) 43 | runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler) 44 | # train_step 45 | runner.train(train_cfg) 46 | if test_cfg.need_test: 47 | runner.test(test_cfg) 48 | --------------------------------------------------------------------------------