├── LICENSE
├── README.md
├── config
└── semantic_seg.yaml
├── datasets
├── __init__.py
├── detection.py
├── instance_seg.py
├── matting.py
├── semantic_seg.py
└── transforms.py
├── extend_sam
├── __init__.py
├── extend_sam.py
├── image_encoder_adapter.py
├── mask_decoder_adapter.py
├── mask_decoder_heads.py
├── mask_decoder_neck.py
├── prompt_encoder_adapter.py
├── runner.py
├── scheduler.py
├── segment_anything_ori
│ ├── __init__.py
│ ├── automatic_mask_generator.py
│ ├── build_sam.py
│ ├── modeling
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── image_encoder.py
│ │ ├── mask_decoder.py
│ │ ├── prompt_encoder.py
│ │ ├── sam.py
│ │ └── transformer.py
│ ├── predictor.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── onnx.py
│ │ └── transforms.py
└── utils.py
├── how_to_use_finetune_anything.md
├── losses
├── __init__.py
└── losses.py
├── requirements.txt
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ziqi-jin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | The [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) has revolutionized computer vision. Relying on fine-tuning of SAM will solve a large number of basic computer vision tasks. We are designing a **class-aware one-stage** tool for training fine-tuning models based on SAM.
4 |
5 | You need to supply the datasets for your tasks and the [supported task](#Supported-Tasks) name, this tool will help you to get a finetuned model for your task. You are also allowed to design your own extend-SAM model, and FA supply the training, testing and deploy process for you.
6 |
7 |
8 |
9 | ## Design
10 | Finetune-Anything further encapsulates the three parts of the original SAM, i.e., Image Encoder Adapter, Prompt Encoder Adapter, and Mask Decoder Adatper. We will support the base extend-SAM model for each task. Users also could design your own customized modules in each adapter, use FA to design different adapters, and set whether the parameters of any module are fixed. For modules with unfixed parameters, parameters such as `lr`, `weight decay` can be set to coordinate with the fine-tuning of the model.
11 | check details in [How_to_use](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md).
12 | For example, MaskDecoder is encapsulated as MaskDecoderAdapter. The current MaskDecoderAdatper contains two parts, DecoderNeck and DecoderHead.
13 |
14 |
15 |
16 | ## Supported Tasks
17 | - [x] Semantic Segmentation
18 | - [x] train
19 | - [x] eval
20 | - [ ] test
21 | - [ ] Matting
22 | - [ ] Instance Segmentation
23 | - [ ] Detection
24 | ## Supported Datasets
25 | - [x] TorchVOCSegmentation
26 | - [x] BaseSemantic
27 | - [ ] BaseInstance
28 | - [ ] BaseMatting
29 |
30 | ## Deploy
31 | - [ ] Onnx export
32 |
33 | ## Support Plan
34 | FA will be updated in the following order,
35 |
36 | - Mattng (task)
37 | - Prompt Part (structure)
38 | - [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) (model)
39 | - Instance Segmentation (task)
40 |
41 | # Usage
42 | finetune-anything(FA) supports the entire training process of SAM model fine-tuning, including the modification of the model structure, as well as the model training, verification, and testing processes. For details, check the [How_to_use](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md), the [Quick Start](#Quick-Start) gives an example of quickly using FA to train a custom semantic segmentation model.
43 | ## Quick Start
44 | ### Install
45 | - Step1
46 | ```
47 | git clone https://github.com/ziqi-jin/finetune-anything.git
48 | cd finetune-anything
49 | pip install -r requirements.txt
50 | ```
51 | - Step2
52 | Download the SAM weights from [SAM repository](https://github.com/facebookresearch/segment-anything#model-checkpoints)
53 |
54 | - Step3
55 | Modify the contents of yaml file for the specific task in **/config**, e.g., ckpt_path, model_type ...
56 |
57 | ### Train
58 | ```
59 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --task_name semantic_seg
60 | ```
61 |
62 | ## One more thing
63 |
64 | If you need to use loss, dataset, or other functions that are not supported by FA, please submit an issue, and I will help you to implement them. At the same time, developers are also welcome to develop new loss, dataset or other new functions for FA, please submit your PR (pull requests).
65 |
66 | ## Related Resources
67 |
68 | - [Documents](https://github.com/ziqi-jin/finetune-anything/blob/main/how_to_use_finetune_anything.md)
69 |
70 |
--------------------------------------------------------------------------------
/config/semantic_seg.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | experiment_name: 'semantic_sam'
3 |
4 | # Model
5 | model:
6 | sam_name: 'sem_sam'
7 | params:
8 | # Fix the a part of parameters in SAM
9 | fix_img_en: True
10 | fix_prompt_en: True
11 | fix_mask_de: False
12 | ckpt_path: 'sam_ckpt/sam_vit_b_01ec64.pth'
13 | class_num: 21 # 20 + 1
14 | model_type: 'vit_b' # type should be in [vit_h, vit_b, vit_l, default]
15 |
16 | # Dataset
17 | dataset:
18 | name: 'torch_voc_sem'
19 | params:
20 | root: '/data/jinziqi/DATASETS/'
21 | year: '2012'
22 | image_set: 'train'
23 | transforms:
24 | resize:
25 | params:
26 | size: [1024, 1024]
27 | to_tensor:
28 | params: ~
29 | target_transforms:
30 | resize:
31 | params:
32 | size: [1024, 1024]
33 |
34 | # Losses
35 | losses:
36 | ce:
37 | weight: 0.5
38 | params: # ~ means None type, the initial params of loss could be identified here
39 | ignore_index: 255
40 | label_one_hot: False
41 |
42 | # Optimizer
43 | opt_params:
44 | lr_default: 1e-3
45 | wd_default: 1e-4
46 | momentum: 0.9
47 | lr_list: [ 1e-2, ]
48 | group_keys: [ [ 'mask_adapter.decoder_head.output_hypernetworks_mlps', ], ]
49 | wd_list: [ 0.0, ]
50 | opt_name: 'sgd' # 'sgd'
51 | scheduler_name: 'cosine'
52 |
53 | # Runner
54 | max_iter: 100000
55 | log_iter: 20
56 | eval_iter: 200
57 | runner_name: 'sem_runner'
58 | # Dataloader
59 | bs: 8 # 8
60 | num_workers: 2
61 | drop_last: True
62 | # Logger
63 | use_tensorboard: True
64 | tensorboard_folder: './experiment/tensorboard'
65 | log_folder: './experiment/log'
66 | model_folder: './experiment/model'
67 |
68 | val:
69 | # Dataset
70 | dataset:
71 | name: 'torch_voc_sem'
72 | params:
73 | root: '/data/jinziqi/DATASETS/'
74 | year: '2012'
75 | image_set: 'train'
76 | transforms:
77 | resize:
78 | params:
79 | size: [1024, 1024]
80 | to_tensor:
81 | params: ~
82 | target_transforms:
83 | resize:
84 | params:
85 | size: [1024, 1024]
86 |
87 | bs: 8
88 | num_workers: 2
89 | drop_last: True
90 |
91 |
92 | test:
93 | need_test: False
94 |
95 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .detection import BaseDetectionDataset
2 | from .instance_seg import BaseInstanceDataset
3 | from .semantic_seg import BaseSemanticDataset, VOCSemanticDataset, TorchVOCSegmentation
4 | from .transforms import get_transforms
5 | from torchvision.datasets import VOCSegmentation
6 |
7 | segment_datasets = {'base_ins': BaseInstanceDataset, 'base_sem': BaseSemanticDataset,
8 | 'voc_sem': VOCSemanticDataset, 'torch_voc_sem': TorchVOCSegmentation}
9 | det_dataset = {'base_det': BaseDetectionDataset, }
10 |
11 |
12 | def get_dataset(cfg):
13 | name = cfg.name
14 | assert name in segment_datasets or name in det_dataset, \
15 | print('{name} is not supported, please implement it first.'.format(name=name))
16 | # TODO customized dataset params:
17 | # customized dataset params example:
18 | # if xxx:
19 | # param1 = cfg.xxx
20 | # param2 = cfg.xxx
21 | # return name_dict[name](path, model, param1, param2, ...)
22 | transform = get_transforms(cfg.transforms)
23 | if name in det_dataset:
24 | return det_dataset[name](**cfg.params, transform=transform)
25 | target_transform = get_transforms(cfg.target_transforms)
26 | return segment_datasets[name](**cfg.params, transform=transform, target_transform=target_transform)
27 |
28 |
29 | class Iterator:
30 | def __init__(self, loader):
31 | self.loader = loader
32 | self.init()
33 |
34 | def init(self):
35 | self.iterator = iter(self.loader)
36 |
37 | def get(self):
38 | try:
39 | data = next(self.iterator)
40 | except StopIteration:
41 | self.init()
42 | data = next(self.iterator)
43 |
44 | return data
45 |
--------------------------------------------------------------------------------
/datasets/detection.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class BaseDetectionDataset(Dataset):
5 | def __init__(self):
6 | assert False, print('BaseDetectionDataset is not Unimplemented.')
7 |
8 | def __getitem__(self, item):
9 | pass
10 |
--------------------------------------------------------------------------------
/datasets/instance_seg.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class BaseInstanceDataset(Dataset):
5 | def __init__(self):
6 | assert False, print("Unimplement Dataset.")
7 |
8 | def __getitem__(self, item):
9 | pass
10 |
--------------------------------------------------------------------------------
/datasets/matting.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 | from torchvision.datasets import VisionDataset
5 | import numpy as np
6 |
7 | class BaseMattingDataset(VisionDataset):
8 | """
9 | if you want to customize a new dataset to train the matting task,
10 | the img and mask file need be arranged as this sturcture.
11 | ├── data
12 | │ ├── my_dataset
13 | │ │ ├── img
14 | │ │ │ ├── train
15 | │ │ │ │ ├── xxx{img_suffix}
16 | │ │ │ │ ├── yyy{img_suffix}
17 | │ │ │ │ ├── zzz{img_suffix}
18 | │ │ │ ├── val
19 | │ │ ├── trimap
20 | │ │ │ ├── train
21 | │ │ │ │ ├── xxx{img_suffix}
22 | │ │ │ │ ├── yyy{img_suffix}
23 | │ │ │ │ ├── zzz{img_suffix}
24 | │ │ │ ├── val
25 | │ │ ├── ann
26 | │ │ │ ├── train
27 | │ │ │ │ ├── xxx{ann_suffix}
28 | │ │ │ │ ├── yyy{ann_suffix}
29 | │ │ │ │ ├── zzz{ann_suffix}
30 | │ │ │ ├── val
31 | """
32 |
33 | def __init__(self, metainfo, dataset_dir, transform, target_transform,
34 | trimap_transform=None,
35 | image_set='train',
36 | img_suffix='.jpg',
37 | ann_suffix='.png',
38 | trimap_suffix=None,
39 | data_prefix: dict = dict(img_path='img', ann_path='ann', trimap_path='trimap_pth'),
40 | return_dict=False):
41 | '''
42 |
43 | :param metainfo: meta data in original dataset, e.g. class_names
44 | :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above
45 | :param image_set: 'train' or 'val'
46 | :param img_suffix: your image suffix
47 | :param ann_suffix: your annotation suffix
48 | :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann'
49 | :param return_dict: return dict() or tuple(img, ann)
50 | '''
51 | super(BaseMattingDataset, self).__init__(root=dataset_dir, transform=transform,
52 | target_transform=target_transform)
53 |
54 | self.class_names = metainfo['class_names']
55 | self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set)
56 | self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set)
57 |
58 | print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format(
59 | img_folder_name=self.img_path, ann_folder_name=self.ann_path))
60 | self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if
61 | img_name.endswith(img_suffix)]
62 |
63 | self.has_trimap = trimap_suffix is not None
64 | if self.has_trimap:
65 | self.trimap_path = os.path.join(dataset_dir, data_prefix['trimap_pth'], image_set)
66 | print('trimap_folder_name: {trimap_folder_name}'.format(trimap_folder_name=self.trimap_path))
67 | self.img_suffix = img_suffix
68 | self.ann_suffix = ann_suffix
69 | self.return_dict = return_dict
70 | self.trimap_transform = trimap_transform
71 |
72 | def __getitem__(self, index):
73 | img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix))
74 | ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
75 | if self.transforms is not None:
76 | img, ann = self.transforms(img, ann)
77 | ann = np.array(ann)
78 | if self.has_trimap:
79 | ## return for self.has_trimpa==True
80 | trimap = Image.open(os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix))
81 | if self.trimap_transform:
82 | trimap = self.trimap_transform(trimap)
83 | else:
84 | print("Warnning: you may need set transform function for trimap input")
85 | if self.return_dict:
86 | data = dict(img_name=self.img_names[index], img=img, ann=ann, trimap=trimap,
87 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
88 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix),
89 | trimap_path=os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix))
90 | return data
91 | return img, ann, trimap
92 | else:
93 | ## return for self.has_trimpa==False
94 | if self.return_dict:
95 | data = dict(img_name=self.img_names[index], img=img, ann=ann,
96 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
97 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
98 | return data
99 | return img, ann
100 |
101 | def __len__(self):
102 | return len(self.img_names)
103 |
104 |
--------------------------------------------------------------------------------
/datasets/semantic_seg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from torch.utils.data import Dataset
4 | from torchvision.datasets import VOCSegmentation, VisionDataset
5 | import numpy as np
6 |
7 |
8 | class BaseSemanticDataset(VisionDataset):
9 | """
10 | if you want to customize a new dataset to train the segmentation task,
11 | the img and mask file need be arranged as this sturcture.
12 | ├── data
13 | │ ├── my_dataset
14 | │ │ ├── img
15 | │ │ │ ├── train
16 | │ │ │ │ ├── xxx{img_suffix}
17 | │ │ │ │ ├── yyy{img_suffix}
18 | │ │ │ │ ├── zzz{img_suffix}
19 | │ │ │ ├── val
20 | │ │ ├── ann
21 | │ │ │ ├── train
22 | │ │ │ │ ├── xxx{ann_suffix}
23 | │ │ │ │ ├── yyy{ann_suffix}
24 | │ │ │ │ ├── zzz{ann_suffix}
25 | │ │ │ ├── val
26 | """
27 |
28 | def __init__(self, metainfo, dataset_dir, transform, target_transform,
29 | image_set='train',
30 | img_suffix='.jpg',
31 | ann_suffix='.png',
32 | data_prefix: dict = dict(img_path='img', ann_path='ann'),
33 | return_dict=False):
34 | '''
35 |
36 | :param metainfo: meta data in original dataset, e.g. class_names
37 | :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above
38 | :param image_set: 'train' or 'val'
39 | :param img_suffix: your image suffix
40 | :param ann_suffix: your annotation suffix
41 | :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann'
42 | :param return_dict: return dict() or tuple(img, ann)
43 | '''
44 | super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform,
45 | target_transform=target_transform)
46 |
47 | self.class_names = metainfo['class_names']
48 | self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set)
49 | self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set)
50 | print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format(
51 | img_folder_name=self.img_path, ann_folder_name=self.ann_path))
52 | self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if
53 | img_name.endswith(img_suffix)]
54 | self.img_suffix = img_suffix
55 | self.ann_suffix = ann_suffix
56 | self.return_dict = return_dict
57 |
58 | def __getitem__(self, index):
59 | img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix))
60 | ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
61 | if self.transforms is not None:
62 | img, ann = self.transforms(img, ann)
63 | ann = np.array(ann)
64 |
65 | if self.return_dict:
66 | data = dict(img_name=self.img_names[index], img=img, ann=ann,
67 | img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
68 | ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
69 | return data
70 | return img, ann
71 |
72 | def __len__(self):
73 | return len(self.img_names)
74 |
75 |
76 | class VOCSemanticDataset(Dataset):
77 | def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False):
78 | super(VOCSemanticDataset, self).__init__()
79 | self.root_dir = root_dir
80 |
81 | self.image_dir = self.root_dir + 'JPEGImages/'
82 | self.xml_dir = self.root_dir + 'Annotations/'
83 | self.mask_dir = self.root_dir + 'SegmentationClass/'
84 |
85 | self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()]
86 | self.transform = transform
87 | self.with_id = with_id
88 | self.with_mask = with_mask
89 | self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
90 | 'bus', 'car', 'cat', 'chair', 'cow',
91 | 'diningtable', 'dog', 'horse', 'motorbike', 'person',
92 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
93 |
94 | def __len__(self):
95 | return len(self.image_id_list)
96 |
97 | def get_image(self, image_id):
98 | image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB')
99 | if self.transform is not None:
100 | image = self.transform(image)
101 | return image
102 |
103 | def get_mask(self, image_id):
104 | mask_path = self.mask_dir + image_id + '.png'
105 | if os.path.isfile(mask_path):
106 | mask = Image.open(mask_path)
107 | else:
108 | mask = None
109 | return mask
110 |
111 | def __getitem__(self, index):
112 | image_id = self.image_id_list[index]
113 |
114 | data_list = [self.get_image(image_id)]
115 |
116 | if self.with_id:
117 | data_list.append(image_id)
118 |
119 | if self.with_mask:
120 | data_list.append(self.get_mask(image_id))
121 |
122 | return data_list
123 |
124 |
125 | class TorchVOCSegmentation(VOCSegmentation):
126 | def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None):
127 | super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download,
128 | transform=transform, target_transform=target_transform)
129 | self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
130 | 'bus', 'car', 'cat', 'chair', 'cow',
131 | 'diningtable', 'dog', 'horse', 'motorbike', 'person',
132 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
133 |
134 | def __getitem__(self, index: int):
135 | """
136 | Args:
137 | index (int): Index
138 |
139 | Returns:
140 | tuple: (image, target) where target is the image segmentation.
141 | """
142 | img = Image.open(self.images[index]).convert('RGB')
143 | target = Image.open(self.masks[index])
144 |
145 | if self.transforms is not None:
146 | img, target = self.transforms(img, target)
147 |
148 | target = np.array(target)
149 | return img, target
150 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | from omegaconf.dictconfig import DictConfig
3 | import torch.nn as nn
4 |
5 | AVIAL_TRANSFORM = {'resize': T.Resize, 'to_tensor': T.ToTensor}
6 |
7 |
8 | def get_transforms(transforms: DictConfig):
9 | T_list = []
10 | for t_name in transforms.keys():
11 | assert t_name in AVIAL_TRANSFORM, "{T_name} is not supported transform, please implement it and add it to " \
12 | "AVIAL_TRANSFORM first.".format(T_name=t_name)
13 | if transforms[t_name].params is not None:
14 | T_list.append(AVIAL_TRANSFORM[t_name](**transforms[t_name].params))
15 | else:
16 | T_list.append(AVIAL_TRANSFORM[t_name]())
17 | return T.Compose(T_list)
18 |
19 |
20 | class CustomTransform(nn.Module):
21 | def __init__(self):
22 | pass
23 |
24 | def forward(self):
25 | pass
26 |
--------------------------------------------------------------------------------
/extend_sam/__init__.py:
--------------------------------------------------------------------------------
1 | # copyright ziqi-jin
2 | import torch
3 | from .extend_sam import BaseExtendSam, SemanticSam
4 | from .runner import BaseRunner, SemRunner
5 | # from .optimizer import BaseOptimizer
6 | from .scheduler import WarmupMultiStepLR
7 | from .utils import get_opt_pamams
8 |
9 | AVAI_SCH = ["single_step", "multi_step", "warmup_multi_step", "cosine", "linear"]
10 | AVAI_MODEL = {'base_sam': BaseExtendSam, 'sem_sam': SemanticSam}
11 | # AVAI_OPT = {'base_opt': BaseOptimizer, 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam}
12 | AVAI_OPT = {'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adamw': torch.optim.AdamW}
13 | AVAI_RUNNER = {'base_runner': BaseRunner, 'sem_runner': SemRunner}
14 |
15 |
16 | def get_model(model_name, **kwargs):
17 | if model_name not in AVAI_MODEL:
18 | print('not supported model name, please implement it first.')
19 | return AVAI_MODEL[model_name](**kwargs).cuda()
20 |
21 |
22 | def get_optimizer(opt_name, **kwargs):
23 | if opt_name not in AVAI_OPT:
24 | print('not supported optimizer name, please implement it first.')
25 | return AVAI_OPT[opt_name](**{k: v for k, v in kwargs.items() if v is not None})
26 |
27 |
28 | def get_runner(runner_name):
29 | if runner_name not in AVAI_RUNNER:
30 | print('not supported runner name, please implement it first.')
31 | return AVAI_RUNNER[runner_name]
32 |
33 |
34 | def get_scheduler(
35 | optimizer,
36 | lr_scheduler="single_step",
37 | stepsize=1,
38 | gamma=0.1,
39 | warmup_factor=0.01,
40 | warmup_steps=10,
41 | max_epoch=1,
42 | n_epochs_init=50,
43 | n_epochs_decay=50,
44 |
45 | ):
46 | """A function wrapper for building a learning rate scheduler.
47 | Args:
48 | optimizer (Optimizer): an Optimizer.
49 | lr_scheduler (str, optional): learning rate scheduler method. Default is
50 | single_step.
51 | stepsize (int or list, optional): step size to decay learning rate.
52 | When ``lr_scheduler`` is "single_step", ``stepsize`` should be an integer.
53 | When ``lr_scheduler`` is "multi_step", ``stepsize`` is a list. Default is 1.
54 | gamma (float, optional): decay rate. Default is 0.1.
55 | max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1.
56 | Examples::
57 | >>> # Decay learning rate by every 20 epochs.
58 | >>> scheduler = get_scheduler(
59 | >>> optimizer, lr_scheduler='single_step', stepsize=20
60 | >>> )
61 | >>> # Decay learning rate at 30, 50 and 55 epochs.
62 | >>> scheduler = get_scheduler(
63 | >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55]
64 | >>> )
65 | """
66 | if lr_scheduler not in AVAI_SCH:
67 | raise ValueError(
68 | "Unsupported scheduler: {}. Must be one of {}".format(
69 | lr_scheduler, AVAI_SCH
70 | )
71 | )
72 |
73 | if lr_scheduler == "single_step":
74 | if isinstance(stepsize, list):
75 | stepsize = stepsize[-1]
76 |
77 | if not isinstance(stepsize, int):
78 | raise TypeError(
79 | "For single_step lr_scheduler, stepsize must "
80 | "be an integer, but got {}".format(type(stepsize))
81 | )
82 |
83 | scheduler = torch.optim.lr_scheduler.StepLR(
84 | optimizer, step_size=stepsize, gamma=gamma
85 | )
86 |
87 | elif lr_scheduler == "multi_step":
88 | if not isinstance(stepsize, list):
89 | raise TypeError(
90 | "For multi_step lr_scheduler, stepsize must "
91 | "be a list, but got {}".format(type(stepsize))
92 | )
93 |
94 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
95 | optimizer, milestones=stepsize, gamma=gamma
96 | )
97 |
98 | elif lr_scheduler == "warmup_multi_step":
99 | if not isinstance(stepsize, list):
100 | raise TypeError(
101 | "For warmup multi_step lr_scheduler, stepsize must "
102 | "be a list, but got {}".format(type(stepsize))
103 | )
104 |
105 | scheduler = WarmupMultiStepLR(
106 | optimizer,
107 | milestones=stepsize,
108 | gamma=gamma,
109 | warmup_factor=warmup_factor,
110 | warmup_iters=warmup_steps,
111 | )
112 |
113 | elif lr_scheduler == "cosine":
114 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
115 | optimizer, int(max_epoch)
116 | )
117 |
118 | elif lr_scheduler == "linear":
119 | def lambda_rule(epoch):
120 | lr_l = 1.0 - max(0, epoch - n_epochs_init) / float(n_epochs_decay + 1)
121 | return lr_l
122 |
123 | scheduler = torch.optim.lr_scheduler.LambdaLR(
124 | optimizer, lr_lambda=lambda_rule
125 | )
126 |
127 | return scheduler
128 |
--------------------------------------------------------------------------------
/extend_sam/extend_sam.py:
--------------------------------------------------------------------------------
1 | # copyright ziqi-jin
2 | import torch
3 | import torch.nn as nn
4 | from .segment_anything_ori import sam_model_registry
5 | from .image_encoder_adapter import BaseImgEncodeAdapter
6 | from .mask_decoder_adapter import BaseMaskDecoderAdapter, SemMaskDecoderAdapter
7 | from .prompt_encoder_adapter import BasePromptEncodeAdapter
8 |
9 |
10 | class BaseExtendSam(nn.Module):
11 |
12 | def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type='vit_b'):
13 | super(BaseExtendSam, self).__init__()
14 | assert model_type in ['default', 'vit_b', 'vit_l', 'vit_h'], print(
15 | "Wrong model_type, SAM only can be built as vit_b, vot_l, vit_h and default ")
16 | self.ori_sam = sam_model_registry[model_type](ckpt_path)
17 | self.img_adapter = BaseImgEncodeAdapter(self.ori_sam, fix=fix_img_en)
18 | self.prompt_adapter = BasePromptEncodeAdapter(self.ori_sam, fix=fix_prompt_en)
19 | self.mask_adapter = BaseMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de)
20 |
21 | def forward(self, img):
22 | x = self.img_adapter(img)
23 | points = None
24 | boxes = None
25 | masks = None
26 |
27 | sparse_embeddings, dense_embeddings = self.prompt_adapter(
28 | points=points,
29 | boxes=boxes,
30 | masks=masks,
31 | )
32 | multimask_output = True
33 | low_res_masks, iou_predictions = self.mask_adapter(
34 | image_embeddings=x,
35 | prompt_adapter=self.prompt_adapter,
36 | sparse_embeddings=sparse_embeddings,
37 | dense_embeddings=dense_embeddings,
38 | multimask_output=multimask_output,
39 | )
40 | return low_res_masks, iou_predictions
41 |
42 |
43 | class SemanticSam(BaseExtendSam):
44 |
45 | def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, class_num=20, model_type='vit_b'):
46 | super().__init__(ckpt_path=ckpt_path, fix_img_en=fix_img_en, fix_prompt_en=fix_prompt_en,
47 | fix_mask_de=fix_mask_de, model_type=model_type)
48 | self.mask_adapter = SemMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de, class_num=class_num)
49 |
--------------------------------------------------------------------------------
/extend_sam/image_encoder_adapter.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .segment_anything_ori.modeling.sam import Sam
3 | from .utils import fix_params
4 |
5 |
6 | class BaseImgEncodeAdapter(nn.Module):
7 |
8 | def __init__(self, ori_sam: Sam, fix=False):
9 | super(BaseImgEncodeAdapter, self).__init__()
10 | self.sam_img_encoder = ori_sam.image_encoder
11 | if fix:
12 | fix_params(self.sam_img_encoder)
13 |
14 | def forward(self, x):
15 | x = self.sam_img_encoder(x)
16 | return x
17 |
--------------------------------------------------------------------------------
/extend_sam/mask_decoder_adapter.py:
--------------------------------------------------------------------------------
1 | # @copyright ziqi-jin
2 |
3 | import torch.nn as nn
4 | import torch
5 | from .segment_anything_ori.modeling.sam import Sam
6 | from .utils import fix_params
7 | from .segment_anything_ori.modeling.mask_decoder import MaskDecoder
8 | from typing import List, Tuple
9 | from torch.nn import functional as F
10 | from .mask_decoder_heads import SemSegHead
11 | from .mask_decoder_neck import MaskDecoderNeck
12 |
13 |
14 | class BaseMaskDecoderAdapter(MaskDecoder):
15 | '''
16 | multimask_output (bool): If true, the model will return three masks.
17 | For ambiguous input prompts (such as a single click), this will often
18 | produce better masks than a single prediction. If only a single
19 | mask is needed, the model's predicted quality score can be used
20 | to select the best mask. For non-ambiguous prompts, such as multiple
21 | input prompts, multimask_output=False can give better results.
22 | '''
23 |
24 | # is fix and load params
25 | def __init__(self, ori_sam: Sam, fix=False):
26 | super(BaseMaskDecoderAdapter, self).__init__(transformer_dim=ori_sam.mask_decoder.transformer_dim,
27 | transformer=ori_sam.mask_decoder.transformer)
28 | self.sam_mask_decoder = ori_sam.mask_decoder
29 | if fix:
30 | fix_params(self.sam_mask_decoder) # move to runner to implement
31 |
32 | def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True):
33 | low_res_masks, iou_predictions = self.sam_mask_decoder(image_embeddings=image_embeddings,
34 | image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
35 | sparse_prompt_embeddings=sparse_embeddings,
36 | dense_prompt_embeddings=dense_embeddings,
37 | multimask_output=multimask_output, )
38 | return low_res_masks, iou_predictions
39 |
40 |
41 | class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
42 | def __init__(self, ori_sam: Sam, fix=False, class_num=20):
43 | super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix)
44 | self.decoder_neck = MaskDecoderNeck(transformer_dim=self.sam_mask_decoder.transformer_dim,
45 | transformer=self.sam_mask_decoder.transformer,
46 | num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs)
47 | self.decoder_head = SemSegHead(transformer_dim=self.sam_mask_decoder.transformer_dim,
48 | num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs,
49 | iou_head_depth=self.sam_mask_decoder.iou_head_depth,
50 | iou_head_hidden_dim=self.sam_mask_decoder.iou_head_hidden_dim,
51 | class_num=class_num)
52 | # pair the params between ori mask_decoder and new mask_decoder_adapter
53 | self.pair_params(self.decoder_neck)
54 | self.pair_params(self.decoder_head)
55 |
56 | def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True,
57 | scale=1):
58 | src, iou_token_out, mask_tokens_out, src_shape = self.decoder_neck(image_embeddings=image_embeddings,
59 | image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
60 | sparse_prompt_embeddings=sparse_embeddings,
61 | dense_prompt_embeddings=dense_embeddings,
62 | multimask_output=multimask_output, )
63 | masks, iou_pred = self.decoder_head(src, iou_token_out, mask_tokens_out, src_shape, mask_scale=scale)
64 | return masks, iou_pred
65 |
66 | def pair_params(self, target_model: nn.Module):
67 | src_dict = self.sam_mask_decoder.state_dict()
68 | for name, value in target_model.named_parameters():
69 | if name in src_dict.keys():
70 | value.data.copy_(src_dict[name].data)
71 |
72 |
73 | # Lightly adapted from
74 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
75 | class MLP(nn.Module):
76 | def __init__(
77 | self,
78 | input_dim: int,
79 | hidden_dim: int,
80 | output_dim: int,
81 | num_layers: int,
82 | sigmoid_output: bool = False,
83 | ) -> None:
84 | super().__init__()
85 | self.num_layers = num_layers
86 | h = [hidden_dim] * (num_layers - 1)
87 | self.layers = nn.ModuleList(
88 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
89 | )
90 | self.sigmoid_output = sigmoid_output
91 |
92 | def forward(self, x):
93 | for i, layer in enumerate(self.layers):
94 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
95 | if self.sigmoid_output:
96 | x = F.sigmoid(x)
97 | return x
98 |
--------------------------------------------------------------------------------
/extend_sam/mask_decoder_heads.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from typing import List, Tuple, Type
6 |
7 | from .segment_anything_ori.modeling.common import LayerNorm2d
8 |
9 |
10 | class OriHead(nn.Module):
11 |
12 | def __init__(
13 | self,
14 | *,
15 | transformer_dim: int,
16 | num_multimask_outputs: int = 3,
17 | activation: Type[nn.Module] = nn.GELU,
18 | iou_head_depth: int = 3,
19 | iou_head_hidden_dim: int = 256,
20 | ) -> None:
21 | """
22 | Predicts masks given an image and prompt embeddings, using a
23 | tranformer architecture.
24 |
25 | Arguments:
26 | transformer_dim (int): the channel dimension of the transformer
27 | num_multimask_outputs (int): the number of masks to predict
28 | when disambiguating masks
29 | activation (nn.Module): the type of activation to use when
30 | upscaling masks
31 | iou_head_depth (int): the depth of the MLP used to predict
32 | mask quality
33 | iou_head_hidden_dim (int): the hidden dimension of the MLP
34 | used to predict mask quality
35 | """
36 | super().__init__()
37 | self.transformer_dim = transformer_dim
38 |
39 | self.num_multimask_outputs = num_multimask_outputs
40 |
41 | self.num_mask_tokens = num_multimask_outputs + 1
42 |
43 | self.output_upscaling = nn.Sequential(
44 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
45 | LayerNorm2d(transformer_dim // 4),
46 | activation(),
47 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
48 | activation(),
49 | )
50 | self.output_hypernetworks_mlps = nn.ModuleList(
51 | [
52 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
53 | for i in range(self.num_mask_tokens)
54 | ]
55 | )
56 |
57 | self.iou_prediction_head = MLP(
58 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
59 | )
60 |
61 | def forward(
62 | self,
63 | src: torch.Tensor,
64 | iou_token_out: torch.Tensor,
65 | mask_tokens_out: torch.Tensor,
66 | multimask_output: bool,
67 | ) -> Tuple[torch.Tensor, torch.Tensor]:
68 | """
69 | Predict masks given image and prompt embeddings.
70 |
71 | Arguments:
72 | image_embeddings (torch.Tensor): the embeddings from the image encoder
73 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
74 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
75 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
76 | multimask_output (bool): Whether to return multiple masks or a single
77 | mask.
78 |
79 | Returns:
80 | torch.Tensor: batched predicted masks
81 | torch.Tensor: batched predictions of mask quality
82 | """
83 | b, c, h, w = src.shape
84 |
85 | # Upscale mask embeddings and predict masks using the mask tokens
86 | src = src.transpose(1, 2).view(b, c, h, w)
87 | upscaled_embedding = self.output_upscaling(src)
88 | hyper_in_list: List[torch.Tensor] = []
89 | for i in range(self.num_mask_tokens):
90 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
91 | hyper_in = torch.stack(hyper_in_list, dim=1)
92 | b, c, h, w = upscaled_embedding.shape
93 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
94 |
95 | # Generate mask quality predictions
96 | iou_pred = self.iou_prediction_head(iou_token_out)
97 |
98 | # Select the correct mask or masks for outptu
99 | if multimask_output:
100 | mask_slice = slice(1, None)
101 | else:
102 | mask_slice = slice(0, 1)
103 | masks = masks[:, mask_slice, :, :]
104 | iou_pred = iou_pred[:, mask_slice]
105 |
106 | # Prepare output
107 | return masks, iou_pred
108 |
109 |
110 | class SemSegHead(nn.Module):
111 |
112 | def __init__(
113 | self,
114 | *,
115 | transformer_dim: int,
116 | num_multimask_outputs: int = 3,
117 | activation: Type[nn.Module] = nn.GELU,
118 | iou_head_depth: int = 3,
119 | iou_head_hidden_dim: int = 256,
120 | class_num: int = 20,
121 | ) -> None:
122 | """
123 | Predicts masks given an image and prompt embeddings, using a
124 | tranformer architecture.
125 |
126 | Arguments:
127 | transformer_dim (int): the channel dimension of the transformer
128 | num_multimask_outputs (int): the number of masks to predict
129 | when disambiguating masks
130 | activation (nn.Module): the type of activation to use when
131 | upscaling masks
132 | iou_head_depth (int): the depth of the MLP used to predict
133 | mask quality
134 | iou_head_hidden_dim (int): the hidden dimension of the MLP
135 | used to predict mask quality
136 | """
137 | super().__init__()
138 | self.transformer_dim = transformer_dim
139 | self.num_multimask_outputs = num_multimask_outputs
140 | self.num_mask_tokens = num_multimask_outputs + 1
141 | self.class_num = class_num
142 |
143 | self.output_upscaling = nn.Sequential(
144 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
145 | LayerNorm2d(transformer_dim // 4),
146 | activation(),
147 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
148 | activation(),
149 | )
150 |
151 | self.output_hypernetworks_mlps = nn.ModuleList(
152 | [
153 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
154 | for _ in range(self.class_num)
155 | ]
156 | )
157 |
158 | self.iou_prediction_head = MLP(
159 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
160 | )
161 |
162 | def forward(
163 | self,
164 | src: torch.Tensor,
165 | iou_token_out: torch.Tensor,
166 | mask_tokens_out: torch.Tensor,
167 | src_shape,
168 | mask_scale=1,
169 | ) -> Tuple[torch.Tensor, torch.Tensor]:
170 | """
171 | Predict masks given image and prompt embeddings.
172 |
173 | Arguments:
174 | src (torch.Tensor): The tensor contains image embedding and sparse prompt embedding
175 | iou_token_out (torch.Tensor): Tokens of iou prediction from neck module
176 | mask_tokens_out (torch.Tensor): Tokens of mask prediction form neck module
177 | mask_scale (int): Original SAM output 3 masks which is from local to global as default
178 | This Class use one of three mask tokens to transform it into class-ware
179 | semantic segmentation prediction
180 |
181 | Returns:
182 | torch.Tensor: batched predicted semantic masks
183 | torch.Tensor: batched predictions of mask quality
184 | """
185 | b, c, h, w = src_shape
186 |
187 | # Upscale mask embeddings and predict masks using the mask tokens
188 | src = src.transpose(1, 2).view(b, c, h, w)
189 | upscaled_embedding = self.output_upscaling(src)
190 | hyper_in_list: List[torch.Tensor] = []
191 | for i in range(self.class_num):
192 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :]))
193 | hyper_in = torch.stack(hyper_in_list, dim=1)
194 |
195 | b, c, h, w = upscaled_embedding.shape
196 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # B N H W, N is num of category
197 |
198 | # Generate mask quality predictions
199 | iou_pred = self.iou_prediction_head(iou_token_out) # B N H W, N is num of category
200 |
201 | return masks, iou_pred
202 |
203 |
204 | # Lightly adapted from
205 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
206 | class MLP(nn.Module):
207 | def __init__(
208 | self,
209 | input_dim: int,
210 | hidden_dim: int,
211 | output_dim: int,
212 | num_layers: int,
213 | sigmoid_output: bool = False,
214 | ) -> None:
215 | super().__init__()
216 | self.num_layers = num_layers
217 | h = [hidden_dim] * (num_layers - 1)
218 | self.layers = nn.ModuleList(
219 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
220 | )
221 | self.sigmoid_output = sigmoid_output
222 |
223 | def forward(self, x):
224 | for i, layer in enumerate(self.layers):
225 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
226 | if self.sigmoid_output:
227 | x = F.sigmoid(x)
228 | return x
229 |
--------------------------------------------------------------------------------
/extend_sam/mask_decoder_neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 | from .segment_anything_ori.modeling.common import LayerNorm2d
13 |
14 | '''
15 | This file save the mask_decoder's neck class,
16 | which is the former part of original mask decoder of SAM.
17 | Then the mask_decoder_heads can be used with the neck.
18 | '''
19 |
20 |
21 | class MaskDecoderNeck(nn.Module):
22 | def __init__(
23 | self,
24 | *,
25 | transformer_dim: int,
26 | transformer: nn.Module,
27 | num_multimask_outputs: int = 3,
28 | activation: Type[nn.Module] = nn.GELU,
29 | ) -> None:
30 | """
31 | Predicts masks given an image and prompt embeddings, using a
32 | tranformer architecture.
33 |
34 | Arguments:
35 | transformer_dim (int): the channel dimension of the transformer
36 | transformer (nn.Module): the transformer used to predict masks
37 | num_multimask_outputs (int): the number of masks to predict
38 | when disambiguating masks
39 | activation (nn.Module): the type of activation to use when
40 | upscaling masks
41 | """
42 | super().__init__()
43 | self.transformer_dim = transformer_dim
44 | self.transformer = transformer
45 |
46 | self.num_multimask_outputs = num_multimask_outputs
47 |
48 | self.iou_token = nn.Embedding(1, transformer_dim)
49 | self.num_mask_tokens = num_multimask_outputs + 1
50 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
51 |
52 | self.output_upscaling = nn.Sequential(
53 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
54 | LayerNorm2d(transformer_dim // 4),
55 | activation(),
56 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
57 | activation(),
58 | )
59 |
60 | def forward(
61 | self,
62 | image_embeddings: torch.Tensor,
63 | image_pe: torch.Tensor,
64 | sparse_prompt_embeddings: torch.Tensor,
65 | dense_prompt_embeddings: torch.Tensor,
66 | multimask_output: bool,
67 | ) -> Tuple[torch.Tensor, torch.Tensor]:
68 | """
69 | Predict masks given image and prompt embeddings.
70 |
71 | Arguments:
72 | image_embeddings (torch.Tensor): the embeddings from the image encoder
73 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
74 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
75 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
76 | multimask_output (bool): Whether to return multiple masks or a single
77 | mask.
78 |
79 | Returns:
80 | torch.Tensor: The tensor contains image embedding and sparse prompt embedding
81 | torch.Tensor: Tokens of iou prediction
82 | torch.Tensor: Tokens of mask prediction
83 | """
84 | # Concatenate output tokens
85 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
86 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
87 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
88 |
89 | # Expand per-image data in batch direction to be per-mask
90 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
91 | src = src + dense_prompt_embeddings
92 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
93 | src_shape = src.shape
94 | # Run the transformer
95 | hs, src = self.transformer(src, pos_src, tokens)
96 | iou_token_out = hs[:, 0, :]
97 | mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]
98 |
99 | return src, iou_token_out, mask_tokens_out, src_shape
100 |
--------------------------------------------------------------------------------
/extend_sam/prompt_encoder_adapter.py:
--------------------------------------------------------------------------------
1 | # copyright ziqi-jin
2 |
3 | import torch.nn as nn
4 | from .segment_anything_ori.modeling.sam import Sam
5 | from .utils import fix_params
6 |
7 |
8 | class BasePromptEncodeAdapter(nn.Module):
9 |
10 | def __init__(self, ori_sam: Sam, fix=False):
11 | super(BasePromptEncodeAdapter, self).__init__()
12 |
13 | self.sam_prompt_encoder = ori_sam.prompt_encoder
14 | if fix:
15 | fix_params(self.sam_prompt_encoder)
16 |
17 | def forward(self, points=None, boxes=None, masks=None):
18 | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(points, boxes, masks)
19 | return sparse_embeddings, dense_embeddings
20 |
--------------------------------------------------------------------------------
/extend_sam/runner.py:
--------------------------------------------------------------------------------
1 | from datasets import Iterator
2 | from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \
3 | check_folder, one_hot_embedding_3d
4 | import torch
5 | import cv2
6 | import torch.nn.functional as F
7 | import os
8 | import torch.nn as nn
9 |
10 |
11 | class BaseRunner():
12 | def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
13 | self.optimizer = optimizer
14 | self.losses = losses
15 | self.train_loader = train_loader
16 | self.val_loader = val_loader
17 | self.model = model
18 | self.scheduler = scheduler
19 | self.train_timer = Timer()
20 | self.eval_timer = Timer()
21 | try:
22 | use_gpu = os.environ['CUDA_VISIBLE_DEVICES']
23 | except KeyError:
24 | use_gpu = '0'
25 | self.the_number_of_gpu = len(use_gpu.split(','))
26 | self.original_size = self.model.img_adapter.sam_img_encoder.img_size
27 | if self.the_number_of_gpu > 1:
28 | self.model = nn.DataParallel(self.model)
29 |
30 |
31 | class SemRunner(BaseRunner):
32 | # def __init__(self, **kwargs):
33 | # super().__init__(kwargs)
34 |
35 | def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
36 | super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler)
37 | self.exist_status = ['train', 'eval', 'test']
38 |
39 | def train(self, cfg):
40 | # initial identify
41 | train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss'])
42 | train_iterator = Iterator(self.train_loader)
43 | best_valid_mIoU = -1
44 | model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg)
45 | log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg)
46 | check_folder(model_path)
47 | check_folder(log_path)
48 | writer = None
49 | if cfg.use_tensorboard is True:
50 | tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg)
51 | from torch.utils.tensorboard import SummaryWriter
52 | writer = SummaryWriter(tensorboard_dir)
53 | # train
54 | for iteration in range(cfg.max_iter):
55 | images, labels = train_iterator.get()
56 | images, labels = images.cuda(), labels.cuda().long()
57 | masks_pred, iou_pred = self.model(images)
58 | masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False)
59 |
60 | total_loss = torch.zeros(1).cuda()
61 | loss_dict = {}
62 | self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg)
63 | self.optimizer.zero_grad()
64 | total_loss.backward()
65 | self.optimizer.step()
66 | self.scheduler.step()
67 | loss_dict['total_loss'] = total_loss.item()
68 | train_meter.add(loss_dict)
69 |
70 | # log
71 | if (iteration + 1) % cfg.log_iter == 0:
72 | write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True),
73 | status=self.exist_status[0],
74 | writer=writer, timer=self.train_timer)
75 | # eval
76 | if (iteration + 1) % cfg.eval_iter == 0:
77 | mIoU, _ = self._eval()
78 | if best_valid_mIoU == -1 or best_valid_mIoU < mIoU:
79 | best_valid_mIoU = mIoU
80 | save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1)
81 | print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path)
82 | log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU}
83 | write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1],
84 | writer=writer, timer=self.eval_timer)
85 | # final process
86 | save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1)
87 | if writer is not None:
88 | writer.close()
89 |
90 | def test(self):
91 | pass
92 |
93 | def _eval(self):
94 | self.model.eval()
95 | self.eval_timer.start()
96 | class_names = self.val_loader.dataset.class_names
97 | eval_metric = mIoUOnline(class_names=class_names)
98 | with torch.no_grad():
99 | for index, (images, labels) in enumerate(self.val_loader):
100 | images = images.cuda()
101 | labels = labels.cuda()
102 | masks_pred, iou_pred = self.model(images)
103 | predictions = torch.argmax(masks_pred, dim=1)
104 | for batch_index in range(images.size()[0]):
105 | pred_mask = get_numpy_from_tensor(predictions[batch_index])
106 | gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0))
107 | h, w = pred_mask.shape
108 | gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
109 |
110 | eval_metric.add(pred_mask, gt_mask)
111 | self.model.train()
112 | return eval_metric.get(clear=True)
113 |
114 | def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg):
115 | """
116 | Due to the inputs of losses are different, so if you want to add new losses,
117 | you may need to modify the process in this function
118 | """
119 | loss_cfg = cfg.losses
120 | for index, item in enumerate(self.losses.items()):
121 | # item -> (key: loss_name, val: loss)
122 | real_labels = labels
123 | if loss_cfg[item[0]].label_one_hot:
124 | class_num = cfg.model.params.class_num
125 | real_labels = one_hot_embedding_3d(real_labels, class_num=class_num)
126 | tmp_loss = item[1](mask_pred, real_labels)
127 | loss_dict[item[0]] = tmp_loss.item()
128 | total_loss += loss_cfg[item[0]].weight * tmp_loss
129 |
--------------------------------------------------------------------------------
/extend_sam/scheduler.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/optim/lr_scheduler.py # noqa
2 | # and https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/solver/lr_scheduler.py
3 |
4 | from bisect import bisect_right
5 | from typing import List
6 |
7 | import torch
8 | from torch.optim.lr_scheduler import _LRScheduler
9 |
10 |
11 | class WarmupMultiStepLR(_LRScheduler):
12 | def __init__(
13 | self,
14 | optimizer: torch.optim.Optimizer,
15 | milestones: List[int],
16 | gamma: float = 0.1,
17 | warmup_factor: float = 0.001,
18 | warmup_iters: int = 1000,
19 | warmup_method: str = "linear",
20 | last_epoch: int = -1,
21 | **kwargs,
22 | ):
23 | if not list(milestones) == sorted(milestones):
24 | raise ValueError(
25 | "Milestones should be a list of" " increasing integers. Got {}",
26 | milestones,
27 | )
28 | self.milestones = milestones
29 | self.gamma = gamma
30 | self.warmup_factor = warmup_factor
31 | self.warmup_iters = warmup_iters
32 | self.warmup_method = warmup_method
33 | super().__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self) -> List[float]:
36 | warmup_factor = _get_warmup_factor_at_iter(
37 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
38 | )
39 | return [
40 | base_lr
41 | * warmup_factor
42 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
43 | for base_lr in self.base_lrs
44 | ]
45 |
46 | def _compute_values(self) -> List[float]:
47 | # The new interface
48 | return self.get_lr()
49 |
50 |
51 | def _get_warmup_factor_at_iter(
52 | method: str, iter: int, warmup_iters: int, warmup_factor: float
53 | ) -> float:
54 | """
55 | Return the learning rate warmup factor at a specific iteration.
56 | See https://arxiv.org/abs/1706.02677 for more details.
57 | Args:
58 | method (str): warmup method; either "constant" or "linear".
59 | iter (int): iteration at which to calculate the warmup factor.
60 | warmup_iters (int): the number of warmup iterations.
61 | warmup_factor (float): the base warmup factor (the meaning changes according
62 | to the method used).
63 | Returns:
64 | float: the effective warmup factor at the given iteration.
65 | """
66 | if iter >= warmup_iters:
67 | return 1.0
68 |
69 | if method == "constant":
70 | return warmup_factor
71 | elif method == "linear":
72 | alpha = iter / warmup_iters
73 | return warmup_factor * (1 - alpha) + alpha
74 | else:
75 | raise ValueError("Unknown warmup method: {}".format(method))
76 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # modified by ziqi-jin
8 |
9 | from .build_sam import (
10 | build_sam,
11 | build_sam_vit_h,
12 | build_sam_vit_l,
13 | build_sam_vit_b,
14 | sam_model_registry,
15 | )
16 | from .modeling.sam import Sam
17 | from .predictor import SamPredictor
18 | from .automatic_mask_generator import SamAutomaticMaskGenerator
19 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/automatic_mask_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10 |
11 | from typing import Any, Dict, List, Optional, Tuple
12 |
13 | from .modeling import Sam
14 | from .predictor import SamPredictor
15 | from .utils.amg import (
16 | MaskData,
17 | area_from_rle,
18 | batch_iterator,
19 | batched_mask_to_box,
20 | box_xyxy_to_xywh,
21 | build_all_layer_point_grids,
22 | calculate_stability_score,
23 | coco_encode_rle,
24 | generate_crop_boxes,
25 | is_box_near_crop_edge,
26 | mask_to_rle_pytorch,
27 | remove_small_regions,
28 | rle_to_mask,
29 | uncrop_boxes_xyxy,
30 | uncrop_masks,
31 | uncrop_points,
32 | )
33 |
34 |
35 | class SamAutomaticMaskGenerator:
36 | def __init__(
37 | self,
38 | model: Sam,
39 | points_per_side: Optional[int] = 32,
40 | points_per_batch: int = 64,
41 | pred_iou_thresh: float = 0.88,
42 | stability_score_thresh: float = 0.95,
43 | stability_score_offset: float = 1.0,
44 | box_nms_thresh: float = 0.7,
45 | crop_n_layers: int = 0,
46 | crop_nms_thresh: float = 0.7,
47 | crop_overlap_ratio: float = 512 / 1500,
48 | crop_n_points_downscale_factor: int = 1,
49 | point_grids: Optional[List[np.ndarray]] = None,
50 | min_mask_region_area: int = 0,
51 | output_mode: str = "binary_mask",
52 | ) -> None:
53 | """
54 | Using a SAM model, generates masks for the entire image.
55 | Generates a grid of point prompts over the image, then filters
56 | low quality and duplicate masks. The default settings are chosen
57 | for SAM with a ViT-H backbone.
58 |
59 | Arguments:
60 | model (Sam): The SAM model to use for mask prediction.
61 | points_per_side (int or None): The number of points to be sampled
62 | along one side of the image. The total number of points is
63 | points_per_side**2. If None, 'point_grids' must provide explicit
64 | point sampling.
65 | points_per_batch (int): Sets the number of points run simultaneously
66 | by the model. Higher numbers may be faster but use more GPU memory.
67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the
68 | model's predicted mask quality.
69 | stability_score_thresh (float): A filtering threshold in [0,1], using
70 | the stability of the mask under changes to the cutoff used to binarize
71 | the model's mask predictions.
72 | stability_score_offset (float): The amount to shift the cutoff when
73 | calculated the stability score.
74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal
75 | suppression to filter duplicate masks.
76 | crops_n_layers (int): If >0, mask prediction will be run again on
77 | crops of the image. Sets the number of layers to run, where each
78 | layer has 2**i_layer number of image crops.
79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal
80 | suppression to filter duplicate masks between different crops.
81 | crop_overlap_ratio (float): Sets the degree to which crops overlap.
82 | In the first crop layer, crops will overlap by this fraction of
83 | the image length. Later layers with more crops scale down this overlap.
84 | crop_n_points_downscale_factor (int): The number of points-per-side
85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86 | point_grids (list(np.ndarray) or None): A list over explicit grids
87 | of points used for sampling, normalized to [0,1]. The nth grid in the
88 | list is used in the nth crop layer. Exclusive with points_per_side.
89 | min_mask_region_area (int): If >0, postprocessing will be applied
90 | to remove disconnected regions and holes in masks with area smaller
91 | than min_mask_region_area. Requires opencv.
92 | output_mode (str): The form masks are returned in. Can be 'binary_mask',
93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94 | For large resolutions, 'binary_mask' may consume large amounts of
95 | memory.
96 | """
97 |
98 | assert (points_per_side is None) != (
99 | point_grids is None
100 | ), "Exactly one of points_per_side or point_grid must be provided."
101 | if points_per_side is not None:
102 | self.point_grids = build_all_layer_point_grids(
103 | points_per_side,
104 | crop_n_layers,
105 | crop_n_points_downscale_factor,
106 | )
107 | elif point_grids is not None:
108 | self.point_grids = point_grids
109 | else:
110 | raise ValueError("Can't have both points_per_side and point_grid be None.")
111 |
112 | assert output_mode in [
113 | "binary_mask",
114 | "uncompressed_rle",
115 | "coco_rle",
116 | ], f"Unknown output_mode {output_mode}."
117 | if output_mode == "coco_rle":
118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119 |
120 | if min_mask_region_area > 0:
121 | import cv2 # type: ignore # noqa: F401
122 |
123 | self.predictor = SamPredictor(model)
124 | self.points_per_batch = points_per_batch
125 | self.pred_iou_thresh = pred_iou_thresh
126 | self.stability_score_thresh = stability_score_thresh
127 | self.stability_score_offset = stability_score_offset
128 | self.box_nms_thresh = box_nms_thresh
129 | self.crop_n_layers = crop_n_layers
130 | self.crop_nms_thresh = crop_nms_thresh
131 | self.crop_overlap_ratio = crop_overlap_ratio
132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133 | self.min_mask_region_area = min_mask_region_area
134 | self.output_mode = output_mode
135 |
136 | @torch.no_grad()
137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138 | """
139 | Generates masks for the given image.
140 |
141 | Arguments:
142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143 |
144 | Returns:
145 | list(dict(str, any)): A list over records for masks. Each record is
146 | a dict containing the following keys:
147 | segmentation (dict(str, any) or np.ndarray): The mask. If
148 | output_mode='binary_mask', is an array of shape HW. Otherwise,
149 | is a dictionary containing the RLE.
150 | bbox (list(float)): The box around the mask, in XYWH format.
151 | area (int): The area in pixels of the mask.
152 | predicted_iou (float): The model's own prediction of the mask's
153 | quality. This is filtered by the pred_iou_thresh parameter.
154 | point_coords (list(list(float))): The point coordinates input
155 | to the model to generate this mask.
156 | stability_score (float): A measure of the mask's quality. This
157 | is filtered on using the stability_score_thresh parameter.
158 | crop_box (list(float)): The crop of the image used to generate
159 | the mask, given in XYWH format.
160 | """
161 |
162 | # Generate masks
163 | mask_data = self._generate_masks(image)
164 |
165 | # Filter small disconnected regions and holes in masks
166 | if self.min_mask_region_area > 0:
167 | mask_data = self.postprocess_small_regions(
168 | mask_data,
169 | self.min_mask_region_area,
170 | max(self.box_nms_thresh, self.crop_nms_thresh),
171 | )
172 |
173 | # Encode masks
174 | if self.output_mode == "coco_rle":
175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176 | elif self.output_mode == "binary_mask":
177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178 | else:
179 | mask_data["segmentations"] = mask_data["rles"]
180 |
181 | # Write mask records
182 | curr_anns = []
183 | for idx in range(len(mask_data["segmentations"])):
184 | ann = {
185 | "segmentation": mask_data["segmentations"][idx],
186 | "area": area_from_rle(mask_data["rles"][idx]),
187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188 | "predicted_iou": mask_data["iou_preds"][idx].item(),
189 | "point_coords": [mask_data["points"][idx].tolist()],
190 | "stability_score": mask_data["stability_score"][idx].item(),
191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192 | }
193 | curr_anns.append(ann)
194 |
195 | return curr_anns
196 |
197 | def _generate_masks(self, image: np.ndarray) -> MaskData:
198 | orig_size = image.shape[:2]
199 | crop_boxes, layer_idxs = generate_crop_boxes(
200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio
201 | )
202 |
203 | # Iterate over image crops
204 | data = MaskData()
205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 | data.cat(crop_data)
208 |
209 | # Remove duplicate masks between crops
210 | if len(crop_boxes) > 1:
211 | # Prefer masks from smaller crops
212 | scores = 1 / box_area(data["crop_boxes"])
213 | scores = scores.to(data["boxes"].device)
214 | keep_by_nms = batched_nms(
215 | data["boxes"].float(),
216 | scores,
217 | torch.zeros(len(data["boxes"])), # categories
218 | iou_threshold=self.crop_nms_thresh,
219 | )
220 | data.filter(keep_by_nms)
221 |
222 | data.to_numpy()
223 | return data
224 |
225 | def _process_crop(
226 | self,
227 | image: np.ndarray,
228 | crop_box: List[int],
229 | crop_layer_idx: int,
230 | orig_size: Tuple[int, ...],
231 | ) -> MaskData:
232 | # Crop the image and calculate embeddings
233 | x0, y0, x1, y1 = crop_box
234 | cropped_im = image[y0:y1, x0:x1, :]
235 | cropped_im_size = cropped_im.shape[:2]
236 | self.predictor.set_image(cropped_im)
237 |
238 | # Get points for this crop
239 | points_scale = np.array(cropped_im_size)[None, ::-1]
240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale
241 |
242 | # Generate masks for this crop in batches
243 | data = MaskData()
244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246 | data.cat(batch_data)
247 | del batch_data
248 | self.predictor.reset_image()
249 |
250 | # Remove duplicates within this crop.
251 | keep_by_nms = batched_nms(
252 | data["boxes"].float(),
253 | data["iou_preds"],
254 | torch.zeros(len(data["boxes"])), # categories
255 | iou_threshold=self.box_nms_thresh,
256 | )
257 | data.filter(keep_by_nms)
258 |
259 | # Return to the original image frame
260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261 | data["points"] = uncrop_points(data["points"], crop_box)
262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263 |
264 | return data
265 |
266 | def _process_batch(
267 | self,
268 | points: np.ndarray,
269 | im_size: Tuple[int, ...],
270 | crop_box: List[int],
271 | orig_size: Tuple[int, ...],
272 | ) -> MaskData:
273 | orig_h, orig_w = orig_size
274 |
275 | # Run model on this batch
276 | transformed_points = self.predictor.transform.apply_coords(points, im_size)
277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279 | masks, iou_preds, _ = self.predictor.predict_torch(
280 | in_points[:, None, :],
281 | in_labels[:, None],
282 | multimask_output=True,
283 | return_logits=True,
284 | )
285 |
286 | # Serialize predictions and store in MaskData
287 | data = MaskData(
288 | masks=masks.flatten(0, 1),
289 | iou_preds=iou_preds.flatten(0, 1),
290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291 | )
292 | del masks
293 |
294 | # Filter by predicted IoU
295 | if self.pred_iou_thresh > 0.0:
296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh
297 | data.filter(keep_mask)
298 |
299 | # Calculate stability score
300 | data["stability_score"] = calculate_stability_score(
301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302 | )
303 | if self.stability_score_thresh > 0.0:
304 | keep_mask = data["stability_score"] >= self.stability_score_thresh
305 | data.filter(keep_mask)
306 |
307 | # Threshold masks and calculate boxes
308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309 | data["boxes"] = batched_mask_to_box(data["masks"])
310 |
311 | # Filter boxes that touch crop boundaries
312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313 | if not torch.all(keep_mask):
314 | data.filter(keep_mask)
315 |
316 | # Compress to RLE
317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318 | data["rles"] = mask_to_rle_pytorch(data["masks"])
319 | del data["masks"]
320 |
321 | return data
322 |
323 | @staticmethod
324 | def postprocess_small_regions(
325 | mask_data: MaskData, min_area: int, nms_thresh: float
326 | ) -> MaskData:
327 | """
328 | Removes small disconnected regions and holes in masks, then reruns
329 | box NMS to remove any new duplicates.
330 |
331 | Edits mask_data in place.
332 |
333 | Requires open-cv as a dependency.
334 | """
335 | if len(mask_data["rles"]) == 0:
336 | return mask_data
337 |
338 | # Filter small disconnected regions and holes
339 | new_masks = []
340 | scores = []
341 | for rle in mask_data["rles"]:
342 | mask = rle_to_mask(rle)
343 |
344 | mask, changed = remove_small_regions(mask, min_area, mode="holes")
345 | unchanged = not changed
346 | mask, changed = remove_small_regions(mask, min_area, mode="islands")
347 | unchanged = unchanged and not changed
348 |
349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350 | # Give score=0 to changed masks and score=1 to unchanged masks
351 | # so NMS will prefer ones that didn't need postprocessing
352 | scores.append(float(unchanged))
353 |
354 | # Recalculate boxes and remove any new duplicates
355 | masks = torch.cat(new_masks, dim=0)
356 | boxes = batched_mask_to_box(masks)
357 | keep_by_nms = batched_nms(
358 | boxes.float(),
359 | torch.as_tensor(scores),
360 | torch.zeros(len(boxes)), # categories
361 | iou_threshold=nms_thresh,
362 | )
363 |
364 | # Only recalculate RLEs for masks that have changed
365 | for i_mask in keep_by_nms:
366 | if scores[i_mask] == 0.0:
367 | mask_torch = masks[i_mask].unsqueeze(0)
368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370 | mask_data.filter(keep_by_nms)
371 |
372 | return mask_data
373 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # modified by ziqi-jin
8 |
9 | import torch
10 |
11 | from functools import partial
12 |
13 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
14 |
15 |
16 | def build_sam_vit_h(checkpoint=None):
17 | return _build_sam(
18 | encoder_embed_dim=1280,
19 | encoder_depth=32,
20 | encoder_num_heads=16,
21 | encoder_global_attn_indexes=[7, 15, 23, 31],
22 | checkpoint=checkpoint,
23 | )
24 |
25 |
26 | build_sam = build_sam_vit_h
27 |
28 |
29 | def build_sam_vit_l(checkpoint=None):
30 | return _build_sam(
31 | encoder_embed_dim=1024,
32 | encoder_depth=24,
33 | encoder_num_heads=16,
34 | encoder_global_attn_indexes=[5, 11, 17, 23],
35 | checkpoint=checkpoint,
36 | )
37 |
38 |
39 | def build_sam_vit_b(checkpoint=None):
40 | return _build_sam(
41 | encoder_embed_dim=768,
42 | encoder_depth=12,
43 | encoder_num_heads=12,
44 | encoder_global_attn_indexes=[2, 5, 8, 11],
45 | checkpoint=checkpoint,
46 | )
47 |
48 |
49 | sam_model_registry = {
50 | "default": build_sam_vit_h,
51 | "vit_h": build_sam_vit_h,
52 | "vit_l": build_sam_vit_l,
53 | "vit_b": build_sam_vit_b,
54 | }
55 |
56 |
57 | def _build_sam(
58 | encoder_embed_dim,
59 | encoder_depth,
60 | encoder_num_heads,
61 | encoder_global_attn_indexes,
62 | checkpoint=None,
63 | ):
64 | prompt_embed_dim = 256
65 | image_size = 1024
66 | vit_patch_size = 16
67 | image_embedding_size = image_size // vit_patch_size
68 | sam = Sam(
69 | image_encoder=ImageEncoderViT(
70 | depth=encoder_depth,
71 | embed_dim=encoder_embed_dim,
72 | img_size=image_size,
73 | mlp_ratio=4,
74 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
75 | num_heads=encoder_num_heads,
76 | patch_size=vit_patch_size,
77 | qkv_bias=True,
78 | use_rel_pos=True,
79 | global_attn_indexes=encoder_global_attn_indexes,
80 | window_size=14,
81 | out_chans=prompt_embed_dim,
82 | ),
83 | prompt_encoder=PromptEncoder(
84 | embed_dim=prompt_embed_dim,
85 | image_embedding_size=(image_embedding_size, image_embedding_size),
86 | input_image_size=(image_size, image_size),
87 | mask_in_chans=16,
88 | ),
89 | mask_decoder=MaskDecoder(
90 | num_multimask_outputs=3,
91 | transformer=TwoWayTransformer(
92 | depth=2,
93 | embedding_dim=prompt_embed_dim,
94 | mlp_dim=2048,
95 | num_heads=8,
96 | ),
97 | transformer_dim=prompt_embed_dim,
98 | iou_head_depth=3,
99 | iou_head_hidden_dim=256,
100 | ),
101 | pixel_mean=[123.675, 116.28, 103.53],
102 | pixel_std=[58.395, 57.12, 57.375],
103 | )
104 | if checkpoint is not None:
105 | with open(checkpoint, "rb") as f:
106 | state_dict = torch.load(f)
107 | sam.load_state_dict(state_dict)
108 | return sam
109 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from typing import Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d, MLPBlock
14 |
15 |
16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17 | class ImageEncoderViT(nn.Module):
18 | def __init__(
19 | self,
20 | img_size: int = 1024,
21 | patch_size: int = 16,
22 | in_chans: int = 3,
23 | embed_dim: int = 768,
24 | depth: int = 12,
25 | num_heads: int = 12,
26 | mlp_ratio: float = 4.0,
27 | out_chans: int = 256,
28 | qkv_bias: bool = True,
29 | norm_layer: Type[nn.Module] = nn.LayerNorm,
30 | act_layer: Type[nn.Module] = nn.GELU,
31 | use_abs_pos: bool = True,
32 | use_rel_pos: bool = False,
33 | rel_pos_zero_init: bool = True,
34 | window_size: int = 0,
35 | global_attn_indexes: Tuple[int, ...] = (),
36 | ) -> None:
37 | """
38 | Args:
39 | img_size (int): Input image size.
40 | patch_size (int): Patch size.
41 | in_chans (int): Number of input image channels.
42 | embed_dim (int): Patch embedding dimension.
43 | depth (int): Depth of ViT.
44 | num_heads (int): Number of attention heads in each ViT block.
45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
47 | norm_layer (nn.Module): Normalization layer.
48 | act_layer (nn.Module): Activation layer.
49 | use_abs_pos (bool): If True, use absolute positional embeddings.
50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52 | window_size (int): Window size for window attention blocks.
53 | global_attn_indexes (list): Indexes for blocks using global attention.
54 | """
55 | super().__init__()
56 | self.img_size = img_size
57 |
58 | self.patch_embed = PatchEmbed(
59 | kernel_size=(patch_size, patch_size),
60 | stride=(patch_size, patch_size),
61 | in_chans=in_chans,
62 | embed_dim=embed_dim,
63 | )
64 |
65 | self.pos_embed: Optional[nn.Parameter] = None
66 | if use_abs_pos:
67 | # Initialize absolute positional embedding with pretrain image size.
68 | self.pos_embed = nn.Parameter(
69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70 | )
71 |
72 | self.blocks = nn.ModuleList()
73 | for i in range(depth):
74 | block = Block(
75 | dim=embed_dim,
76 | num_heads=num_heads,
77 | mlp_ratio=mlp_ratio,
78 | qkv_bias=qkv_bias,
79 | norm_layer=norm_layer,
80 | act_layer=act_layer,
81 | use_rel_pos=use_rel_pos,
82 | rel_pos_zero_init=rel_pos_zero_init,
83 | window_size=window_size if i not in global_attn_indexes else 0,
84 | input_size=(img_size // patch_size, img_size // patch_size),
85 | )
86 | self.blocks.append(block)
87 |
88 | self.neck = nn.Sequential(
89 | nn.Conv2d(
90 | embed_dim,
91 | out_chans,
92 | kernel_size=1,
93 | bias=False,
94 | ),
95 | LayerNorm2d(out_chans),
96 | nn.Conv2d(
97 | out_chans,
98 | out_chans,
99 | kernel_size=3,
100 | padding=1,
101 | bias=False,
102 | ),
103 | LayerNorm2d(out_chans),
104 | )
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = self.patch_embed(x)
108 | if self.pos_embed is not None:
109 | x = x + self.pos_embed
110 |
111 | for blk in self.blocks:
112 | x = blk(x)
113 |
114 | x = self.neck(x.permute(0, 3, 1, 2))
115 |
116 | return x
117 |
118 |
119 | class Block(nn.Module):
120 | """Transformer blocks with support of window attention and residual propagation blocks"""
121 |
122 | def __init__(
123 | self,
124 | dim: int,
125 | num_heads: int,
126 | mlp_ratio: float = 4.0,
127 | qkv_bias: bool = True,
128 | norm_layer: Type[nn.Module] = nn.LayerNorm,
129 | act_layer: Type[nn.Module] = nn.GELU,
130 | use_rel_pos: bool = False,
131 | rel_pos_zero_init: bool = True,
132 | window_size: int = 0,
133 | input_size: Optional[Tuple[int, int]] = None,
134 | ) -> None:
135 | """
136 | Args:
137 | dim (int): Number of input channels.
138 | num_heads (int): Number of attention heads in each ViT block.
139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
141 | norm_layer (nn.Module): Normalization layer.
142 | act_layer (nn.Module): Activation layer.
143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145 | window_size (int): Window size for window attention blocks. If it equals 0, then
146 | use global attention.
147 | input_size (int or None): Input resolution for calculating the relative positional
148 | parameter size.
149 | """
150 | super().__init__()
151 | self.norm1 = norm_layer(dim)
152 | self.attn = Attention(
153 | dim,
154 | num_heads=num_heads,
155 | qkv_bias=qkv_bias,
156 | use_rel_pos=use_rel_pos,
157 | rel_pos_zero_init=rel_pos_zero_init,
158 | input_size=input_size if window_size == 0 else (window_size, window_size),
159 | )
160 |
161 | self.norm2 = norm_layer(dim)
162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163 |
164 | self.window_size = window_size
165 |
166 | def forward(self, x: torch.Tensor) -> torch.Tensor:
167 | shortcut = x
168 | x = self.norm1(x)
169 | # Window partition
170 | if self.window_size > 0:
171 | H, W = x.shape[1], x.shape[2]
172 | x, pad_hw = window_partition(x, self.window_size)
173 |
174 | x = self.attn(x)
175 | # Reverse window partition
176 | if self.window_size > 0:
177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178 |
179 | x = shortcut + x
180 | x = x + self.mlp(self.norm2(x))
181 |
182 | return x
183 |
184 |
185 | class Attention(nn.Module):
186 | """Multi-head Attention block with relative position embeddings."""
187 |
188 | def __init__(
189 | self,
190 | dim: int,
191 | num_heads: int = 8,
192 | qkv_bias: bool = True,
193 | use_rel_pos: bool = False,
194 | rel_pos_zero_init: bool = True,
195 | input_size: Optional[Tuple[int, int]] = None,
196 | ) -> None:
197 | """
198 | Args:
199 | dim (int): Number of input channels.
200 | num_heads (int): Number of attention heads.
201 | qkv_bias (bool: If True, add a learnable bias to query, key, value.
202 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204 | input_size (int or None): Input resolution for calculating the relative positional
205 | parameter size.
206 | """
207 | super().__init__()
208 | self.num_heads = num_heads
209 | head_dim = dim // num_heads
210 | self.scale = head_dim**-0.5
211 |
212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213 | self.proj = nn.Linear(dim, dim)
214 |
215 | self.use_rel_pos = use_rel_pos
216 | if self.use_rel_pos:
217 | assert (
218 | input_size is not None
219 | ), "Input size must be provided if using relative positional encoding."
220 | # initialize relative positional embeddings
221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223 |
224 | def forward(self, x: torch.Tensor) -> torch.Tensor:
225 | B, H, W, _ = x.shape
226 | # qkv with shape (3, B, nHead, H * W, C)
227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228 | # q, k, v with shape (B * nHead, H * W, C)
229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230 |
231 | attn = (q * self.scale) @ k.transpose(-2, -1)
232 |
233 | if self.use_rel_pos:
234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235 |
236 | attn = attn.softmax(dim=-1)
237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238 | x = self.proj(x)
239 |
240 | return x
241 |
242 |
243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244 | """
245 | Partition into non-overlapping windows with padding if needed.
246 | Args:
247 | x (tensor): input tokens with [B, H, W, C].
248 | window_size (int): window size.
249 |
250 | Returns:
251 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
252 | (Hp, Wp): padded height and width before partition
253 | """
254 | B, H, W, C = x.shape
255 |
256 | pad_h = (window_size - H % window_size) % window_size
257 | pad_w = (window_size - W % window_size) % window_size
258 | if pad_h > 0 or pad_w > 0:
259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260 | Hp, Wp = H + pad_h, W + pad_w
261 |
262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264 | return windows, (Hp, Wp)
265 |
266 |
267 | def window_unpartition(
268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269 | ) -> torch.Tensor:
270 | """
271 | Window unpartition into original sequences and removing padding.
272 | Args:
273 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274 | window_size (int): window size.
275 | pad_hw (Tuple): padded height and width (Hp, Wp).
276 | hw (Tuple): original height and width (H, W) before padding.
277 |
278 | Returns:
279 | x: unpartitioned sequences with [B, H, W, C].
280 | """
281 | Hp, Wp = pad_hw
282 | H, W = hw
283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286 |
287 | if Hp > H or Wp > W:
288 | x = x[:, :H, :W, :].contiguous()
289 | return x
290 |
291 |
292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293 | """
294 | Get relative positional embeddings according to the relative positions of
295 | query and key sizes.
296 | Args:
297 | q_size (int): size of query q.
298 | k_size (int): size of key k.
299 | rel_pos (Tensor): relative position embeddings (L, C).
300 |
301 | Returns:
302 | Extracted positional embeddings according to relative positions.
303 | """
304 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
305 | # Interpolate rel pos if needed.
306 | if rel_pos.shape[0] != max_rel_dist:
307 | # Interpolate rel pos.
308 | rel_pos_resized = F.interpolate(
309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310 | size=max_rel_dist,
311 | mode="linear",
312 | )
313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314 | else:
315 | rel_pos_resized = rel_pos
316 |
317 | # Scale the coords with short length if shapes for q and k are different.
318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321 |
322 | return rel_pos_resized[relative_coords.long()]
323 |
324 |
325 | def add_decomposed_rel_pos(
326 | attn: torch.Tensor,
327 | q: torch.Tensor,
328 | rel_pos_h: torch.Tensor,
329 | rel_pos_w: torch.Tensor,
330 | q_size: Tuple[int, int],
331 | k_size: Tuple[int, int],
332 | ) -> torch.Tensor:
333 | """
334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336 | Args:
337 | attn (Tensor): attention map.
338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343 |
344 | Returns:
345 | attn (Tensor): attention map with added relative positional embeddings.
346 | """
347 | q_h, q_w = q_size
348 | k_h, k_w = k_size
349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351 |
352 | B, _, dim = q.shape
353 | r_q = q.reshape(B, q_h, q_w, dim)
354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356 |
357 | attn = (
358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359 | ).view(B, q_h * q_w, k_h * k_w)
360 |
361 | return attn
362 |
363 |
364 | class PatchEmbed(nn.Module):
365 | """
366 | Image to Patch Embedding.
367 | """
368 |
369 | def __init__(
370 | self,
371 | kernel_size: Tuple[int, int] = (16, 16),
372 | stride: Tuple[int, int] = (16, 16),
373 | padding: Tuple[int, int] = (0, 0),
374 | in_chans: int = 3,
375 | embed_dim: int = 768,
376 | ) -> None:
377 | """
378 | Args:
379 | kernel_size (Tuple): kernel size of the projection layer.
380 | stride (Tuple): stride of the projection layer.
381 | padding (Tuple): padding size of the projection layer.
382 | in_chans (int): Number of input image channels.
383 | embed_dim (int): embed_dim (int): Patch embedding dimension.
384 | """
385 | super().__init__()
386 |
387 | self.proj = nn.Conv2d(
388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389 | )
390 |
391 | def forward(self, x: torch.Tensor) -> torch.Tensor:
392 | x = self.proj(x)
393 | # B C H W -> B H W C
394 | x = x.permute(0, 2, 3, 1)
395 | return x
396 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | tranformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 | self.iou_head_depth = iou_head_depth
53 | self.iou_head_hidden_dim = iou_head_hidden_dim
54 | self.output_upscaling = nn.Sequential(
55 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
56 | LayerNorm2d(transformer_dim // 4),
57 | activation(),
58 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
59 | activation(),
60 | )
61 | self.output_hypernetworks_mlps = nn.ModuleList(
62 | [
63 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
64 | for i in range(self.num_mask_tokens)
65 | ]
66 | )
67 |
68 | self.iou_prediction_head = MLP(
69 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
70 | )
71 |
72 | def forward(
73 | self,
74 | image_embeddings: torch.Tensor,
75 | image_pe: torch.Tensor,
76 | sparse_prompt_embeddings: torch.Tensor,
77 | dense_prompt_embeddings: torch.Tensor,
78 | multimask_output: bool,
79 | ) -> Tuple[torch.Tensor, torch.Tensor]:
80 | """
81 | Predict masks given image and prompt embeddings.
82 |
83 | Arguments:
84 | image_embeddings (torch.Tensor): the embeddings from the image encoder
85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
88 | multimask_output (bool): Whether to return multiple masks or a single
89 | mask.
90 |
91 | Returns:
92 | torch.Tensor: batched predicted masks
93 | torch.Tensor: batched predictions of mask quality
94 | """
95 | masks, iou_pred = self.predict_masks(
96 | image_embeddings=image_embeddings,
97 | image_pe=image_pe,
98 | sparse_prompt_embeddings=sparse_prompt_embeddings,
99 | dense_prompt_embeddings=dense_prompt_embeddings,
100 | )
101 |
102 | # Select the correct mask or masks for outptu
103 | if multimask_output:
104 | mask_slice = slice(1, None)
105 | else:
106 | mask_slice = slice(0, 1)
107 | masks = masks[:, mask_slice, :, :]
108 | iou_pred = iou_pred[:, mask_slice]
109 |
110 | # Prepare output
111 | return masks, iou_pred
112 |
113 | def predict_masks(
114 | self,
115 | image_embeddings: torch.Tensor,
116 | image_pe: torch.Tensor,
117 | sparse_prompt_embeddings: torch.Tensor,
118 | dense_prompt_embeddings: torch.Tensor,
119 | ) -> Tuple[torch.Tensor, torch.Tensor]:
120 | """Predicts masks. See 'forward' for more details."""
121 | # Concatenate output tokens
122 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
123 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
124 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
125 |
126 | # Expand per-image data in batch direction to be per-mask
127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
128 | src = src + dense_prompt_embeddings
129 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
130 | b, c, h, w = src.shape
131 |
132 | # Run the transformer
133 | hs, src = self.transformer(src, pos_src, tokens)
134 | iou_token_out = hs[:, 0, :]
135 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
136 |
137 | # Upscale mask embeddings and predict masks using the mask tokens
138 | src = src.transpose(1, 2).view(b, c, h, w)
139 | upscaled_embedding = self.output_upscaling(src)
140 | hyper_in_list: List[torch.Tensor] = []
141 | for i in range(self.num_mask_tokens):
142 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
143 | hyper_in = torch.stack(hyper_in_list, dim=1)
144 | b, c, h, w = upscaled_embedding.shape
145 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
146 |
147 | # Generate mask quality predictions
148 | iou_pred = self.iou_prediction_head(iou_token_out)
149 |
150 | return masks, iou_pred
151 |
152 |
153 | # Lightly adapted from
154 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
155 | class MLP(nn.Module):
156 | def __init__(
157 | self,
158 | input_dim: int,
159 | hidden_dim: int,
160 | output_dim: int,
161 | num_layers: int,
162 | sigmoid_output: bool = False,
163 | ) -> None:
164 | super().__init__()
165 | self.num_layers = num_layers
166 | h = [hidden_dim] * (num_layers - 1)
167 | self.layers = nn.ModuleList(
168 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
169 | )
170 | self.sigmoid_output = sigmoid_output
171 |
172 | def forward(self, x):
173 | for i, layer in enumerate(self.layers):
174 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
175 | if self.sigmoid_output:
176 | x = F.sigmoid(x)
177 | return x
178 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels = torch.cat([labels, padding_label], dim=1)
86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87 | point_embedding[labels == -1] = 0.0
88 | point_embedding[labels == -1] += self.not_a_point_embed.weight
89 | point_embedding[labels == 0] += self.point_embeddings[0].weight
90 | point_embedding[labels == 1] += self.point_embeddings[1].weight
91 | return point_embedding
92 |
93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94 | """Embeds box prompts."""
95 | boxes = boxes + 0.5 # Shift to center of pixel
96 | coords = boxes.reshape(-1, 2, 2)
97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100 | return corner_embedding
101 |
102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103 | """Embeds mask inputs."""
104 | mask_embedding = self.mask_downscaling(masks)
105 | return mask_embedding
106 |
107 | def _get_batch_size(
108 | self,
109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110 | boxes: Optional[torch.Tensor],
111 | masks: Optional[torch.Tensor],
112 | ) -> int:
113 | """
114 | Gets the batch size of the output given the batch size of the input prompts.
115 | """
116 | if points is not None:
117 | return points[0].shape[0]
118 | elif boxes is not None:
119 | return boxes.shape[0]
120 | elif masks is not None:
121 | return masks.shape[0]
122 | else:
123 | return 1
124 |
125 | def _get_device(self) -> torch.device:
126 | return self.point_embeddings[0].weight.device
127 |
128 | def forward(
129 | self,
130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131 | boxes: Optional[torch.Tensor],
132 | masks: Optional[torch.Tensor],
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """
135 | Embeds different types of prompts, returning both sparse and dense
136 | embeddings.
137 |
138 | Arguments:
139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140 | and labels to embed.
141 | boxes (torch.Tensor or none): boxes to embed
142 | masks (torch.Tensor or none): masks to embed
143 |
144 | Returns:
145 | torch.Tensor: sparse embeddings for the points and boxes, with shape
146 | BxNx(embed_dim), where N is determined by the number of input points
147 | and boxes.
148 | torch.Tensor: dense embeddings for the masks, in the shape
149 | Bx(embed_dim)x(embed_H)x(embed_W)
150 | """
151 | bs = self._get_batch_size(points, boxes, masks)
152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153 | if points is not None:
154 | coords, labels = points
155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 | if boxes is not None:
158 | box_embeddings = self._embed_boxes(boxes)
159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160 |
161 | if masks is not None:
162 | dense_embeddings = self._embed_masks(masks)
163 | else:
164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166 | )
167 |
168 | return sparse_embeddings, dense_embeddings
169 |
170 |
171 | class PositionEmbeddingRandom(nn.Module):
172 | """
173 | Positional encoding using random spatial frequencies.
174 | """
175 |
176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177 | super().__init__()
178 | if scale is None or scale <= 0.0:
179 | scale = 1.0
180 | self.register_buffer(
181 | "positional_encoding_gaussian_matrix",
182 | scale * torch.randn((2, num_pos_feats)),
183 | )
184 |
185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186 | """Positionally encode points that are normalized to [0,1]."""
187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188 | coords = 2 * coords - 1
189 | coords = coords @ self.positional_encoding_gaussian_matrix
190 | coords = 2 * np.pi * coords
191 | # outputs d_1 x ... x d_n x C shape
192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193 |
194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195 | """Generate positional encoding for a grid of the specified size."""
196 | h, w = size
197 | device: Any = self.positional_encoding_gaussian_matrix.device
198 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
199 | y_embed = grid.cumsum(dim=0) - 0.5
200 | x_embed = grid.cumsum(dim=1) - 0.5
201 | y_embed = y_embed / h
202 | x_embed = x_embed / w
203 |
204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205 | return pe.permute(2, 0, 1) # C x H x W
206 |
207 | def forward_with_coords(
208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209 | ) -> torch.Tensor:
210 | """Positionally encode points that are not normalized to [0,1]."""
211 | coords = coords_input.clone()
212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
215 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # modified by ziqi-jin
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from typing import Any, Dict, List, Tuple
14 |
15 | from .image_encoder import ImageEncoderViT
16 | from .mask_decoder import MaskDecoder
17 | from .prompt_encoder import PromptEncoder
18 |
19 |
20 | class Sam(nn.Module):
21 | mask_threshold: float = 0.0
22 | image_format: str = "RGB"
23 |
24 | def __init__(
25 | self,
26 | image_encoder: ImageEncoderViT,
27 | prompt_encoder: PromptEncoder,
28 | mask_decoder: MaskDecoder,
29 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
30 | pixel_std: List[float] = [58.395, 57.12, 57.375],
31 | ) -> None:
32 | """
33 | SAM predicts object masks from an image and input prompts.
34 |
35 | Arguments:
36 | image_encoder (ImageEncoderViT): The backbone used to encode the
37 | image into image embeddings that allow for efficient mask prediction.
38 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
39 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
40 | and encoded prompts.
41 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
42 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
43 | """
44 | super().__init__()
45 | self.image_encoder = image_encoder
46 | self.prompt_encoder = prompt_encoder
47 | self.mask_decoder = mask_decoder
48 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
49 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
50 |
51 | @property
52 | def device(self) -> Any:
53 | return self.pixel_mean.device
54 |
55 | def forward(
56 | self,
57 | batched_input: List[Dict[str, Any]],
58 | multimask_output: bool,
59 | ) -> List[Dict[str, torch.Tensor]]:
60 | """
61 | Predicts masks end-to-end from provided images and prompts.
62 | If prompts are not known in advance, using SamPredictor is
63 | recommended over calling the model directly.
64 |
65 | Arguments:
66 | batched_input (list(dict)): A list over input images, each a
67 | dictionary with the following keys. A prompt key can be
68 | excluded if it is not present.
69 | 'image': The image as a torch tensor in 3xHxW format,
70 | already transformed for input to the model.
71 | 'original_size': (tuple(int, int)) The original size of
72 | the image before transformation, as (H, W).
73 | 'point_coords': (torch.Tensor) Batched point prompts for
74 | this image, with shape BxNx2. Already transformed to the
75 | input frame of the model.
76 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
77 | with shape BxN.
78 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
79 | Already transformed to the input frame of the model.
80 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
81 | in the form Bx1xHxW.
82 | multimask_output (bool): Whether the model should predict multiple
83 | disambiguating masks, or return a single mask.
84 |
85 | Returns:
86 | (list(dict)): A list over input images, where each element is
87 | as dictionary with the following keys.
88 | 'masks': (torch.Tensor) Batched binary mask predictions,
89 | with shape BxCxHxW, where B is the number of input promts,
90 | C is determiend by multimask_output, and (H, W) is the
91 | original size of the image.
92 | 'iou_predictions': (torch.Tensor) The model's predictions
93 | of mask quality, in shape BxC.
94 | 'low_res_logits': (torch.Tensor) Low resolution logits with
95 | shape BxCxHxW, where H=W=256. Can be passed as mask input
96 | to subsequent iterations of prediction.
97 | """
98 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
99 | image_embeddings = self.image_encoder(input_images)
100 |
101 | outputs = []
102 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
103 | if "point_coords" in image_record:
104 | points = (image_record["point_coords"], image_record["point_labels"])
105 | else:
106 | points = None
107 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
108 | points=points,
109 | boxes=image_record.get("boxes", None),
110 | masks=image_record.get("mask_inputs", None),
111 | )
112 | low_res_masks, iou_predictions = self.mask_decoder(
113 | image_embeddings=curr_embedding.unsqueeze(0),
114 | image_pe=self.prompt_encoder.get_dense_pe(),
115 | sparse_prompt_embeddings=sparse_embeddings,
116 | dense_prompt_embeddings=dense_embeddings,
117 | multimask_output=multimask_output,
118 | )
119 | masks = self.postprocess_masks(
120 | low_res_masks,
121 | input_size=image_record["image"].shape[-2:],
122 | original_size=image_record["original_size"],
123 | )
124 | masks = masks > self.mask_threshold
125 | outputs.append(
126 | {
127 | "masks": masks,
128 | "iou_predictions": iou_predictions,
129 | "low_res_logits": low_res_masks,
130 | }
131 | )
132 | return outputs
133 |
134 | def postprocess_masks(
135 | self,
136 | masks: torch.Tensor,
137 | input_size: Tuple[int, ...],
138 | original_size: Tuple[int, ...],
139 | ) -> torch.Tensor:
140 | """
141 | Remove padding and upscale masks to the original image size.
142 |
143 | Arguments:
144 | masks (torch.Tensor): Batched masks from the mask_decoder,
145 | in BxCxHxW format.
146 | input_size (tuple(int, int)): The size of the image input to the
147 | model, in (H, W) format. Used to remove padding.
148 | original_size (tuple(int, int)): The original size of the image
149 | before resizing for input to the model, in (H, W) format.
150 |
151 | Returns:
152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
153 | is given by original_size.
154 | """
155 | masks = F.interpolate(
156 | masks,
157 | (self.image_encoder.img_size, self.image_encoder.img_size),
158 | mode="bilinear",
159 | align_corners=False,
160 | )
161 | masks = masks[..., : input_size[0], : input_size[1]]
162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
163 | return masks
164 |
165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
166 | """Normalize pixel values and pad to a square input."""
167 | # Normalize colors
168 | x = (x - self.pixel_mean) / self.pixel_std
169 |
170 | # Pad
171 | h, w = x.shape[-2:]
172 | padh = self.image_encoder.img_size - h
173 | padw = self.image_encoder.img_size - w
174 | x = F.pad(x, (0, padw, 0, padh))
175 | return x
176 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attenion layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from extend_sam.segment_anything_ori.modeling import Sam
11 |
12 | from typing import Optional, Tuple
13 |
14 | from .utils.transforms import ResizeLongestSide
15 |
16 |
17 | class SamPredictor:
18 | def __init__(
19 | self,
20 | sam_model: Sam,
21 | ) -> None:
22 | """
23 | Uses SAM to calculate the image embedding for an image, and then
24 | allow repeated, efficient mask prediction given prompts.
25 |
26 | Arguments:
27 | sam_model (Sam): The model to use for mask prediction.
28 | """
29 | super().__init__()
30 | self.model = sam_model
31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32 | self.reset_image()
33 |
34 | def set_image(
35 | self,
36 | image: np.ndarray,
37 | image_format: str = "RGB",
38 | ) -> None:
39 | """
40 | Calculates the image embeddings for the provided image, allowing
41 | masks to be predicted with the 'predict' method.
42 |
43 | Arguments:
44 | image (np.ndarray): The image for calculating masks. Expects an
45 | image in HWC uint8 format, with pixel values in [0, 255].
46 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
47 | """
48 | assert image_format in [
49 | "RGB",
50 | "BGR",
51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52 | if image_format != self.model.image_format:
53 | image = image[..., ::-1]
54 |
55 | # Transform the image to the form expected by the model
56 | input_image = self.transform.apply_image(image)
57 | input_image_torch = torch.as_tensor(input_image, device=self.device)
58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59 |
60 | self.set_torch_image(input_image_torch, image.shape[:2])
61 |
62 | @torch.no_grad()
63 | def set_torch_image(
64 | self,
65 | transformed_image: torch.Tensor,
66 | original_image_size: Tuple[int, ...],
67 | ) -> None:
68 | """
69 | Calculates the image embeddings for the provided image, allowing
70 | masks to be predicted with the 'predict' method. Expects the input
71 | image to be already transformed to the format expected by the model.
72 |
73 | Arguments:
74 | transformed_image (torch.Tensor): The input image, with shape
75 | 1x3xHxW, which has been transformed with ResizeLongestSide.
76 | original_image_size (tuple(int, int)): The size of the image
77 | before transformation, in (H, W) format.
78 | """
79 | assert (
80 | len(transformed_image.shape) == 4
81 | and transformed_image.shape[1] == 3
82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 | self.reset_image()
85 |
86 | self.original_size = original_image_size
87 | self.input_size = tuple(transformed_image.shape[-2:])
88 | input_image = self.model.preprocess(transformed_image)
89 | self.features = self.model.image_encoder(input_image)
90 | self.is_image_set = True
91 |
92 | def predict(
93 | self,
94 | point_coords: Optional[np.ndarray] = None,
95 | point_labels: Optional[np.ndarray] = None,
96 | box: Optional[np.ndarray] = None,
97 | mask_input: Optional[np.ndarray] = None,
98 | multimask_output: bool = True,
99 | return_logits: bool = False,
100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
101 | """
102 | Predict masks for the given input prompts, using the currently set image.
103 |
104 | Arguments:
105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106 | model. Each point is in (X,Y) in pixels.
107 | point_labels (np.ndarray or None): A length N array of labels for the
108 | point prompts. 1 indicates a foreground point and 0 indicates a
109 | background point.
110 | box (np.ndarray or None): A length 4 array given a box prompt to the
111 | model, in XYXY format.
112 | mask_input (np.ndarray): A low resolution mask input to the model, typically
113 | coming from a previous prediction iteration. Has form 1xHxW, where
114 | for SAM, H=W=256.
115 | multimask_output (bool): If true, the model will return three masks.
116 | For ambiguous input prompts (such as a single click), this will often
117 | produce better masks than a single prediction. If only a single
118 | mask is needed, the model's predicted quality score can be used
119 | to select the best mask. For non-ambiguous prompts, such as multiple
120 | input prompts, multimask_output=False can give better results.
121 | return_logits (bool): If true, returns un-thresholded masks logits
122 | instead of a binary mask.
123 |
124 | Returns:
125 | (np.ndarray): The output masks in CxHxW format, where C is the
126 | number of masks, and (H, W) is the original image size.
127 | (np.ndarray): An array of length C containing the model's
128 | predictions for the quality of each mask.
129 | (np.ndarray): An array of shape CxHxW, where C is the number
130 | of masks and H=W=256. These low resolution logits can be passed to
131 | a subsequent iteration as mask input.
132 | """
133 | if not self.is_image_set:
134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135 |
136 | # Transform input prompts
137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138 | if point_coords is not None:
139 | assert (
140 | point_labels is not None
141 | ), "point_labels must be supplied if point_coords is supplied."
142 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146 | if box is not None:
147 | box = self.transform.apply_boxes(box, self.original_size)
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149 | box_torch = box_torch[None, :]
150 | if mask_input is not None:
151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152 | mask_input_torch = mask_input_torch[None, :, :, :]
153 |
154 | masks, iou_predictions, low_res_masks = self.predict_torch(
155 | coords_torch,
156 | labels_torch,
157 | box_torch,
158 | mask_input_torch,
159 | multimask_output,
160 | return_logits=return_logits,
161 | )
162 |
163 | masks = masks[0].detach().cpu().numpy()
164 | iou_predictions = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks = low_res_masks[0].detach().cpu().numpy()
166 | return masks, iou_predictions, low_res_masks
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | """
179 | Predict masks for the given input prompts, using the currently set image.
180 | Input prompts are batched torch tensors and are expected to already be
181 | transformed to the input frame using ResizeLongestSide.
182 |
183 | Arguments:
184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (torch.Tensor or None): A BxN array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | box (np.ndarray or None): A Bx4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form Bx1xHxW, where
193 | for SAM, H=W=256. Masks returned by a previous iteration of the
194 | predict method do not need further transformation.
195 | multimask_output (bool): If true, the model will return three masks.
196 | For ambiguous input prompts (such as a single click), this will often
197 | produce better masks than a single prediction. If only a single
198 | mask is needed, the model's predicted quality score can be used
199 | to select the best mask. For non-ambiguous prompts, such as multiple
200 | input prompts, multimask_output=False can give better results.
201 | return_logits (bool): If true, returns un-thresholded masks logits
202 | instead of a binary mask.
203 |
204 | Returns:
205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
206 | number of masks, and (H, W) is the original image size.
207 | (torch.Tensor): An array of shape BxC containing the model's
208 | predictions for the quality of each mask.
209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
210 | of masks and H=W=256. These low res logits can be passed to
211 | a subsequent iteration as mask input.
212 | """
213 | if not self.is_image_set:
214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 |
216 | if point_coords is not None:
217 | points = (point_coords, point_labels)
218 | else:
219 | points = None
220 |
221 | # Embed prompts
222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 | points=points,
224 | boxes=boxes,
225 | masks=mask_input,
226 | )
227 |
228 | # Predict masks
229 | low_res_masks, iou_predictions = self.model.mask_decoder(
230 | image_embeddings=self.features,
231 | image_pe=self.model.prompt_encoder.get_dense_pe(),
232 | sparse_prompt_embeddings=sparse_embeddings,
233 | dense_prompt_embeddings=dense_embeddings,
234 | multimask_output=multimask_output,
235 | )
236 |
237 | # Upscale the masks to the original image resolution
238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239 |
240 | if not return_logits:
241 | masks = masks > self.model.mask_threshold
242 |
243 | return masks, iou_predictions, low_res_masks
244 |
245 | def get_image_embedding(self) -> torch.Tensor:
246 | """
247 | Returns the image embeddings for the currently set image, with
248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
250 | """
251 | if not self.is_image_set:
252 | raise RuntimeError(
253 | "An image must be set with .set_image(...) to generate an embedding."
254 | )
255 | assert self.features is not None, "Features must exist if an image has been set."
256 | return self.features
257 |
258 | @property
259 | def device(self) -> torch.device:
260 | return self.model.device
261 |
262 | def reset_image(self) -> None:
263 | """Resets the currently set image."""
264 | self.is_image_set = False
265 | self.features = None
266 | self.orig_h = None
267 | self.orig_w = None
268 | self.input_h = None
269 | self.input_w = None
270 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | import math
11 | from copy import deepcopy
12 | from itertools import product
13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
14 |
15 |
16 | class MaskData:
17 | """
18 | A structure for storing masks and their related data in batched format.
19 | Implements basic filtering and concatenation.
20 | """
21 |
22 | def __init__(self, **kwargs) -> None:
23 | for v in kwargs.values():
24 | assert isinstance(
25 | v, (list, np.ndarray, torch.Tensor)
26 | ), "MaskData only supports list, numpy arrays, and torch tensors."
27 | self._stats = dict(**kwargs)
28 |
29 | def __setitem__(self, key: str, item: Any) -> None:
30 | assert isinstance(
31 | item, (list, np.ndarray, torch.Tensor)
32 | ), "MaskData only supports list, numpy arrays, and torch tensors."
33 | self._stats[key] = item
34 |
35 | def __delitem__(self, key: str) -> None:
36 | del self._stats[key]
37 |
38 | def __getitem__(self, key: str) -> Any:
39 | return self._stats[key]
40 |
41 | def items(self) -> ItemsView[str, Any]:
42 | return self._stats.items()
43 |
44 | def filter(self, keep: torch.Tensor) -> None:
45 | for k, v in self._stats.items():
46 | if v is None:
47 | self._stats[k] = None
48 | elif isinstance(v, torch.Tensor):
49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50 | elif isinstance(v, np.ndarray):
51 | self._stats[k] = v[keep.detach().cpu().numpy()]
52 | elif isinstance(v, list) and keep.dtype == torch.bool:
53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54 | elif isinstance(v, list):
55 | self._stats[k] = [v[i] for i in keep]
56 | else:
57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58 |
59 | def cat(self, new_stats: "MaskData") -> None:
60 | for k, v in new_stats.items():
61 | if k not in self._stats or self._stats[k] is None:
62 | self._stats[k] = deepcopy(v)
63 | elif isinstance(v, torch.Tensor):
64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65 | elif isinstance(v, np.ndarray):
66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67 | elif isinstance(v, list):
68 | self._stats[k] = self._stats[k] + deepcopy(v)
69 | else:
70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71 |
72 | def to_numpy(self) -> None:
73 | for k, v in self._stats.items():
74 | if isinstance(v, torch.Tensor):
75 | self._stats[k] = v.detach().cpu().numpy()
76 |
77 |
78 | def is_box_near_crop_edge(
79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80 | ) -> torch.Tensor:
81 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88 | return torch.any(near_crop_edge, dim=1)
89 |
90 |
91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92 | box_xywh = deepcopy(box_xyxy)
93 | box_xywh[2] = box_xywh[2] - box_xywh[0]
94 | box_xywh[3] = box_xywh[3] - box_xywh[1]
95 | return box_xywh
96 |
97 |
98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99 | assert len(args) > 0 and all(
100 | len(a) == len(args[0]) for a in args
101 | ), "Batched iteration must have inputs of all the same size."
102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103 | for b in range(n_batches):
104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105 |
106 |
107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108 | """
109 | Encodes masks to an uncompressed RLE, in the format expected by
110 | pycoco tools.
111 | """
112 | # Put in fortran order and flatten h,w
113 | b, h, w = tensor.shape
114 | tensor = tensor.permute(0, 2, 1).flatten(1)
115 |
116 | # Compute change indices
117 | diff = tensor[:, 1:] ^ tensor[:, :-1]
118 | change_indices = diff.nonzero()
119 |
120 | # Encode run length
121 | out = []
122 | for i in range(b):
123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124 | cur_idxs = torch.cat(
125 | [
126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127 | cur_idxs + 1,
128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | ]
130 | )
131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132 | counts = [] if tensor[i, 0] == 0 else [0]
133 | counts.extend(btw_idxs.detach().cpu().tolist())
134 | out.append({"size": [h, w], "counts": counts})
135 | return out
136 |
137 |
138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139 | """Compute a binary mask from an uncompressed RLE."""
140 | h, w = rle["size"]
141 | mask = np.empty(h * w, dtype=bool)
142 | idx = 0
143 | parity = False
144 | for count in rle["counts"]:
145 | mask[idx : idx + count] = parity
146 | idx += count
147 | parity ^= True
148 | mask = mask.reshape(w, h)
149 | return mask.transpose() # Put in C order
150 |
151 |
152 | def area_from_rle(rle: Dict[str, Any]) -> int:
153 | return sum(rle["counts"][1::2])
154 |
155 |
156 | def calculate_stability_score(
157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158 | ) -> torch.Tensor:
159 | """
160 | Computes the stability score for a batch of masks. The stability
161 | score is the IoU between the binary masks obtained by thresholding
162 | the predicted mask logits at high and low values.
163 | """
164 | # One mask is always contained inside the other.
165 | # Save memory by preventing unnecesary cast to torch.int64
166 | intersections = (
167 | (masks > (mask_threshold + threshold_offset))
168 | .sum(-1, dtype=torch.int16)
169 | .sum(-1, dtype=torch.int32)
170 | )
171 | unions = (
172 | (masks > (mask_threshold - threshold_offset))
173 | .sum(-1, dtype=torch.int16)
174 | .sum(-1, dtype=torch.int32)
175 | )
176 | return intersections / unions
177 |
178 |
179 | def build_point_grid(n_per_side: int) -> np.ndarray:
180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181 | offset = 1 / (2 * n_per_side)
182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186 | return points
187 |
188 |
189 | def build_all_layer_point_grids(
190 | n_per_side: int, n_layers: int, scale_per_layer: int
191 | ) -> List[np.ndarray]:
192 | """Generates point grids for all crop layers."""
193 | points_by_layer = []
194 | for i in range(n_layers + 1):
195 | n_points = int(n_per_side / (scale_per_layer**i))
196 | points_by_layer.append(build_point_grid(n_points))
197 | return points_by_layer
198 |
199 |
200 | def generate_crop_boxes(
201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202 | ) -> Tuple[List[List[int]], List[int]]:
203 | """
204 | Generates a list of crop boxes of different sizes. Each layer
205 | has (2**i)**2 boxes for the ith layer.
206 | """
207 | crop_boxes, layer_idxs = [], []
208 | im_h, im_w = im_size
209 | short_side = min(im_h, im_w)
210 |
211 | # Original image
212 | crop_boxes.append([0, 0, im_w, im_h])
213 | layer_idxs.append(0)
214 |
215 | def crop_len(orig_len, n_crops, overlap):
216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217 |
218 | for i_layer in range(n_layers):
219 | n_crops_per_side = 2 ** (i_layer + 1)
220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221 |
222 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
223 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
224 |
225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227 |
228 | # Crops in XYWH format
229 | for x0, y0 in product(crop_box_x0, crop_box_y0):
230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231 | crop_boxes.append(box)
232 | layer_idxs.append(i_layer + 1)
233 |
234 | return crop_boxes, layer_idxs
235 |
236 |
237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238 | x0, y0, _, _ = crop_box
239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240 | # Check if boxes has a channel dimension
241 | if len(boxes.shape) == 3:
242 | offset = offset.unsqueeze(1)
243 | return boxes + offset
244 |
245 |
246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247 | x0, y0, _, _ = crop_box
248 | offset = torch.tensor([[x0, y0]], device=points.device)
249 | # Check if points has a channel dimension
250 | if len(points.shape) == 3:
251 | offset = offset.unsqueeze(1)
252 | return points + offset
253 |
254 |
255 | def uncrop_masks(
256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257 | ) -> torch.Tensor:
258 | x0, y0, x1, y1 = crop_box
259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260 | return masks
261 | # Coordinate transform masks
262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263 | pad = (x0, pad_x - x0, y0, pad_y - y0)
264 | return torch.nn.functional.pad(masks, pad, value=0)
265 |
266 |
267 | def remove_small_regions(
268 | mask: np.ndarray, area_thresh: float, mode: str
269 | ) -> Tuple[np.ndarray, bool]:
270 | """
271 | Removes small disconnected regions and holes in a mask. Returns the
272 | mask and an indicator of if the mask has been modified.
273 | """
274 | import cv2 # type: ignore
275 |
276 | assert mode in ["holes", "islands"]
277 | correct_holes = mode == "holes"
278 | working_mask = (correct_holes ^ mask).astype(np.uint8)
279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280 | sizes = stats[:, -1][1:] # Row 0 is background label
281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282 | if len(small_regions) == 0:
283 | return mask, False
284 | fill_labels = [0] + small_regions
285 | if not correct_holes:
286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287 | # If every region is below threshold, keep largest
288 | if len(fill_labels) == 0:
289 | fill_labels = [int(np.argmax(sizes)) + 1]
290 | mask = np.isin(regions, fill_labels)
291 | return mask, True
292 |
293 |
294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295 | from pycocotools import mask as mask_utils # type: ignore
296 |
297 | h, w = uncompressed_rle["size"]
298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300 | return rle
301 |
302 |
303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304 | """
305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307 | """
308 | # torch.max below raises an error on empty inputs, just skip in this case
309 | if torch.numel(masks) == 0:
310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311 |
312 | # Normalize shape to CxHxW
313 | shape = masks.shape
314 | h, w = shape[-2:]
315 | if len(shape) > 2:
316 | masks = masks.flatten(0, -3)
317 | else:
318 | masks = masks.unsqueeze(0)
319 |
320 | # Get top and bottom edges
321 | in_height, _ = torch.max(masks, dim=-1)
322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324 | in_height_coords = in_height_coords + h * (~in_height)
325 | top_edges, _ = torch.min(in_height_coords, dim=-1)
326 |
327 | # Get left and right edges
328 | in_width, _ = torch.max(masks, dim=-2)
329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330 | right_edges, _ = torch.max(in_width_coords, dim=-1)
331 | in_width_coords = in_width_coords + w * (~in_width)
332 | left_edges, _ = torch.min(in_width_coords, dim=-1)
333 |
334 | # If the mask is empty the right edge will be to the left of the left edge.
335 | # Replace these boxes with [0, 0, 0, 0]
336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338 | out = out * (~empty_filter).unsqueeze(-1)
339 |
340 | # Return to original shape
341 | if len(shape) > 2:
342 | out = out.reshape(*shape[:-2], 4)
343 | else:
344 | out = out[0]
345 |
346 | return out
347 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]]
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/extend_sam/segment_anything_ori/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/extend_sam/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | @copyright ziqi-jin
3 | '''
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import os.path as osp
9 | import os
10 |
11 |
12 | def fix_params(model):
13 | for name, param in model.named_parameters():
14 | param.requires_grad = False
15 |
16 |
17 | def load_params(model, params):
18 | pass
19 |
20 |
21 | def get_opt_pamams(model, lr_list, group_keys, wd_list):
22 | '''
23 |
24 | :param model: model
25 | :param lr_list: list, contain the lr for each params group
26 | :param wd_list: list, contain the weight decay for each params group
27 | :param group_keys: list of list, according to the sub list to divide params to different groups
28 | :return: list of dict
29 | '''
30 | assert len(lr_list) == len(group_keys), "lr_list should has the same length as group_keys"
31 | assert len(lr_list) == len(wd_list), "lr_list should has the same length as wd_list"
32 | params_group = [[] for _ in range(len(lr_list))]
33 | for name, value in model.named_parameters():
34 | for index, g_keys in enumerate(group_keys):
35 | for g_key in g_keys:
36 | if g_key in name:
37 | params_group[index].append(value)
38 | return [{'params': params_group[i], 'lr': lr_list[i], 'weight_decay': wd_list[i]} for i in range(len(lr_list))]
39 |
40 |
41 | class Timer:
42 |
43 | def __init__(self):
44 | self.start_time = 0.0
45 | self.end_time = 0.0
46 |
47 | self.start()
48 |
49 | def start(self):
50 | self.start_time = time.time()
51 |
52 | def end(self, ms=False, clear=False):
53 | self.end_time = time.time()
54 |
55 | if ms:
56 | duration = int((self.end_time - self.start_time) * 1000)
57 | else:
58 | duration = int(self.end_time - self.start_time)
59 |
60 | if clear:
61 | self.start()
62 |
63 | return duration
64 |
65 |
66 | class Average_Meter:
67 | def __init__(self, keys):
68 | self.keys = keys
69 | self.clear()
70 |
71 | def add(self, dic):
72 | for key, value in dic.items():
73 | self.data_dic[key].append(value)
74 |
75 | def get(self, keys=None, clear=False):
76 | if keys is None:
77 | keys = self.keys
78 |
79 | dataset = {}
80 | for key in keys:
81 | dataset[key] = float(np.mean(self.data_dic[key]))
82 |
83 | if clear:
84 | self.clear()
85 |
86 | return dataset
87 |
88 | def clear(self):
89 | self.data_dic = {key: [] for key in self.keys}
90 |
91 |
92 | def print_and_save_log(message, path):
93 | print(message)
94 |
95 | with open(path, 'a+') as f:
96 | f.write(message + '\n')
97 |
98 |
99 | class mIoUOnline:
100 | def __init__(self, class_names):
101 | self.class_names = ['background'] + class_names
102 | self.class_num = len(self.class_names)
103 |
104 | self.clear()
105 |
106 | def get_data(self, pred_mask, gt_mask):
107 | obj_mask = gt_mask < 255
108 | correct_mask = (pred_mask == gt_mask) * obj_mask
109 |
110 | P_list, T_list, TP_list = [], [], []
111 | for i in range(self.class_num):
112 | P_list.append(np.sum((pred_mask == i) * obj_mask))
113 | T_list.append(np.sum((gt_mask == i) * obj_mask))
114 | TP_list.append(np.sum((gt_mask == i) * correct_mask))
115 |
116 | return (P_list, T_list, TP_list)
117 |
118 | def add_using_data(self, data):
119 | P_list, T_list, TP_list = data
120 | for i in range(self.class_num):
121 | self.P[i] += P_list[i]
122 | self.T[i] += T_list[i]
123 | self.TP[i] += TP_list[i]
124 |
125 | def add(self, pred_mask, gt_mask):
126 | obj_mask = gt_mask < 255
127 | correct_mask = (pred_mask == gt_mask) * obj_mask
128 |
129 | for i in range(self.class_num):
130 | self.P[i] += np.sum((pred_mask == i) * obj_mask)
131 | self.T[i] += np.sum((gt_mask == i) * obj_mask)
132 | self.TP[i] += np.sum((gt_mask == i) * correct_mask)
133 |
134 | def get(self, detail=False, clear=True):
135 | IoU_dic = {}
136 | IoU_list = []
137 |
138 | FP_list = [] # over activation
139 | FN_list = [] # under activation
140 |
141 | for i in range(self.class_num):
142 | IoU = self.TP[i] / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) * 100
143 | FP = (self.P[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10)
144 | FN = (self.T[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10)
145 |
146 | IoU_dic[self.class_names[i]] = IoU
147 |
148 | IoU_list.append(IoU)
149 | FP_list.append(FP)
150 | FN_list.append(FN)
151 |
152 | mIoU = np.mean(np.asarray(IoU_list))
153 | mIoU_foreground = np.mean(np.asarray(IoU_list)[1:])
154 |
155 | FP = np.mean(np.asarray(FP_list))
156 | FN = np.mean(np.asarray(FN_list))
157 |
158 | if clear:
159 | self.clear()
160 |
161 | if detail:
162 | return mIoU, mIoU_foreground, IoU_dic, FP, FN
163 | else:
164 | return mIoU, mIoU_foreground
165 |
166 | def clear(self):
167 | self.TP = []
168 | self.P = []
169 | self.T = []
170 |
171 | for _ in range(self.class_num):
172 | self.TP.append(0)
173 | self.P.append(0)
174 | self.T.append(0)
175 |
176 |
177 | def get_numpy_from_tensor(tensor):
178 | return tensor.cpu().detach().numpy()
179 |
180 |
181 | def save_model(model, model_path, parallel=False, is_final=False):
182 | if is_final:
183 | model_path_split = model_path.split('.')
184 | model_path = model_path_split[0] + "_final.pth"
185 | if parallel:
186 | torch.save(model.module.state_dict(), model_path)
187 | else:
188 | torch.save(model.state_dict(), model_path)
189 |
190 |
191 | def write_log(iteration, log_path, log_data, status, writer, timer):
192 | log_data['iteration'] = iteration
193 | log_data['time'] = timer.end(clear=True)
194 | message = "iteration : {val}, ".format(val=log_data['iteration'])
195 | for key, value in log_data.items():
196 | if key == 'iteration':
197 | continue
198 | message += "{key} : {val}, ".format(key=key, val=value)
199 | message = message[:-2] # + '\n'
200 | print_and_save_log(message, log_path)
201 | # visualize
202 | if writer is not None:
203 | for key, value in log_data.items():
204 | writer.add_scalar("{status}/{key}".format(status=status, key=key), value, iteration)
205 |
206 |
207 | def check_folder(file_path, is_folder=False):
208 | '''
209 |
210 | :param file_path: the path of file, default input is a complete file name with dir path.
211 | :param is_folder: if the input is a dir, not a file_name, is_folder should be True
212 | :return: no return, this function will check and judge whether need to make dirs.
213 | '''
214 | if is_folder:
215 | if not osp.exists(is_folder):
216 | os.makedirs(file_path)
217 |
218 | else:
219 | splits = file_path.split("/")
220 | folder_name = "/".join(splits[:-1])
221 | if not osp.exists(folder_name):
222 | os.makedirs(folder_name)
223 |
224 |
225 | def one_hot_embedding_3d(labels, class_num=21):
226 | '''
227 |
228 | :param real_labels: B H W
229 | :param class_num: N
230 | :return: B N H W
231 | '''
232 | one_hot_labels = labels.clone()
233 | one_hot_labels[one_hot_labels == 255] = 0 # 0 is background
234 | return F.one_hot(one_hot_labels, num_classes=class_num).permute(0, 3, 1, 2).contiguous().float()
235 |
--------------------------------------------------------------------------------
/how_to_use_finetune_anything.md:
--------------------------------------------------------------------------------
1 | # How to use finetune-anything
2 | finetune-anything (FA) is intended as a tool to help users quickly build extended SAM models. It not only supports the built-in basic tasks and basic models, but also supports user-defined extensions of different modules, training processes, and datasets for the extend SAM.
3 |
4 | - Content
5 | - [Structure](#Structure)
6 | - [Model](#Model)
7 | - [Datasets](#Datasets)
8 | - [Losses](#Losses)
9 | - [Optimizer](#Optimizer)
10 | - [Runner](#Runner)
11 | - [Logger](#Logger)
12 | - [One more thing](#One-more-thing)
13 |
14 |
15 | ## Structure
16 | Using FA can be divided into two parts: training and testing. The training part includes [model](#Model), [Datasets](#Datasets), [Losses](#Losses), [Optimizer](#Optimizer), [Logger](#Logger), and [Runner](#Runner).
17 | The above content needs to be configured through the yaml file in `config`.
18 | - The tasks already supported by FA can be trained and tested directly by inputting `task_name`.
19 | ```
20 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --task_name ${one of supported task names}
21 | ```
22 | - Custom configuration files can be trained and tested by reading `cfg`
23 | ```
24 | CUDA_VISIBLE_DEVICES=${your GPU number} python train.py --cfg config/${yaml file name}
25 | ```
26 | The testing part is coming soon ~
27 |
28 | ## Model
29 | The SAM model includes image encdoer, prompt encoder and mask decoder. FA further encapsulates the encoder and decoder of SAM and identify Extend-SAM model consists of image encoder adapter, prompt encoder adapter and mask decoder adapter. The initialized process of Extend-SAM as below,
30 |
31 |
32 | Users can choose the adapter that need to be fixed or learned during the finetune process. This function can be configured in the `model` part of the yaml file, as shown in the following example:
33 |
34 | ```yaml
35 | model:
36 | sam_name: 'extend sam name' # e.g., 'sem_sam', custom SAM model name, you should implement this model('sem_sam') first
37 | params:
38 | # Fix the a part of parameters in SAM
39 | fix_img_en: True # fix image encoder adapter parameters
40 | fix_prompt_en: True # fix prompt encoder adapter parameters
41 | fix_mask_de: False # unfix mask decoder adapter parameters to learn
42 | ckpt_path: 'your original sam weights' # e.g., 'sam_ckpt/sam_vit_b_01ec64.pth'
43 | class_num: 21 # number of classes for your dataset(20) + background(1)
44 | model_type: 'vit_b' # type should be in [vit_h, vit_b, vit_l, default], this is original SAM type
45 | # related to different original SAM model. the type should be corresponded to the ckpt_path
46 | ```
47 | ### Customized Model
48 | If you need to redesign the structure of a certain module of SAM, you need to write code according to the following three steps. Take [SemanticSAM](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/extend_sam.py#L43) as an example.
49 | - step1
50 |
51 | First, inherit the corresponding adapter base class in `extend_sam\xxx_(encoder or decoder)_adapter.py`, and then implement the `__init__` and `forward` function corresponding to the adapter.
52 | ```python
53 | class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
54 | def __init__(self, ori_sam: Sam, fix=False, class_num=20):
55 | super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix) # init super class
56 | self.decoder_neck = MaskDecoderNeck(...) # custom module
57 | self.decoder_head = SemSegHead(...) # custom module
58 | # pair the params between ori mask_decoder and new mask_decoder_adapter
59 | self.pair_params(self.decoder_neck) # give the weights which are with the same name in original SAM to customized module
60 | self.pair_params(self.decoder_head)
61 |
62 | def forward(self, ...):
63 | ... = self.decoder_neck(...)
64 | masks, iou_pred = self.decoder_head(...)
65 | return masks, iou_pred
66 | ```
67 | - step2
68 |
69 | First inherit the BaseExtendSAM base class in [extend_sam.py](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/extend_sam.py#L43), and make necessary modifications to `__init__` function.
70 | ```python
71 | class SemanticSam(BaseExtendSam):
72 |
73 | def __init__(self, ...):
74 | super().__init__(...) # init super class
75 | self.mask_adapter = SemMaskDecoderAdapter(...) # replace original Adapter as the new identified customized Adapter
76 | ```
77 | - step3
78 |
79 | Add new Extend-SAM class to [AVAI_MODEL](https://github.com/ziqi-jin/finetune-anything/blob/350c1fbf7f122a8525e7ffdecc40f259b262983f/extend_sam/__init__.py#L10) dict and give it a key.
80 | then you can train this new model by modify the `sam_name` in config file.
81 |
82 | ## Datasets
83 |
84 | FA comes with datasets for multiple tasks, and also supports custom datasets, and sets the training and test datasets separately. Takes `torch_voc_sem` as an example, the configuration file of the dataset part is as follows,
85 | The dataset part includes `name`, `params`, `transforms` and `target_transforms`,
86 | The `params` which is a `dict` include the key and value your want to set about the init function's parameters of corresponding dataset. make sure the dataset has parameters with the same names as the key.
87 | `transforms` and `target_transforms` respectively correspond to the input image and Ground Truth for transform processing.
88 | `transforms/target_transforms` support to set the implemented transform function and the corresponding `params`, `params` are still in the form of a `dict`, and transform will process the datasets according to the input order of the configuration file.
89 | ```yaml
90 | # Dataset
91 | dataset:
92 | name: 'torch_voc_sem'
93 | params:
94 | root: '/your/dataset/path/'
95 | year: '2012'
96 | image_set: 'train'
97 | transforms:
98 | resize:
99 | params:
100 | size: [1024, 1024]
101 | to_tensor:
102 | params: ~ # no parameters, set to '~'
103 | target_transforms:
104 | resize:
105 | params:
106 | size: [1024, 1024]
107 | ```
108 |
109 | ### Customized Dataset
110 |
111 | ### Customized Transform
112 |
113 | If you want to customize the transform, you can follow the following three steps,
114 |
115 | - step1
116 |
117 | - Torch-supported transform, skip this step.
118 |
119 | - Torch-unsupported transform
120 |
121 | Create it in [datasets/transforms.py](https://github.com/ziqi-jin/finetune-anything/blob/main/datasets/transforms.py), implement the `__init__` and `forward` function.
122 |
123 | ```python
124 | import torch.nn as nn
125 | class CustomTransform(nn.Module):
126 | def __init__(self):
127 | # identify your init process here
128 | def forward(self):
129 | # identify your transform process here
130 | ```
131 |
132 |
133 | - step2
134 |
135 | Import torch-supported transform you want or torch-unsupported transform your identify in [datasets/transforms.py](https://github.com/ziqi-jin/finetune-anything/blob/main/datasets/transforms.py).
136 | Then add this transform into the AVIAL_TRANSFORM dict, give this transform a key like `resize`, and the value is the transform class.
137 |
138 | ```python
139 | import torchvision.transforms as T
140 | AVIAL_TRANSFORM = {'your_transform_name': T.XXX, 'your_transform_name': CustomTransform}
141 | ```
142 |
143 | - step3
144 |
145 | Set the loss in your config file.
146 | ```yaml
147 | transforms:
148 | your_transform_name:
149 | params: # if there are parameters of the transform's __init__ function to be set. else set to '~'
150 | params_1: xxx
151 | params_2: xxx
152 | ```
153 |
154 | ## Losses
155 |
156 | FA supports multiple torch loss functions, and also allows users to customize the loss function. The configuration content of the loss function part is as below,
157 | ```yaml
158 | losses:
159 | ce:
160 | weight: 0.5
161 | params: # the initial params of loss could be identified here
162 | ignore_index: 255
163 | label_one_hot: False
164 | mse:
165 | weight: 5.0
166 | params: ~ # no parameters, set '~'
167 | label_one_hot: True
168 | ```
169 | Now loss part has `weight`, `params`, and `label_one_hot` keys, `weight` control the weight of each loss in total loss. Take the config above as example, assume the `ce` loss as $Loss_{ce}$ and the `mse` as $Loss_{mse}$, the final total loss as below,
170 |
171 | $$
172 | Loss_{total} = weight_{ce} \times Loss_{ce} + weight_{mse} \times Loss_{mse} = 0.5 \times Loss_{ce} + 5.0 \times Loss_{mse}
173 | $$
174 |
175 | The `params` which is a `dict` include the key and value your want to set about the corresponding loss function's parameters, make sure the loss function has parameters with the same names as the key. if you don't need the set params, give params `~`.
176 | for semantic segmentation task, if your loss function need a one hot label, set the `label_one_hot` to `True`.
177 |
178 |
179 | ### Customized Losses
180 |
181 | If you want to customize the loss function, you can follow the following three steps,
182 |
183 | - step1
184 |
185 | - Torch-supported Loss, skip this step.
186 |
187 | - Torch-unsupported Loss
188 |
189 | Create it in [loss.py](https://github.com/ziqi-jin/finetune-anything/blob/main/losses/losses.py), implement the `__init__` and `forward` function.
190 |
191 | ```python
192 | import torch.nn as nn
193 | class CustormLoss(nn.Module):
194 | def __init__(self,xxx):
195 | # identify your init process here
196 | def forward(self, x, y, xxx):
197 | # identify your forward process here
198 | ```
199 |
200 |
201 | - step2
202 |
203 | Import torch-supported loss you want or torch-unsupported loss your identify in [losses/\_\_init\_\_,py](https://github.com/ziqi-jin/finetune-anything/blob/26b9ebd1b035a2f0ec8ce4e358eac79de7e263a2/losses/__init__.py#L2).
204 | Then add this loss into the AVAI_LOSS dict, give this loss a key like `ce`, and the value is the loss function.
205 |
206 | ```python
207 |
208 | import torch.nn as nn
209 | from .losses import YourCuntomLoss
210 | AVAI_LOSS = {'your loss key': YourCuntomLoss, 'your loss key': nn.xxxLoss}
211 | ```
212 |
213 | - step3
214 |
215 | Set the loss in your config file.
216 |
217 | ```yaml
218 | losses:
219 | your_loss_key:
220 | weight: your_weight # float
221 | params:
222 | your_loss_param1: xx
223 | your_loss_param2: xx
224 | label_one_hot: False
225 | ```
226 |
227 | ## Optimizer
228 | FA's optimizer supports setting learning_rate(`lr`) and weight_decay(`wd`) for any module in the adapter that is not fixed.
229 | User could use keyword `sgd`, `adam`, and `adamw` to set the optimizer. the `opt_params` save necessary params for each kind of optimizer.
230 | - Normal module setting
231 |
232 | `lr_default` save the default learing rate for all unfixed params, `wd_default` save the default weight decay for all unfixed params,
233 | `momentum` save the momentum for optimizer. if the corresponding optimizer has no parameter, e.g., `adam` has no `momentum`, just set the `momentum` to `~`.
234 | - Specific module setting
235 |
236 | The left three params `group_keys`, `lr_list` and `wd_list` is for specific module.
237 | They are list have the same length and correspond to the module name, learning rate and weight decay respectively.
238 | for example, if you want to give `mask_adapter.decoder_head.output_hypernetworks_mlps` module a specific optimizing parameter, put it into `group_keys` as a list first, and then set the corresponding learning rate and weight decay into `lr_list` and `wd_list`.
239 | If there are multiple modules that need to use the same specific parameter setting, just add the key to the corresponding list in the `group_keys`. For example, add `modulexxx` to the first list of `group_keys`.
240 | ```yaml
241 | # Optimizer
242 | opt_params:
243 | lr_default: 1e-3
244 | wd_default: 1e-4
245 | momentum: 0.9
246 | group_keys: [ [ 'mask_adapter.decoder_head.output_hypernetworks_mlps', 'modulexxx' ], ['second_module'], ]
247 | lr_list: [ 1e-2, 1e-4, ]
248 | wd_list: [ 0.0, 0.1, ]
249 | opt_name: 'sgd' # 'sgd'
250 | scheduler_name: 'cosine'
251 | ```
252 | FA also supports multiple schedulers, which can be set using the keyword `single_step`, `multi_step`, `warmup_multi_step`, `cosine`, `linear`.
253 | ## Runner
254 |
255 | ## Logger
256 | As shown in the config file, FA provides two kinds of loggers, one is the log output by default and will be saved in `log_folder`, and the other is the log output of tensorboard saved in `tensorboard_folder` when `use_tensorboard` is `True`.
257 | The best model will be saved in `model_folder`.
258 | ```yaml
259 | # Logger
260 | use_tensorboard: True
261 | tensorboard_folder: './experiment/tensorboard'
262 | log_folder: './experiment/log'
263 | model_folder: './experiment/model'
264 | ```
265 |
266 | ## One more thing
267 |
268 | If you need to use loss, dataset, or other functions that are not supported by FA, please submit an issue, and I will help you to implement them. At the same time, developers are also welcome to develop new loss, dataset or other new functions for FA, please submit your PR (pull requests).
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .losses import CustormLoss
3 |
4 | AVAI_LOSS = {'ce': nn.CrossEntropyLoss, 'multi_label_soft_margin': nn.MultiLabelSoftMarginLoss,
5 | 'test_custom': CustormLoss, 'mse': nn.MSELoss}
6 |
7 |
8 | def get_losses(losses):
9 | loss_dict = {}
10 | for name in losses:
11 | assert name in AVAI_LOSS, print('{name} is not supported, please implement it first.'.format(name=name))
12 | if losses[name].params is not None:
13 | loss_dict[name] = AVAI_LOSS[name](**losses[name].params)
14 | else:
15 | loss_dict[name] = AVAI_LOSS[name]()
16 | return loss_dict
17 |
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | '''
2 | @copyright ziqi-jin
3 | You can create custom loss function in this file, then import the created loss in ./__init__.py and add the loss into AVAI_LOSS
4 | '''
5 | import torch.nn as nn
6 |
7 |
8 | # example
9 | class CustormLoss(nn.Module):
10 | def __init__(self):
11 | pass
12 |
13 | def forward(self, x, y):
14 | pass
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.2
2 | omegaconf==2.3.0
3 | opencv_python==4.7.0.72
4 | pandas==2.0.1
5 | Pillow==9.5.0
6 | torch==1.7.1
7 | torchvision==0.8.2
8 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @copyright ziqi-jin
3 | '''
4 | import argparse
5 | from omegaconf import OmegaConf
6 | from torch.utils.data import DataLoader
7 | from datasets import get_dataset
8 | from losses import get_losses
9 | from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
10 |
11 | supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--task_name', default='semantic_seg', type=str)
14 | parser.add_argument('--cfg', default=None, type=str)
15 |
16 | if __name__ == '__main__':
17 | args = parser.parse_args()
18 | task_name = args.task_name
19 | if args.cfg is not None:
20 | config = OmegaConf.load(args.cfg)
21 | else:
22 | assert task_name in supported_tasks, "Please input the supported task name."
23 | config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
24 |
25 | train_cfg = config.train
26 | val_cfg = config.val
27 | test_cfg = config.test
28 |
29 | train_dataset = get_dataset(train_cfg.dataset)
30 | train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers,
31 | drop_last=train_cfg.drop_last)
32 | val_dataset = get_dataset(val_cfg.dataset)
33 | val_loader = DataLoader(val_dataset, batch_size=val_cfg.bs, shuffle=False, num_workers=val_cfg.num_workers,
34 | drop_last=val_cfg.drop_last)
35 | losses = get_losses(losses=train_cfg.losses)
36 | # according the model name to get the adapted model
37 | model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
38 | opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys,
39 | wd_list=train_cfg.opt_params.wd_list)
40 | optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default,
41 | momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default)
42 | scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name)
43 | runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler)
44 | # train_step
45 | runner.train(train_cfg)
46 | if test_cfg.need_test:
47 | runner.test(test_cfg)
48 |
--------------------------------------------------------------------------------