├── .gitignore
├── README.md
├── datasets
└── dataset_ufpr_sam.py
├── lora_predictor.py
├── sam_lora_image_encoder.py
├── sam_lora_image_encoder_mask_decoder.py
├── segment_anything
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ └── transforms.py
├── test.py
├── train.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | SamLP is a license plate detector based on visual foundation model. We fine-tune the Segment Anything Model (SAM) to license plate detection task.
2 |
3 | Thanks to the previous works:
4 |
5 |
6 |
--------------------------------------------------------------------------------
/datasets/dataset_ufpr_sam.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any
3 | import numpy as np
4 | import cv2
5 | import torch.utils.data as data
6 | import torch
7 | import torchvision.transforms as transforms
8 | import PIL.Image as Image
9 | import random
10 | from einops import repeat
11 | from icecream import ic
12 | from scipy.ndimage.interpolation import zoom
13 | from scipy import ndimage
14 | from torchvision.transforms.functional import resize, to_pil_image, rotate, hflip, vflip # type: ignore
15 | from torch.nn import functional as F
16 |
17 |
18 | def random_rot_flip_torch(image, label):
19 | k = np.random.randint(0, 4)
20 | image = rotate(image, k*90)
21 | label = rotate(label, k*90)
22 | axis = np.random.randint(0, 2)
23 | if axis == 0:
24 | image = hflip(image)
25 | label = hflip(label)
26 | elif axis == 1:
27 | image = vflip(image)
28 | label = vflip(label)
29 | return image, label
30 |
31 | def random_rotate_torch(image, label):
32 | angle = np.random.randint(-20, 20)
33 | image = rotate(image, angle)
34 | label = rotate(label, angle)
35 | return image, label
36 |
37 |
38 | class SamTransformTest:
39 | def __init__(self, target_length):
40 | self.target_length = target_length
41 | self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
42 | self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
43 |
44 | def get_low_res_mask(self, x):
45 | pass
46 |
47 | def pad_image(self, x):
48 | h, w = x.shape[-2:]
49 | padh = self.target_length - h
50 | padw = self.target_length - w
51 | x = F.pad(x, (0, padw, 0, padh))
52 | return x
53 |
54 | def normalize_image(self, x):
55 | x = (x - self.pixel_mean) / self.pixel_std
56 | return x
57 |
58 | @staticmethod
59 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
60 | """
61 | Compute the output size given input size and target long side length.
62 | """
63 | scale = long_side_length * 1.0 / max(oldh, oldw)
64 | newh, neww = oldh * scale, oldw * scale
65 | neww = int(neww + 0.5)
66 | newh = int(newh + 0.5)
67 | return (newh, neww)
68 |
69 | def apply_image(self, image: np.ndarray) -> np.ndarray:
70 | """
71 | Expects a numpy array with shape BxHxWxC in uint8 format.
72 | """
73 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
74 | # print(image_batch.shape)
75 | return np.array(resize(to_pil_image(image), target_size))
76 |
77 | def __call__(self, sample):
78 | image, label = sample['image'], sample['label']
79 |
80 | image = self.apply_image(image)
81 | label = self.apply_image(label)
82 |
83 | image_torch = torch.as_tensor(image).permute(2,0,1).contiguous()#[None, :, :, :]
84 | label_torch = torch.as_tensor(label).contiguous()
85 |
86 | image_torch = self.pad_image(self.normalize_image(image_torch))
87 | label_torch = self.pad_image(label_torch)[None, :, :]
88 |
89 | # if random.random() > 0.5:
90 | # image_torch, label_torch = random_rot_flip_torch(image_torch, label_torch)
91 | # elif random.random() > 0.5:
92 | # image_torch, label_torch = random_rotate_torch(image_torch, label_torch)
93 |
94 | low_res_label = resize(label_torch, self.target_length//4)#.squeeze()
95 | sample = {'image': image_torch, 'label': label_torch.float(), 'low_res_label': low_res_label.float()}
96 | return sample
97 |
98 |
99 |
100 | class SamTransform:
101 | def __init__(self, target_length):
102 | self.target_length = target_length
103 | self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
104 | self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
105 |
106 | def get_low_res_mask(self, x):
107 | pass
108 |
109 | def pad_image(self, x):
110 | h, w = x.shape[-2:]
111 | padh = self.target_length - h
112 | padw = self.target_length - w
113 | x = F.pad(x, (0, padw, 0, padh))
114 | return x
115 |
116 | def normalize_image(self, x):
117 | x = (x - self.pixel_mean) / self.pixel_std
118 | return x
119 |
120 | @staticmethod
121 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
122 | """
123 | Compute the output size given input size and target long side length.
124 | """
125 | scale = long_side_length * 1.0 / max(oldh, oldw)
126 | newh, neww = oldh * scale, oldw * scale
127 | neww = int(neww + 0.5)
128 | newh = int(newh + 0.5)
129 | return (newh, neww)
130 |
131 | def apply_image(self, image: np.ndarray) -> np.ndarray:
132 | """
133 | Expects a numpy array with shape BxHxWxC in uint8 format.
134 | """
135 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
136 | # print(image_batch.shape)
137 | return np.array(resize(to_pil_image(image), target_size))
138 |
139 | def __call__(self, sample):
140 | image, label = sample['image'], sample['label']
141 |
142 | image = self.apply_image(image)
143 | label = self.apply_image(label)
144 |
145 | image_torch = torch.as_tensor(image).permute(2,0,1).contiguous()#[None, :, :, :]
146 | label_torch = torch.as_tensor(label).contiguous()
147 |
148 | image_torch = self.pad_image(self.normalize_image(image_torch))
149 | label_torch = self.pad_image(label_torch)[None, :, :]
150 |
151 | if random.random() > 0.5:
152 | image_torch, label_torch = random_rot_flip_torch(image_torch, label_torch)
153 | elif random.random() > 0.5:
154 | image_torch, label_torch = random_rotate_torch(image_torch, label_torch)
155 |
156 | low_res_label = resize(label_torch, self.target_length//4)#.squeeze()
157 | sample = {'image': image_torch, 'label': label_torch.float(), 'low_res_label': low_res_label.float()}
158 | return sample
159 |
160 | def collater(data):
161 | images = [s['image'] for s in data]
162 | labels = [s['label'] for s in data]
163 | low_res_labels = [s['low_res_label'] for s in data]
164 |
165 | images = torch.stack(images, dim=0)
166 | labels = torch.stack(labels, dim=0).squeeze()
167 | low_res_labels = torch.stack(low_res_labels, dim=0).squeeze()
168 |
169 | return {'image': images, 'label': labels, 'low_res_label': low_res_labels}
170 | # pass
171 |
172 |
173 | class UFPR_ALPR_Dataset(data.Dataset):
174 | def __init__(self, root, split='training', transform=None):
175 | self.data_dir = os.path.join(root, split)
176 | self.image_list = self.build_image_list()
177 | self.transform = transform
178 |
179 | def build_image_list(self):
180 | image_list = []
181 | for i in range(len(os.listdir(self.data_dir))):
182 | path = os.path.join(self.data_dir, os.listdir(self.data_dir)[i])
183 | files = os.listdir(path)
184 | for j in range(len(files)):
185 | if os.path.splitext(files[j])[-1] == '.png':
186 | image_list.append(os.path.join(path, files[j]))
187 | # image_list = image_list[490:]
188 | return image_list
189 |
190 | def load_image(self, path):
191 | img = cv2.imread(path)
192 | img = img.astype(np.uint8)
193 | return img
194 |
195 | def load_annotations(self, path):
196 | file = path.replace('png', 'txt')
197 | with open(file, 'r') as f:
198 | data = f.read()
199 |
200 | lines = data.replace('\t', '').replace('-', '').split('\n')
201 | for line in lines:
202 | line_split = line.split(':')
203 | prop = line_split[0].strip()
204 | data = line_split[1].strip()
205 | if prop == "position_plate":
206 | data = data.split(" ")
207 | data = np.array(data, dtype=np.float32)
208 | label = data.reshape((1,4))
209 |
210 | return label
211 |
212 | def plate_mask(self, img, annot):
213 | h, w = img.shape[0], img.shape[1]
214 | mask = np.zeros((h, w))
215 | mask[int(annot[:,1]):int(annot[:,1]+annot[:,3]),int(annot[:,0]):int(annot[:,0]+annot[:,2])] = 1
216 | mask = mask.astype(np.uint8)
217 | return mask
218 |
219 | def __len__(self):
220 | return len(self.image_list)
221 |
222 | def __getitem__(self, idx):
223 | path = self.image_list[idx]
224 | img = self.load_image(path)
225 | plate_annot = self.load_annotations(path)
226 | mask = self.plate_mask(img, plate_annot)
227 | sample = {'image': img, 'label': mask}
228 | if self.transform:
229 | sample = self.transform(sample)
230 |
231 | return sample
232 |
233 |
234 | if __name__=='__main__':
235 | # db_train = UFPR_ALPR_Dataset(root='/media/disk1/yxding/dhx/Dataset/UFPR-ALPR/', split="training",
236 | # transform=transforms.Compose(
237 | # [RandomGenerator(output_size=[512, 512], low_res=[128, 128])]))
238 |
239 | db_train = UFPR_ALPR_Dataset(root='/media/disk1/yxding/dhx/Dataset/UFPR-ALPR/', split="training",
240 | transform=SamTransform(1024))
241 |
242 | trainloader = data.DataLoader(db_train, batch_size=2, shuffle=True, collate_fn=collater, drop_last=True, num_workers=2)
243 |
244 | for v in trainloader:
245 | images = v['image']
246 | labels = v['label']
247 | low_res_labels = v['low_res_label']
248 |
249 | print(images.shape)
250 | print(labels.shape)
251 | print(low_res_labels.shape)
252 | raise
253 |
254 | # sample = db_train[10]
255 | # label = sample['label']
256 | # image = sample['image']
257 | # low_res_label = sample['low_res_label']
258 |
259 | # # print(label.shape)
260 | # # print(image.shape)
261 |
262 | # image = sample['image'].permute(1,2,0).numpy()
263 | # cv2.imwrite('test_image.png', image*100)
264 |
265 | # label = sample['label'].permute(1,2,0).numpy().astype(np.uint8)
266 | # label = cv2.cvtColor(label, cv2.COLOR_GRAY2BGR)
267 | # cv2.imwrite('test_label.png', label*100)
268 |
269 | # low_res_label = sample['low_res_label'].permute(1,2,0).numpy().astype(np.uint8)
270 | # low_res_label = cv2.cvtColor(low_res_label, cv2.COLOR_GRAY2BGR)
271 | # cv2.imwrite('test_low_res_label.png', low_res_label*100)
272 |
273 |
274 | # label = sample['label'].numpy().astype(np.uint8)
275 | # image = sample['image'].permute(1,2,0).numpy()
276 |
277 | # cv2.imwrite('test_image.png', image*255)
278 | # label = cv2.cvtColor(label, cv2.COLOR_GRAY2BGR)
279 | # cv2.imwrite('test_label.png', label*255)
280 |
281 | # print(image.shape)
282 | # print(label.shape)
283 |
284 |
285 |
286 |
--------------------------------------------------------------------------------
/lora_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from segment_anything.modeling import Sam
11 | from sam_lora_image_encoder_mask_decoder import LoRA_Sam
12 | # from sam_lora_image_encoder import LoRA_Sam
13 |
14 | from typing import Optional, Tuple
15 |
16 | from datasets.dataset_ufpr_sam import SamTransform
17 | from segment_anything.utils.transforms import ResizeLongestSide
18 |
19 |
20 | class LoRA_SamPredictor:
21 | def __init__(
22 | self,
23 | sam_model: LoRA_Sam,
24 | ) -> None:
25 | """
26 | Uses SAM to calculate the image embedding for an image, and then
27 | allow repeated, efficient mask prediction given prompts.
28 |
29 | Arguments:
30 | sam_model (Sam): The model to use for mask prediction.
31 | """
32 | super().__init__()
33 | self.model = sam_model
34 | self.original_size = (1080, 1920)
35 | self.input_size = (576, 1024)
36 | self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
37 | self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
38 | self.transform = ResizeLongestSide(sam_model.sam.image_encoder.img_size)
39 | self.reset_image()
40 |
41 | def forward(self, image_batch, multimask_output=True):
42 | featrues = self.model.sam.image_encoder(image_batch)
43 |
44 | sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
45 | points=None,
46 | boxes=None,
47 | masks=None,
48 | )
49 |
50 | low_res_masks, iou_predictions = self.model.sam.mask_decoder(
51 | image_embeddings=featrues,
52 | image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
53 | sparse_prompt_embeddings=sparse_embeddings,
54 | dense_prompt_embeddings=dense_embeddings,
55 | multimask_output=multimask_output,
56 | )
57 | # masks = self.model.sam.postprocess_masks(low_res_masks, self.input_size, self.original_size)
58 | # masks = masks > self.model.sam.mask_threshold
59 | # masks = self.model.sam.postprocess_masks(low_res_masks, self.input_size, self.original_size)
60 | return low_res_masks#.sum(dim=1)
61 |
62 | @torch.no_grad()
63 | def forward_test(self, image_batch, multimask_output=True):
64 | self.features = self.model.sam.image_encoder(image_batch)
65 | self.is_image_set = True
66 |
67 | sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
68 | points=None,
69 | boxes=None,
70 | masks=None,
71 | )
72 |
73 | low_res_masks, iou_predictions = self.model.sam.mask_decoder(
74 | image_embeddings=self.features,
75 | image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
76 | sparse_prompt_embeddings=sparse_embeddings,
77 | dense_prompt_embeddings=dense_embeddings,
78 | multimask_output=multimask_output,
79 | )
80 | # print(low_res_masks.shape)
81 | # print(self.input_size)
82 | masks = self.model.sam.postprocess_masks(low_res_masks, self.input_size, self.original_size)
83 | masks = masks > self.model.sam.mask_threshold
84 |
85 | return masks, iou_predictions, low_res_masks
86 |
87 | @torch.no_grad()
88 | def foward_refine(self, point_corrds, point_labels, masks, multimask_output=True):
89 | if point_corrds != None:
90 | points = (point_corrds, point_labels)
91 | else:
92 | points = None
93 |
94 | sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
95 | points=points,
96 | boxes=None,
97 | masks=masks
98 | )
99 |
100 | low_res_masks, iou_predictions = self.model.sam.mask_decoder(
101 | image_embeddings=self.features,
102 | image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
103 | sparse_prompt_embeddings=sparse_embeddings,
104 | dense_prompt_embeddings=dense_embeddings,
105 | multimask_output=multimask_output,
106 | )
107 | masks = self.model.sam.postprocess_masks(low_res_masks, self.input_size, self.original_size)
108 | masks = masks > self.model.sam.mask_threshold
109 | return masks, iou_predictions, low_res_masks
110 |
111 |
112 | def set_image(
113 | self,
114 | image: np.ndarray,
115 | image_format: str = "RGB",
116 | ) -> None:
117 | """
118 | Calculates the image embeddings for the provided image, allowing
119 | masks to be predicted with the 'predict' method.
120 |
121 | Arguments:
122 | image (np.ndarray): The image for calculating masks. Expects an
123 | image in HWC uint8 format, with pixel values in [0, 255].
124 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
125 | """
126 | assert image_format in [
127 | "RGB",
128 | "BGR",
129 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
130 | if image_format != self.model.image_format:
131 | image = image[..., ::-1]
132 |
133 | # Transform the image to the form expected by the model
134 | input_image = self.transform.apply_image(image)
135 | input_image_torch = torch.as_tensor(input_image, device=self.device)
136 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
137 |
138 | self.set_torch_image(input_image_torch, image.shape[:2])
139 |
140 | @torch.no_grad()
141 | def set_torch_image(
142 | self,
143 | transformed_image: torch.Tensor,
144 | original_image_size: Tuple[int, ...],
145 | ) -> None:
146 | """
147 | Calculates the image embeddings for the provided image, allowing
148 | masks to be predicted with the 'predict' method. Expects the input
149 | image to be already transformed to the format expected by the model.
150 |
151 | Arguments:
152 | transformed_image (torch.Tensor): The input image, with shape
153 | 1x3xHxW, which has been transformed with ResizeLongestSide.
154 | original_image_size (tuple(int, int)): The size of the image
155 | before transformation, in (H, W) format.
156 | """
157 | assert (
158 | len(transformed_image.shape) == 4
159 | and transformed_image.shape[1] == 3
160 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
161 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
162 | self.reset_image()
163 |
164 | self.original_size = original_image_size
165 | self.input_size = tuple(transformed_image.shape[-2:])
166 | input_image = self.model.preprocess(transformed_image)
167 | self.features = self.model.image_encoder(input_image)
168 | self.is_image_set = True
169 | # self.device =
170 |
171 | def predict(
172 | self,
173 | point_coords: Optional[np.ndarray] = None,
174 | point_labels: Optional[np.ndarray] = None,
175 | box: Optional[np.ndarray] = None,
176 | mask_input: Optional[np.ndarray] = None,
177 | multimask_output: bool = True,
178 | return_logits: bool = False,
179 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
180 | """
181 | Predict masks for the given input prompts, using the currently set image.
182 |
183 | Arguments:
184 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (np.ndarray or None): A length N array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | box (np.ndarray or None): A length 4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form 1xHxW, where
193 | for SAM, H=W=256.
194 | multimask_output (bool): If true, the model will return three masks.
195 | For ambiguous input prompts (such as a single click), this will often
196 | produce better masks than a single prediction. If only a single
197 | mask is needed, the model's predicted quality score can be used
198 | to select the best mask. For non-ambiguous prompts, such as multiple
199 | input prompts, multimask_output=False can give better results.
200 | return_logits (bool): If true, returns un-thresholded masks logits
201 | instead of a binary mask.
202 |
203 | Returns:
204 | (np.ndarray): The output masks in CxHxW format, where C is the
205 | number of masks, and (H, W) is the original image size.
206 | (np.ndarray): An array of length C containing the model's
207 | predictions for the quality of each mask.
208 | (np.ndarray): An array of shape CxHxW, where C is the number
209 | of masks and H=W=256. These low resolution logits can be passed to
210 | a subsequent iteration as mask input.
211 | """
212 | if not self.is_image_set:
213 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
214 |
215 | # Transform input prompts
216 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
217 | if point_coords is not None:
218 | assert (
219 | point_labels is not None
220 | ), "point_labels must be supplied if point_coords is supplied."
221 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
222 | # print()
223 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
224 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
225 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
226 | if box is not None:
227 | box = self.transform.apply_boxes(box, self.original_size)
228 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
229 | box_torch = box_torch[None, :]
230 | if mask_input is not None:
231 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
232 | mask_input_torch = mask_input_torch[None, :, :, :]
233 |
234 | masks, iou_predictions, low_res_masks = self.predict_torch(
235 | coords_torch,
236 | labels_torch,
237 | box_torch,
238 | mask_input_torch,
239 | multimask_output,
240 | return_logits=return_logits,
241 | )
242 |
243 | masks_np = masks[0].detach().cpu().numpy()
244 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
245 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
246 | return masks_np, iou_predictions_np, low_res_masks_np
247 |
248 | @torch.no_grad()
249 | def predict_torch(
250 | self,
251 | point_coords: Optional[torch.Tensor],
252 | point_labels: Optional[torch.Tensor],
253 | boxes: Optional[torch.Tensor] = None,
254 | mask_input: Optional[torch.Tensor] = None,
255 | multimask_output: bool = True,
256 | return_logits: bool = False,
257 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
258 | """
259 | Predict masks for the given input prompts, using the currently set image.
260 | Input prompts are batched torch tensors and are expected to already be
261 | transformed to the input frame using ResizeLongestSide.
262 |
263 | Arguments:
264 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
265 | model. Each point is in (X,Y) in pixels.
266 | point_labels (torch.Tensor or None): A BxN array of labels for the
267 | point prompts. 1 indicates a foreground point and 0 indicates a
268 | background point.
269 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
270 | model, in XYXY format.
271 | mask_input (np.ndarray): A low resolution mask input to the model, typically
272 | coming from a previous prediction iteration. Has form Bx1xHxW, where
273 | for SAM, H=W=256. Masks returned by a previous iteration of the
274 | predict method do not need further transformation.
275 | multimask_output (bool): If true, the model will return three masks.
276 | For ambiguous input prompts (such as a single click), this will often
277 | produce better masks than a single prediction. If only a single
278 | mask is needed, the model's predicted quality score can be used
279 | to select the best mask. For non-ambiguous prompts, such as multiple
280 | input prompts, multimask_output=False can give better results.
281 | return_logits (bool): If true, returns un-thresholded masks logits
282 | instead of a binary mask.
283 |
284 | Returns:
285 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
286 | number of masks, and (H, W) is the original image size.
287 | (torch.Tensor): An array of shape BxC containing the model's
288 | predictions for the quality of each mask.
289 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
290 | of masks and H=W=256. These low res logits can be passed to
291 | a subsequent iteration as mask input.
292 | """
293 | if not self.is_image_set:
294 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
295 |
296 | if point_coords is not None:
297 | points = (point_coords, point_labels)
298 | else:
299 | points = None
300 |
301 | # Embed prompts
302 | sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
303 | points=points,
304 | boxes=boxes,
305 | masks=mask_input,
306 | )
307 |
308 | # Predict masks
309 | low_res_masks, iou_predictions = self.model.sam.mask_decoder(
310 | image_embeddings=self.features,
311 | image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
312 | sparse_prompt_embeddings=sparse_embeddings,
313 | dense_prompt_embeddings=dense_embeddings,
314 | multimask_output=multimask_output,
315 | )
316 |
317 | # Upscale the masks to the original image resolution
318 | masks = self.model.sam.postprocess_masks(low_res_masks, self.input_size, self.original_size)
319 |
320 | if not return_logits:
321 | masks = masks > self.model.sam.mask_threshold
322 |
323 | return masks, iou_predictions, low_res_masks
324 |
325 | def get_image_embedding(self) -> torch.Tensor:
326 | """
327 | Returns the image embeddings for the currently set image, with
328 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
329 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
330 | """
331 | if not self.is_image_set:
332 | raise RuntimeError(
333 | "An image must be set with .set_image(...) to generate an embedding."
334 | )
335 | assert self.features is not None, "Features must exist if an image has been set."
336 | return self.features
337 |
338 | @property
339 | def device(self) -> torch.device:
340 | return self.model.sam.device
341 |
342 | def reset_image(self) -> None:
343 | """Resets the currently set image."""
344 | self.is_image_set = False
345 | self.features = None
346 | self.orig_h = None
347 | self.orig_w = None
348 | self.input_h = None
349 | self.input_w = None
350 |
--------------------------------------------------------------------------------
/sam_lora_image_encoder.py:
--------------------------------------------------------------------------------
1 | from segment_anything import build_sam, SamPredictor
2 | from segment_anything import sam_model_registry
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from torch.nn.parameter import Parameter
10 | from segment_anything.modeling import Sam
11 | from safetensors import safe_open
12 | from safetensors.torch import save_file
13 |
14 | from icecream import ic
15 |
16 | class LoRA(nn.Module):
17 | def __init__(self,
18 | qkv: nn.Module,
19 | linear_a_q: nn.Module,
20 | linear_b_q: nn.Module,
21 | linear_a_v: nn.Module,
22 | linear_b_v: nn.Module
23 | ):
24 | # super(LoRA, self).__init__()
25 | super().__init__()
26 | self.qkv = qkv
27 | self.linear_a_q = linear_a_q
28 | self.linear_b_q = linear_b_q
29 | self.linear_a_v = linear_a_v
30 | self.linear_b_v = linear_b_v
31 | self.dim = qkv.in_features
32 | self.w_identity = torch.eye(qkv.in_features)
33 |
34 | def forward(self, x):
35 | qkv = self.qkv(x)
36 | new_q = self.linear_b_q(self.linear_a_q(x))
37 | new_v = self.linear_b_v(self.linear_a_v(x))
38 | qkv[:, :, :, : self.dim] += new_q
39 | qkv[:, :, :, -self.dim:] += new_v
40 | return qkv
41 |
42 | class LoRA_Sam(nn.Module):
43 | def __init__(self, sam_model: Sam, r: int, lora_layer=None):
44 | super(LoRA_Sam, self).__init__()
45 |
46 | assert r > 0
47 | if lora_layer:
48 | self.lora_layer = lora_layer
49 | else:
50 | self.lora_layer = list(
51 | range(len(sam_model.image_encoder.blocks)))
52 |
53 | self.w_As = []
54 | self.w_Bs = []
55 |
56 | for param in sam_model.image_encoder.parameters():
57 | param.requires_grad = False
58 | for param in sam_model.prompt_encoder.parameters():
59 | param.requires_grad = False
60 | for param in sam_model.mask_decoder.parameters():
61 | param.requires_grad = False
62 |
63 | for layer_i, block in enumerate(sam_model.image_encoder.blocks):
64 | if layer_i not in self.lora_layer:
65 | continue
66 | w_qkv_linear = block.attn.qkv
67 | self.dim = w_qkv_linear.in_features
68 | w_a_linear_q = nn.Linear(self.dim, r, bias=False)
69 | w_b_linear_q = nn.Linear(r, self.dim, bias=False)
70 | w_a_linear_v = nn.Linear(self.dim, r, bias=False)
71 | w_b_linear_v = nn.Linear(r, self.dim, bias=False)
72 | self.w_As.append(w_a_linear_q)
73 | self.w_Bs.append(w_b_linear_q)
74 | self.w_As.append(w_a_linear_v)
75 | self.w_Bs.append(w_b_linear_v)
76 | block.attn.qkv = LoRA(
77 | w_qkv_linear,
78 | w_a_linear_q,
79 | w_b_linear_q,
80 | w_a_linear_v,
81 | w_b_linear_v,
82 | )
83 |
84 | self.reset_parameters()
85 | self.sam = sam_model
86 |
87 | def reset_parameters(self):
88 | for w_A in self.w_As:
89 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
90 | for w_B in self.w_Bs:
91 | nn.init.zeros_(w_B.weight)
92 |
93 | def forward(self, batched_input, multimask_output, image_size):
94 | return self.sam(batched_input, multimask_output, image_size)
95 |
96 | def save_lora_parameters(self, filename: str):
97 | assert filename.endswith('pt') or filename.endswith('pth')
98 | num_layer = len(self.w_As)
99 | a_tensors = {f'w_a_{i:03d}': self.w_As[i].weight for i in range(num_layer)}
100 | b_tensors = {f'w_b_{i:03d}': self.w_Bs[i].weight for i in range(num_layer)}
101 | prompt_encoder_tensors = {}
102 | mask_decoder_tensors = {}
103 |
104 | if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam, torch.nn.parallel.DistributedDataParallel):
105 | state_dict = self.sam.module.state_dict()
106 | else:
107 | state_dict = self.sam.state_dict()
108 |
109 | for key, value in state_dict.items():
110 | if 'prompt_encoder' in key:
111 | prompt_encoder_tensors[key] = value
112 | if 'mask_decoder' in key:
113 | mask_decoder_tensors[key] = value
114 |
115 | merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors}
116 | torch.save(merged_dict, filename)
117 |
118 | def load_lora_parameters(self, filename: str):
119 | assert filename.endswith('.pt') or filename.endswith('.pth')
120 |
121 | state_dict = torch.load(filename)
122 |
123 | for i, w_A_linear in enumerate(self.w_As):
124 | saved_key = f'w_a_{i:03d}'
125 | saved_tensor = state_dict[saved_key]
126 | w_A_linear.weight = Parameter(saved_tensor)
127 |
128 | for i, w_B_linear in enumerate(self.w_Bs):
129 | saved_key = f'w_b_{i:03d}'
130 | saved_tensor = state_dict[saved_key]
131 | w_B_linear.weight = Parameter(saved_tensor)
132 |
133 | sam_dict = self.sam.state_dict()
134 | sam_keys = sam_dict.keys()
135 |
136 | prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k]
137 | prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys]
138 | prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)}
139 | sam_dict.update(prompt_encoder_new_state_dict)
140 |
141 | mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k]
142 | mask_decoder_values = [state_dict[k] for k in mask_decoder_keys]
143 | mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)}
144 | sam_dict.update(mask_decoder_new_state_dict)
145 |
146 | self.sam.load_state_dict(sam_dict)
147 |
148 | def get_parameter_number(model):
149 | total_num = sum(p.numel() for p in model.parameters())
150 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
151 | return {'Total': total_num, 'Trainable': trainable_num}
152 |
153 |
--------------------------------------------------------------------------------
/sam_lora_image_encoder_mask_decoder.py:
--------------------------------------------------------------------------------
1 | from segment_anything import build_sam, SamPredictor
2 | from segment_anything import sam_model_registry
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from torch.nn.parameter import Parameter
10 | from segment_anything.modeling import Sam
11 | from safetensors import safe_open
12 | from safetensors.torch import save_file
13 |
14 | from icecream import ic
15 |
16 |
17 | class _LoRA_qkv(nn.Module):
18 | """In Sam it is implemented as
19 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
20 | B, N, C = x.shape
21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
22 | q, k, v = qkv.unbind(0)
23 | """
24 |
25 | def __init__(
26 | self,
27 | qkv: nn.Module,
28 | linear_a_q: nn.Module,
29 | linear_b_q: nn.Module,
30 | linear_a_v: nn.Module,
31 | linear_b_v: nn.Module,
32 | ):
33 | super().__init__()
34 | self.qkv = qkv
35 | self.linear_a_q = linear_a_q
36 | self.linear_b_q = linear_b_q
37 | self.linear_a_v = linear_a_v
38 | self.linear_b_v = linear_b_v
39 | self.dim = qkv.in_features
40 | self.w_identity = torch.eye(qkv.in_features)
41 |
42 | def forward(self, x):
43 | qkv = self.qkv(x) # B,N,N,3*org_C
44 | new_q = self.linear_b_q(self.linear_a_q(x))
45 | new_v = self.linear_b_v(self.linear_a_v(x))
46 | qkv[:, :, :, : self.dim] += new_q
47 | qkv[:, :, :, -self.dim:] += new_v
48 | return qkv
49 |
50 |
51 | class _LoRA_qkv_proj(nn.Module):
52 | def __init__(self, proj: nn.Module, w_a: nn.Module, w_b: nn.Module):
53 | super().__init__()
54 | self.proj = proj
55 | self.w_a = w_a
56 | self.w_b = w_b
57 |
58 | def forward(self, x):
59 | out = self.proj(x) + self.w_b(self.w_a(x))
60 | return out
61 |
62 |
63 | class LoRA_Sam(nn.Module):
64 | """Applies low-rank adaptation to a Sam model's image encoder.
65 |
66 | Args:
67 | sam_model: a vision transformer model, see base_vit.py
68 | r: rank of LoRA
69 | num_classes: how many classes the model output, default to the vit model
70 | lora_layer: which layer we apply LoRA.
71 |
72 | Examples::
73 | >>> model = ViT('B_16_imagenet1k')
74 | >>> lora_model = LoRA_ViT(model, r=4)
75 | >>> preds = lora_model(img)
76 | >>> print(preds.shape)
77 | torch.Size([1, 1000])
78 | """
79 |
80 | def __init__(self, sam_model: Sam, r: int, lora_layer=None):
81 | super(LoRA_Sam, self).__init__()
82 |
83 | assert r > 0
84 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
85 | # dim = base_vit_dim
86 | if lora_layer:
87 | self.lora_layer = lora_layer
88 | else:
89 | self.lora_layer = list(
90 | range(len(sam_model.image_encoder.blocks)))
91 | # create for storage, then we can init them or load weights
92 | self.w_As = [] # These are linear layers
93 | self.w_Bs = []
94 |
95 | # lets freeze first
96 | for param in sam_model.image_encoder.parameters():
97 | param.requires_grad = False
98 |
99 | # Here, we do the surgery
100 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
101 | # If we only want few lora layer instead of all
102 | if t_layer_i not in self.lora_layer:
103 | continue
104 | w_qkv_linear = blk.attn.qkv
105 | self.dim = w_qkv_linear.in_features
106 | w_a_linear_q = nn.Linear(self.dim, r, bias=False)
107 | w_b_linear_q = nn.Linear(r, self.dim, bias=False)
108 | w_a_linear_v = nn.Linear(self.dim, r, bias=False)
109 | w_b_linear_v = nn.Linear(r, self.dim, bias=False)
110 | self.w_As.append(w_a_linear_q)
111 | self.w_Bs.append(w_b_linear_q)
112 | self.w_As.append(w_a_linear_v)
113 | self.w_Bs.append(w_b_linear_v)
114 | blk.attn.qkv = _LoRA_qkv(
115 | w_qkv_linear,
116 | w_a_linear_q,
117 | w_b_linear_q,
118 | w_a_linear_v,
119 | w_b_linear_v,
120 | )
121 |
122 | # Additional surgery for the mask decoder
123 | self.self_attn_As = []
124 | self.self_attn_Bs = []
125 | self.cross_attn_ti_As = []
126 | self.cross_attn_ti_Bs = []
127 | self.cross_attn_it_As = []
128 | self.cross_attn_it_Bs = []
129 |
130 | for param in sam_model.mask_decoder.parameters():
131 | param.requires_grad = False
132 | for param in sam_model.prompt_encoder.parameters():
133 | param.requires_grad = False
134 |
135 | decoder_transformer = sam_model.mask_decoder.transformer
136 | for layer_idx, blk in enumerate(decoder_transformer.layers):
137 | self_attn_q_proj = blk.self_attn.q_proj
138 | self_attn_v_proj = blk.self_attn.v_proj
139 | input_dim = blk.self_attn.embedding_dim
140 | output_dim = blk.self_attn.internal_dim
141 | w_a_linear_q_self_attn = nn.Linear(input_dim, r, bias=False)
142 | w_b_linear_q_self_attn = nn.Linear(r, output_dim, bias=False)
143 | w_a_linear_v_self_attn = nn.Linear(input_dim, r, bias=False)
144 | w_b_linear_v_self_attn = nn.Linear(r, output_dim, bias=False)
145 | self.self_attn_As.append(w_a_linear_q_self_attn)
146 | self.self_attn_Bs.append(w_b_linear_q_self_attn)
147 | self.self_attn_As.append(w_a_linear_v_self_attn)
148 | self.self_attn_Bs.append(w_b_linear_v_self_attn)
149 | blk.self_attn.q_proj = _LoRA_qkv_proj(self_attn_q_proj, w_a_linear_q_self_attn, w_b_linear_q_self_attn)
150 | blk.self_attn.v_proj = _LoRA_qkv_proj(self_attn_v_proj, w_a_linear_v_self_attn, w_b_linear_v_self_attn)
151 |
152 | cross_attn_ti_q_proj = blk.cross_attn_token_to_image.q_proj
153 | cross_attn_ti_v_proj = blk.cross_attn_token_to_image.v_proj
154 | ti_input_dim = blk.cross_attn_token_to_image.embedding_dim
155 | ti_output_dim = blk.cross_attn_token_to_image.internal_dim
156 | w_a_linear_q_cross_attn_ti = nn.Linear(ti_input_dim, r, bias=False)
157 | w_b_linear_q_cross_attn_ti = nn.Linear(r, ti_output_dim, bias=False)
158 | w_a_linear_v_cross_attn_ti = nn.Linear(ti_input_dim, r, bias=False)
159 | w_b_linear_v_cross_attn_ti = nn.Linear(r, ti_output_dim, bias=False)
160 | self.cross_attn_ti_As.append(w_a_linear_q_cross_attn_ti)
161 | self.cross_attn_ti_Bs.append(w_b_linear_q_cross_attn_ti)
162 | self.cross_attn_ti_As.append(w_a_linear_v_cross_attn_ti)
163 | self.cross_attn_ti_Bs.append(w_b_linear_v_cross_attn_ti)
164 | blk.cross_attn_token_to_image.q_proj = _LoRA_qkv_proj(cross_attn_ti_q_proj, w_a_linear_q_cross_attn_ti,
165 | w_b_linear_q_cross_attn_ti)
166 | blk.cross_attn_token_to_image.v_proj = _LoRA_qkv_proj(cross_attn_ti_v_proj, w_a_linear_v_cross_attn_ti,
167 | w_b_linear_v_cross_attn_ti)
168 |
169 | cross_attn_it_q_proj = blk.cross_attn_image_to_token.q_proj
170 | cross_attn_it_v_proj = blk.cross_attn_image_to_token.v_proj
171 | it_input_dim = blk.cross_attn_image_to_token.embedding_dim
172 | it_output_dim = blk.cross_attn_image_to_token.internal_dim
173 | w_a_linear_q_cross_attn_it = nn.Linear(it_input_dim, r, bias=False)
174 | w_b_linear_q_cross_attn_it = nn.Linear(r, it_output_dim, bias=False)
175 | w_a_linear_v_cross_attn_it = nn.Linear(it_input_dim, r, bias=False)
176 | w_b_linear_v_cross_attn_it = nn.Linear(r, it_output_dim, bias=False)
177 | self.cross_attn_it_As.append(w_a_linear_q_cross_attn_it)
178 | self.cross_attn_it_Bs.append(w_b_linear_q_cross_attn_it)
179 | self.cross_attn_it_As.append(w_a_linear_v_cross_attn_it)
180 | self.cross_attn_it_Bs.append(w_b_linear_v_cross_attn_it)
181 | blk.cross_attn_image_to_token.q_proj = _LoRA_qkv_proj(cross_attn_it_q_proj, w_a_linear_q_cross_attn_it,
182 | w_b_linear_q_cross_attn_it)
183 | blk.cross_attn_image_to_token.v_proj = _LoRA_qkv_proj(cross_attn_it_v_proj, w_a_linear_v_cross_attn_it,
184 | w_b_linear_v_cross_attn_it)
185 |
186 | # final attention token to image
187 | block = decoder_transformer.final_attn_token_to_image
188 | fa_ti_q_proj = block.q_proj
189 | fa_ti_v_proj = block.v_proj
190 | in_dim, out_dim = block.embedding_dim, block.internal_dim
191 | self.fa_ti_q_proj_A = nn.Linear(in_dim, r, bias=False)
192 | self.fa_ti_q_proj_B = nn.Linear(r, out_dim, bias=False)
193 | self.fa_ti_v_proj_A = nn.Linear(in_dim, r, bias=False)
194 | self.fa_ti_v_proj_B = nn.Linear(r, out_dim, bias=False)
195 | # block.q_proj = _LoRA_qkv_proj(fa_ti_q_proj, self.fa_ti_q_proj_A, self.fa_ti_q_proj_B)
196 | # block.v_proj = _LoRA_qkv_proj(fa_ti_v_proj, self.fa_ti_v_proj_A, self.fa_ti_v_proj_B)
197 | block.q_proj = _LoRA_qkv_proj(fa_ti_q_proj, self.fa_ti_q_proj_A, self.fa_ti_q_proj_B)
198 | block.v_proj = _LoRA_qkv_proj(fa_ti_v_proj, self.fa_ti_v_proj_A, self.fa_ti_v_proj_B)
199 |
200 | self.reset_parameters()
201 | self.sam = sam_model
202 |
203 | def save_lora_parameters(self, filename: str) -> None:
204 | r"""Only safetensors is supported now.
205 |
206 | pip install safetensor if you do not have one installed yet.
207 |
208 | save both lora and fc parameters.
209 | """
210 |
211 | assert filename.endswith(".pt") or filename.endswith('.pth')
212 |
213 | num_layer = len(self.w_As) # actually, it is half
214 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
215 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
216 | sa_a_tensors = {f"sa_a_{i:03d}": self.self_attn_As[i].weight for i in range(len(self.self_attn_As))}
217 | sa_b_tensors = {f"sa_b_{i:03d}": self.self_attn_Bs[i].weight for i in range(len(self.self_attn_Bs))}
218 | cti_a_tensors = {f"cti_a_{i:03d}": self.cross_attn_ti_As[i].weight for i in range(len(self.cross_attn_ti_As))}
219 | cti_b_tensors = {f"cti_b_{i:03d}": self.cross_attn_ti_Bs[i].weight for i in range(len(self.cross_attn_ti_Bs))}
220 | cit_a_tensors = {f"cit_a_{i:03d}": self.cross_attn_it_As[i].weight for i in range(len(self.cross_attn_it_As))}
221 | cit_b_tensors = {f"cit_b_{i:03d}": self.cross_attn_it_Bs[i].weight for i in range(len(self.cross_attn_it_Bs))}
222 | fa_ti_tensors = {'fati_qa': self.fa_ti_q_proj_A.weight, 'fati_qb': self.fa_ti_q_proj_B.weight,
223 | 'fati_va': self.fa_ti_v_proj_A.weight,
224 | 'fati_vb': self.fa_ti_v_proj_B.weight}
225 | prompt_encoder_tensors = {}
226 | mask_decoder_tensors = {}
227 |
228 | # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted
229 | if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam,
230 | torch.nn.parallel.DistributedDataParallel):
231 | state_dict = self.sam.module.state_dict()
232 | else:
233 | state_dict = self.sam.state_dict()
234 | for key, value in state_dict.items():
235 | if 'prompt_encoder' in key:
236 | prompt_encoder_tensors[key] = value
237 | if 'mask_decoder' in key and 'transformer' not in key:
238 | mask_decoder_tensors[key] = value
239 |
240 | merged_dict = {**a_tensors, **b_tensors, **sa_a_tensors, **sa_b_tensors, **cti_a_tensors, **cti_b_tensors,
241 | **cit_a_tensors, **cit_b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors,
242 | **fa_ti_tensors}
243 | torch.save(merged_dict, filename)
244 |
245 | def load_lora_parameters(self, filename: str) -> None:
246 | r"""Only safetensors is supported now.
247 |
248 | pip install safetensor if you do not have one installed yet.\
249 |
250 | load both lora and fc parameters.
251 | """
252 |
253 | assert filename.endswith(".pt") or filename.endswith('.pth')
254 |
255 | state_dict = torch.load(filename)
256 |
257 | for i, w_A_linear in enumerate(self.w_As):
258 | saved_key = f"w_a_{i:03d}"
259 | saved_tensor = state_dict[saved_key]
260 | w_A_linear.weight = Parameter(saved_tensor)
261 |
262 | for i, w_B_linear in enumerate(self.w_Bs):
263 | saved_key = f"w_b_{i:03d}"
264 | saved_tensor = state_dict[saved_key]
265 | w_B_linear.weight = Parameter(saved_tensor)
266 |
267 | for i, sa_A_linear in enumerate(self.self_attn_As):
268 | saved_key = f"sa_a_{i:03d}"
269 | saved_tensor = state_dict[saved_key]
270 | sa_A_linear.weight = Parameter(saved_tensor)
271 |
272 | for i, sa_B_linear in enumerate(self.self_attn_Bs):
273 | saved_key = f"sa_b_{i:03d}"
274 | saved_tensor = state_dict[saved_key]
275 | sa_B_linear.weight = Parameter(saved_tensor)
276 |
277 | for i, cti_a_linear in enumerate(self.cross_attn_ti_As):
278 | saved_key = f"cti_a_{i:03d}"
279 | saved_tensor = state_dict[saved_key]
280 | cti_a_linear.weight = Parameter(saved_tensor)
281 |
282 | for i, cti_b_linear in enumerate(self.cross_attn_ti_Bs):
283 | saved_key = f"cti_b_{i:03d}"
284 | saved_tensor = state_dict[saved_key]
285 | cti_b_linear.weight = Parameter(saved_tensor)
286 |
287 | for i, cit_a_linear in enumerate(self.cross_attn_it_As):
288 | saved_key = f"cit_a_{i:03d}"
289 | saved_tensor = state_dict[saved_key]
290 | cit_a_linear.weight = Parameter(saved_tensor)
291 |
292 | for i, cit_b_linear in enumerate(self.cross_attn_it_Bs):
293 | saved_key = f"cit_b_{i:03d}"
294 | saved_tensor = state_dict[saved_key]
295 | cit_b_linear.weight = Parameter(saved_tensor)
296 |
297 | self.fa_ti_q_proj_A.weight = Parameter(state_dict["fati_qa"])
298 | self.fa_ti_q_proj_B.weight = Parameter(state_dict["fati_qb"])
299 | self.fa_ti_v_proj_A.weight = Parameter(state_dict["fati_va"])
300 | self.fa_ti_v_proj_B.weight = Parameter(state_dict["fati_vb"])
301 |
302 | sam_dict = self.sam.state_dict()
303 | sam_keys = sam_dict.keys()
304 |
305 | # load prompt encoder
306 | prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k]
307 | prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys]
308 | prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)}
309 | sam_dict.update(prompt_encoder_new_state_dict)
310 |
311 | # load mask decoder
312 | mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k and 'transformer' not in k]
313 | mask_decoder_values = [state_dict[k] for k in mask_decoder_keys]
314 | mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)}
315 | sam_dict.update(mask_decoder_new_state_dict)
316 | self.sam.load_state_dict(sam_dict)
317 |
318 | def reset_parameters(self) -> None:
319 | for w_A in self.w_As:
320 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
321 | for w_B in self.w_Bs:
322 | nn.init.zeros_(w_B.weight)
323 | for w_A in self.self_attn_As:
324 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
325 | for w_B in self.self_attn_Bs:
326 | nn.init.zeros_(w_B.weight)
327 | for w_A in self.cross_attn_ti_As:
328 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
329 | for w_B in self.cross_attn_ti_Bs:
330 | nn.init.zeros_(w_B.weight)
331 | for w_A in self.cross_attn_it_As:
332 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
333 | for w_B in self.cross_attn_it_Bs:
334 | nn.init.zeros_(w_B.weight)
335 | nn.init.kaiming_uniform_(self.fa_ti_q_proj_A.weight, a=math.sqrt(5))
336 | nn.init.zeros_(self.fa_ti_q_proj_B.weight)
337 | nn.init.kaiming_uniform_(self.fa_ti_v_proj_A.weight, a=math.sqrt(5))
338 | nn.init.zeros_(self.fa_ti_v_proj_B.weight)
339 |
340 | def forward(self, batched_input, multimask_output):
341 | # image_embedding = self.sam.image_encoder(batched_input)
342 | # sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
343 | # points=None,
344 | # boxes=None,
345 | # masks=None,
346 | # )
347 | return self.sam(batched_input, multimask_output)
348 |
349 | # def forward(self, x: Tensor) -> Tensor:
350 | # return self.lora_vit(x)
351 |
352 | def get_parameter_number(model):
353 | total_num = sum(p.numel() for p in model.parameters())
354 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
355 | return {'Total': total_num, 'Trainable': trainable_num}
356 |
357 |
358 |
--------------------------------------------------------------------------------
/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/automatic_mask_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10 |
11 | from typing import Any, Dict, List, Optional, Tuple
12 |
13 | from .modeling import Sam
14 | from .predictor import SamPredictor
15 | from .utils.amg import (
16 | MaskData,
17 | area_from_rle,
18 | batch_iterator,
19 | batched_mask_to_box,
20 | box_xyxy_to_xywh,
21 | build_all_layer_point_grids,
22 | calculate_stability_score,
23 | coco_encode_rle,
24 | generate_crop_boxes,
25 | is_box_near_crop_edge,
26 | mask_to_rle_pytorch,
27 | remove_small_regions,
28 | rle_to_mask,
29 | uncrop_boxes_xyxy,
30 | uncrop_masks,
31 | uncrop_points,
32 | )
33 |
34 |
35 | class SamAutomaticMaskGenerator:
36 | def __init__(
37 | self,
38 | model: Sam,
39 | points_per_side: Optional[int] = 32,
40 | points_per_batch: int = 64,
41 | pred_iou_thresh: float = 0.88,
42 | stability_score_thresh: float = 0.95,
43 | stability_score_offset: float = 1.0,
44 | box_nms_thresh: float = 0.7,
45 | crop_n_layers: int = 0,
46 | crop_nms_thresh: float = 0.7,
47 | crop_overlap_ratio: float = 512 / 1500,
48 | crop_n_points_downscale_factor: int = 1,
49 | point_grids: Optional[List[np.ndarray]] = None,
50 | min_mask_region_area: int = 0,
51 | output_mode: str = "binary_mask",
52 | ) -> None:
53 | """
54 | Using a SAM model, generates masks for the entire image.
55 | Generates a grid of point prompts over the image, then filters
56 | low quality and duplicate masks. The default settings are chosen
57 | for SAM with a ViT-H backbone.
58 |
59 | Arguments:
60 | model (Sam): The SAM model to use for mask prediction.
61 | points_per_side (int or None): The number of points to be sampled
62 | along one side of the image. The total number of points is
63 | points_per_side**2. If None, 'point_grids' must provide explicit
64 | point sampling.
65 | points_per_batch (int): Sets the number of points run simultaneously
66 | by the model. Higher numbers may be faster but use more GPU memory.
67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the
68 | model's predicted mask quality.
69 | stability_score_thresh (float): A filtering threshold in [0,1], using
70 | the stability of the mask under changes to the cutoff used to binarize
71 | the model's mask predictions.
72 | stability_score_offset (float): The amount to shift the cutoff when
73 | calculated the stability score.
74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal
75 | suppression to filter duplicate masks.
76 | crop_n_layers (int): If >0, mask prediction will be run again on
77 | crops of the image. Sets the number of layers to run, where each
78 | layer has 2**i_layer number of image crops.
79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80 | suppression to filter duplicate masks between different crops.
81 | crop_overlap_ratio (float): Sets the degree to which crops overlap.
82 | In the first crop layer, crops will overlap by this fraction of
83 | the image length. Later layers with more crops scale down this overlap.
84 | crop_n_points_downscale_factor (int): The number of points-per-side
85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86 | point_grids (list(np.ndarray) or None): A list over explicit grids
87 | of points used for sampling, normalized to [0,1]. The nth grid in the
88 | list is used in the nth crop layer. Exclusive with points_per_side.
89 | min_mask_region_area (int): If >0, postprocessing will be applied
90 | to remove disconnected regions and holes in masks with area smaller
91 | than min_mask_region_area. Requires opencv.
92 | output_mode (str): The form masks are returned in. Can be 'binary_mask',
93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94 | For large resolutions, 'binary_mask' may consume large amounts of
95 | memory.
96 | """
97 |
98 | assert (points_per_side is None) != (
99 | point_grids is None
100 | ), "Exactly one of points_per_side or point_grid must be provided."
101 | if points_per_side is not None:
102 | self.point_grids = build_all_layer_point_grids(
103 | points_per_side,
104 | crop_n_layers,
105 | crop_n_points_downscale_factor,
106 | )
107 | elif point_grids is not None:
108 | self.point_grids = point_grids
109 | else:
110 | raise ValueError("Can't have both points_per_side and point_grid be None.")
111 |
112 | assert output_mode in [
113 | "binary_mask",
114 | "uncompressed_rle",
115 | "coco_rle",
116 | ], f"Unknown output_mode {output_mode}."
117 | if output_mode == "coco_rle":
118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119 |
120 | if min_mask_region_area > 0:
121 | import cv2 # type: ignore # noqa: F401
122 |
123 | self.predictor = SamPredictor(model)
124 | self.points_per_batch = points_per_batch
125 | self.pred_iou_thresh = pred_iou_thresh
126 | self.stability_score_thresh = stability_score_thresh
127 | self.stability_score_offset = stability_score_offset
128 | self.box_nms_thresh = box_nms_thresh
129 | self.crop_n_layers = crop_n_layers
130 | self.crop_nms_thresh = crop_nms_thresh
131 | self.crop_overlap_ratio = crop_overlap_ratio
132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133 | self.min_mask_region_area = min_mask_region_area
134 | self.output_mode = output_mode
135 |
136 | @torch.no_grad()
137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138 | """
139 | Generates masks for the given image.
140 |
141 | Arguments:
142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143 |
144 | Returns:
145 | list(dict(str, any)): A list over records for masks. Each record is
146 | a dict containing the following keys:
147 | segmentation (dict(str, any) or np.ndarray): The mask. If
148 | output_mode='binary_mask', is an array of shape HW. Otherwise,
149 | is a dictionary containing the RLE.
150 | bbox (list(float)): The box around the mask, in XYWH format.
151 | area (int): The area in pixels of the mask.
152 | predicted_iou (float): The model's own prediction of the mask's
153 | quality. This is filtered by the pred_iou_thresh parameter.
154 | point_coords (list(list(float))): The point coordinates input
155 | to the model to generate this mask.
156 | stability_score (float): A measure of the mask's quality. This
157 | is filtered on using the stability_score_thresh parameter.
158 | crop_box (list(float)): The crop of the image used to generate
159 | the mask, given in XYWH format.
160 | """
161 |
162 | # Generate masks
163 | mask_data = self._generate_masks(image)
164 |
165 | # Filter small disconnected regions and holes in masks
166 | if self.min_mask_region_area > 0:
167 | mask_data = self.postprocess_small_regions(
168 | mask_data,
169 | self.min_mask_region_area,
170 | max(self.box_nms_thresh, self.crop_nms_thresh),
171 | )
172 |
173 | # Encode masks
174 | if self.output_mode == "coco_rle":
175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176 | elif self.output_mode == "binary_mask":
177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178 | else:
179 | mask_data["segmentations"] = mask_data["rles"]
180 |
181 | # Write mask records
182 | curr_anns = []
183 | for idx in range(len(mask_data["segmentations"])):
184 | ann = {
185 | "segmentation": mask_data["segmentations"][idx],
186 | "area": area_from_rle(mask_data["rles"][idx]),
187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188 | "predicted_iou": mask_data["iou_preds"][idx].item(),
189 | "point_coords": [mask_data["points"][idx].tolist()],
190 | "stability_score": mask_data["stability_score"][idx].item(),
191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192 | }
193 | curr_anns.append(ann)
194 |
195 | return curr_anns
196 |
197 | def _generate_masks(self, image: np.ndarray) -> MaskData:
198 | orig_size = image.shape[:2]
199 | crop_boxes, layer_idxs = generate_crop_boxes(
200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio
201 | )
202 |
203 | # Iterate over image crops
204 | data = MaskData()
205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 | data.cat(crop_data)
208 |
209 | # Remove duplicate masks between crops
210 | if len(crop_boxes) > 1:
211 | # Prefer masks from smaller crops
212 | scores = 1 / box_area(data["crop_boxes"])
213 | scores = scores.to(data["boxes"].device)
214 | keep_by_nms = batched_nms(
215 | data["boxes"].float(),
216 | scores,
217 | torch.zeros_like(data["boxes"][:, 0]), # categories
218 | iou_threshold=self.crop_nms_thresh,
219 | )
220 | data.filter(keep_by_nms)
221 |
222 | data.to_numpy()
223 | return data
224 |
225 | def _process_crop(
226 | self,
227 | image: np.ndarray,
228 | crop_box: List[int],
229 | crop_layer_idx: int,
230 | orig_size: Tuple[int, ...],
231 | ) -> MaskData:
232 | # Crop the image and calculate embeddings
233 | x0, y0, x1, y1 = crop_box
234 | cropped_im = image[y0:y1, x0:x1, :]
235 | cropped_im_size = cropped_im.shape[:2]
236 | self.predictor.set_image(cropped_im)
237 |
238 | # Get points for this crop
239 | points_scale = np.array(cropped_im_size)[None, ::-1]
240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale
241 |
242 | # Generate masks for this crop in batches
243 | data = MaskData()
244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246 | data.cat(batch_data)
247 | del batch_data
248 | self.predictor.reset_image()
249 |
250 | # Remove duplicates within this crop.
251 | keep_by_nms = batched_nms(
252 | data["boxes"].float(),
253 | data["iou_preds"],
254 | torch.zeros_like(data["boxes"][:, 0]), # categories
255 | iou_threshold=self.box_nms_thresh,
256 | )
257 | data.filter(keep_by_nms)
258 |
259 | # Return to the original image frame
260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261 | data["points"] = uncrop_points(data["points"], crop_box)
262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263 |
264 | return data
265 |
266 | def _process_batch(
267 | self,
268 | points: np.ndarray,
269 | im_size: Tuple[int, ...],
270 | crop_box: List[int],
271 | orig_size: Tuple[int, ...],
272 | ) -> MaskData:
273 | orig_h, orig_w = orig_size
274 |
275 | # Run model on this batch
276 | transformed_points = self.predictor.transform.apply_coords(points, im_size)
277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279 | masks, iou_preds, _ = self.predictor.predict_torch(
280 | in_points[:, None, :],
281 | in_labels[:, None],
282 | multimask_output=True,
283 | return_logits=True,
284 | )
285 |
286 | # Serialize predictions and store in MaskData
287 | data = MaskData(
288 | masks=masks.flatten(0, 1),
289 | iou_preds=iou_preds.flatten(0, 1),
290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291 | )
292 | del masks
293 |
294 | # Filter by predicted IoU
295 | if self.pred_iou_thresh > 0.0:
296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh
297 | data.filter(keep_mask)
298 |
299 | # Calculate stability score
300 | data["stability_score"] = calculate_stability_score(
301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302 | )
303 | if self.stability_score_thresh > 0.0:
304 | keep_mask = data["stability_score"] >= self.stability_score_thresh
305 | data.filter(keep_mask)
306 |
307 | # Threshold masks and calculate boxes
308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309 | data["boxes"] = batched_mask_to_box(data["masks"])
310 |
311 | # Filter boxes that touch crop boundaries
312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313 | if not torch.all(keep_mask):
314 | data.filter(keep_mask)
315 |
316 | # Compress to RLE
317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318 | data["rles"] = mask_to_rle_pytorch(data["masks"])
319 | del data["masks"]
320 |
321 | return data
322 |
323 | @staticmethod
324 | def postprocess_small_regions(
325 | mask_data: MaskData, min_area: int, nms_thresh: float
326 | ) -> MaskData:
327 | """
328 | Removes small disconnected regions and holes in masks, then reruns
329 | box NMS to remove any new duplicates.
330 |
331 | Edits mask_data in place.
332 |
333 | Requires open-cv as a dependency.
334 | """
335 | if len(mask_data["rles"]) == 0:
336 | return mask_data
337 |
338 | # Filter small disconnected regions and holes
339 | new_masks = []
340 | scores = []
341 | for rle in mask_data["rles"]:
342 | mask = rle_to_mask(rle)
343 |
344 | mask, changed = remove_small_regions(mask, min_area, mode="holes")
345 | unchanged = not changed
346 | mask, changed = remove_small_regions(mask, min_area, mode="islands")
347 | unchanged = unchanged and not changed
348 |
349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350 | # Give score=0 to changed masks and score=1 to unchanged masks
351 | # so NMS will prefer ones that didn't need postprocessing
352 | scores.append(float(unchanged))
353 |
354 | # Recalculate boxes and remove any new duplicates
355 | masks = torch.cat(new_masks, dim=0)
356 | boxes = batched_mask_to_box(masks)
357 | keep_by_nms = batched_nms(
358 | boxes.float(),
359 | torch.as_tensor(scores),
360 | torch.zeros_like(boxes[:, 0]), # categories
361 | iou_threshold=nms_thresh,
362 | )
363 |
364 | # Only recalculate RLEs for masks that have changed
365 | for i_mask in keep_by_nms:
366 | if scores[i_mask] == 0.0:
367 | mask_torch = masks[i_mask].unsqueeze(0)
368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370 | mask_data.filter(keep_by_nms)
371 |
372 | return mask_data
373 |
--------------------------------------------------------------------------------
/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/modeling/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from typing import Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d, MLPBlock
14 |
15 |
16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17 | class ImageEncoderViT(nn.Module):
18 | def __init__(
19 | self,
20 | img_size: int = 1024,
21 | patch_size: int = 16,
22 | in_chans: int = 3,
23 | embed_dim: int = 768,
24 | depth: int = 12,
25 | num_heads: int = 12,
26 | mlp_ratio: float = 4.0,
27 | out_chans: int = 256,
28 | qkv_bias: bool = True,
29 | norm_layer: Type[nn.Module] = nn.LayerNorm,
30 | act_layer: Type[nn.Module] = nn.GELU,
31 | use_abs_pos: bool = True,
32 | use_rel_pos: bool = False,
33 | rel_pos_zero_init: bool = True,
34 | window_size: int = 0,
35 | global_attn_indexes: Tuple[int, ...] = (),
36 | ) -> None:
37 | """
38 | Args:
39 | img_size (int): Input image size.
40 | patch_size (int): Patch size.
41 | in_chans (int): Number of input image channels.
42 | embed_dim (int): Patch embedding dimension.
43 | depth (int): Depth of ViT.
44 | num_heads (int): Number of attention heads in each ViT block.
45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
47 | norm_layer (nn.Module): Normalization layer.
48 | act_layer (nn.Module): Activation layer.
49 | use_abs_pos (bool): If True, use absolute positional embeddings.
50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52 | window_size (int): Window size for window attention blocks.
53 | global_attn_indexes (list): Indexes for blocks using global attention.
54 | """
55 | super().__init__()
56 | self.img_size = img_size
57 |
58 | self.patch_embed = PatchEmbed(
59 | kernel_size=(patch_size, patch_size),
60 | stride=(patch_size, patch_size),
61 | in_chans=in_chans,
62 | embed_dim=embed_dim,
63 | )
64 |
65 | self.pos_embed: Optional[nn.Parameter] = None
66 | if use_abs_pos:
67 | # Initialize absolute positional embedding with pretrain image size.
68 | self.pos_embed = nn.Parameter(
69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70 | )
71 |
72 | self.blocks = nn.ModuleList()
73 | for i in range(depth):
74 | block = Block(
75 | dim=embed_dim,
76 | num_heads=num_heads,
77 | mlp_ratio=mlp_ratio,
78 | qkv_bias=qkv_bias,
79 | norm_layer=norm_layer,
80 | act_layer=act_layer,
81 | use_rel_pos=use_rel_pos,
82 | rel_pos_zero_init=rel_pos_zero_init,
83 | window_size=window_size if i not in global_attn_indexes else 0,
84 | input_size=(img_size // patch_size, img_size // patch_size),
85 | )
86 | self.blocks.append(block)
87 |
88 | self.neck = nn.Sequential(
89 | nn.Conv2d(
90 | embed_dim,
91 | out_chans,
92 | kernel_size=1,
93 | bias=False,
94 | ),
95 | LayerNorm2d(out_chans),
96 | nn.Conv2d(
97 | out_chans,
98 | out_chans,
99 | kernel_size=3,
100 | padding=1,
101 | bias=False,
102 | ),
103 | LayerNorm2d(out_chans),
104 | )
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = self.patch_embed(x)
108 | if self.pos_embed is not None:
109 | x = x + self.pos_embed
110 |
111 | for blk in self.blocks:
112 | x = blk(x)
113 |
114 | x = self.neck(x.permute(0, 3, 1, 2))
115 |
116 | return x
117 |
118 |
119 | class Block(nn.Module):
120 | """Transformer blocks with support of window attention and residual propagation blocks"""
121 |
122 | def __init__(
123 | self,
124 | dim: int,
125 | num_heads: int,
126 | mlp_ratio: float = 4.0,
127 | qkv_bias: bool = True,
128 | norm_layer: Type[nn.Module] = nn.LayerNorm,
129 | act_layer: Type[nn.Module] = nn.GELU,
130 | use_rel_pos: bool = False,
131 | rel_pos_zero_init: bool = True,
132 | window_size: int = 0,
133 | input_size: Optional[Tuple[int, int]] = None,
134 | ) -> None:
135 | """
136 | Args:
137 | dim (int): Number of input channels.
138 | num_heads (int): Number of attention heads in each ViT block.
139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
141 | norm_layer (nn.Module): Normalization layer.
142 | act_layer (nn.Module): Activation layer.
143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145 | window_size (int): Window size for window attention blocks. If it equals 0, then
146 | use global attention.
147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
148 | positional parameter size.
149 | """
150 | super().__init__()
151 | self.norm1 = norm_layer(dim)
152 | self.attn = Attention(
153 | dim,
154 | num_heads=num_heads,
155 | qkv_bias=qkv_bias,
156 | use_rel_pos=use_rel_pos,
157 | rel_pos_zero_init=rel_pos_zero_init,
158 | input_size=input_size if window_size == 0 else (window_size, window_size),
159 | )
160 |
161 | self.norm2 = norm_layer(dim)
162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163 |
164 | self.window_size = window_size
165 |
166 | def forward(self, x: torch.Tensor) -> torch.Tensor:
167 | shortcut = x
168 | x = self.norm1(x)
169 | # Window partition
170 | if self.window_size > 0:
171 | H, W = x.shape[1], x.shape[2]
172 | x, pad_hw = window_partition(x, self.window_size)
173 |
174 | x = self.attn(x)
175 | # Reverse window partition
176 | if self.window_size > 0:
177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178 |
179 | x = shortcut + x
180 | x = x + self.mlp(self.norm2(x))
181 |
182 | return x
183 |
184 |
185 | class Attention(nn.Module):
186 | """Multi-head Attention block with relative position embeddings."""
187 |
188 | def __init__(
189 | self,
190 | dim: int,
191 | num_heads: int = 8,
192 | qkv_bias: bool = True,
193 | use_rel_pos: bool = False,
194 | rel_pos_zero_init: bool = True,
195 | input_size: Optional[Tuple[int, int]] = None,
196 | ) -> None:
197 | """
198 | Args:
199 | dim (int): Number of input channels.
200 | num_heads (int): Number of attention heads.
201 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
202 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
205 | positional parameter size.
206 | """
207 | super().__init__()
208 | self.num_heads = num_heads
209 | head_dim = dim // num_heads
210 | self.scale = head_dim**-0.5
211 |
212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213 | self.proj = nn.Linear(dim, dim)
214 |
215 | self.use_rel_pos = use_rel_pos
216 | if self.use_rel_pos:
217 | assert (
218 | input_size is not None
219 | ), "Input size must be provided if using relative positional encoding."
220 | # initialize relative positional embeddings
221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223 |
224 | def forward(self, x: torch.Tensor) -> torch.Tensor:
225 | B, H, W, _ = x.shape
226 | # qkv with shape (3, B, nHead, H * W, C)
227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228 | # q, k, v with shape (B * nHead, H * W, C)
229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230 |
231 | attn = (q * self.scale) @ k.transpose(-2, -1)
232 |
233 | if self.use_rel_pos:
234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235 |
236 | attn = attn.softmax(dim=-1)
237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238 | x = self.proj(x)
239 |
240 | return x
241 |
242 |
243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244 | """
245 | Partition into non-overlapping windows with padding if needed.
246 | Args:
247 | x (tensor): input tokens with [B, H, W, C].
248 | window_size (int): window size.
249 |
250 | Returns:
251 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
252 | (Hp, Wp): padded height and width before partition
253 | """
254 | B, H, W, C = x.shape
255 |
256 | pad_h = (window_size - H % window_size) % window_size
257 | pad_w = (window_size - W % window_size) % window_size
258 | if pad_h > 0 or pad_w > 0:
259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260 | Hp, Wp = H + pad_h, W + pad_w
261 |
262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264 | return windows, (Hp, Wp)
265 |
266 |
267 | def window_unpartition(
268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269 | ) -> torch.Tensor:
270 | """
271 | Window unpartition into original sequences and removing padding.
272 | Args:
273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274 | window_size (int): window size.
275 | pad_hw (Tuple): padded height and width (Hp, Wp).
276 | hw (Tuple): original height and width (H, W) before padding.
277 |
278 | Returns:
279 | x: unpartitioned sequences with [B, H, W, C].
280 | """
281 | Hp, Wp = pad_hw
282 | H, W = hw
283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286 |
287 | if Hp > H or Wp > W:
288 | x = x[:, :H, :W, :].contiguous()
289 | return x
290 |
291 |
292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293 | """
294 | Get relative positional embeddings according to the relative positions of
295 | query and key sizes.
296 | Args:
297 | q_size (int): size of query q.
298 | k_size (int): size of key k.
299 | rel_pos (Tensor): relative position embeddings (L, C).
300 |
301 | Returns:
302 | Extracted positional embeddings according to relative positions.
303 | """
304 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
305 | # Interpolate rel pos if needed.
306 | if rel_pos.shape[0] != max_rel_dist:
307 | # Interpolate rel pos.
308 | rel_pos_resized = F.interpolate(
309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310 | size=max_rel_dist,
311 | mode="linear",
312 | )
313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314 | else:
315 | rel_pos_resized = rel_pos
316 |
317 | # Scale the coords with short length if shapes for q and k are different.
318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321 |
322 | return rel_pos_resized[relative_coords.long()]
323 |
324 |
325 | def add_decomposed_rel_pos(
326 | attn: torch.Tensor,
327 | q: torch.Tensor,
328 | rel_pos_h: torch.Tensor,
329 | rel_pos_w: torch.Tensor,
330 | q_size: Tuple[int, int],
331 | k_size: Tuple[int, int],
332 | ) -> torch.Tensor:
333 | """
334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336 | Args:
337 | attn (Tensor): attention map.
338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343 |
344 | Returns:
345 | attn (Tensor): attention map with added relative positional embeddings.
346 | """
347 | q_h, q_w = q_size
348 | k_h, k_w = k_size
349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351 |
352 | B, _, dim = q.shape
353 | r_q = q.reshape(B, q_h, q_w, dim)
354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356 |
357 | attn = (
358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359 | ).view(B, q_h * q_w, k_h * k_w)
360 |
361 | return attn
362 |
363 |
364 | class PatchEmbed(nn.Module):
365 | """
366 | Image to Patch Embedding.
367 | """
368 |
369 | def __init__(
370 | self,
371 | kernel_size: Tuple[int, int] = (16, 16),
372 | stride: Tuple[int, int] = (16, 16),
373 | padding: Tuple[int, int] = (0, 0),
374 | in_chans: int = 3,
375 | embed_dim: int = 768,
376 | ) -> None:
377 | """
378 | Args:
379 | kernel_size (Tuple): kernel size of the projection layer.
380 | stride (Tuple): stride of the projection layer.
381 | padding (Tuple): padding size of the projection layer.
382 | in_chans (int): Number of input image channels.
383 | embed_dim (int): Patch embedding dimension.
384 | """
385 | super().__init__()
386 |
387 | self.proj = nn.Conv2d(
388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389 | )
390 |
391 | def forward(self, x: torch.Tensor) -> torch.Tensor:
392 | x = self.proj(x)
393 | # B C H W -> B H W C
394 | x = x.permute(0, 2, 3, 1)
395 | return x
396 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | src = src + dense_prompt_embeddings
128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129 | b, c, h, w = src.shape
130 |
131 | # Run the transformer
132 | hs, src = self.transformer(src, pos_src, tokens)
133 | iou_token_out = hs[:, 0, :]
134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135 |
136 |
137 | # Upscale mask embeddings and predict masks using the mask tokens
138 | src = src.transpose(1, 2).view(b, c, h, w)
139 | upscaled_embedding = self.output_upscaling(src)
140 | # print(upscaled_embedding.shape)
141 | hyper_in_list: List[torch.Tensor] = []
142 | for i in range(self.num_mask_tokens):
143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
144 | hyper_in = torch.stack(hyper_in_list, dim=1)
145 | # print(hyper_in.shape)
146 | b, c, h, w = upscaled_embedding.shape
147 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
148 |
149 | # Generate mask quality predictions
150 | iou_pred = self.iou_prediction_head(iou_token_out)
151 |
152 | return masks, iou_pred
153 |
154 |
155 | # Lightly adapted from
156 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
157 | class MLP(nn.Module):
158 | def __init__(
159 | self,
160 | input_dim: int,
161 | hidden_dim: int,
162 | output_dim: int,
163 | num_layers: int,
164 | sigmoid_output: bool = False,
165 | ) -> None:
166 | super().__init__()
167 | self.num_layers = num_layers
168 | h = [hidden_dim] * (num_layers - 1)
169 | self.layers = nn.ModuleList(
170 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
171 | )
172 | self.sigmoid_output = sigmoid_output
173 |
174 | def forward(self, x):
175 | for i, layer in enumerate(self.layers):
176 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
177 | if self.sigmoid_output:
178 | x = F.sigmoid(x)
179 | return x
180 |
--------------------------------------------------------------------------------
/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | # print(padding_point.shape)
85 | points = torch.cat([points, padding_point], dim=1)
86 | # print(labels.shape)
87 |
88 | labels = torch.cat([labels, padding_label], dim=1)
89 |
90 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
91 | point_embedding[labels == -1] = 0.0
92 | point_embedding[labels == -1] += self.not_a_point_embed.weight
93 | point_embedding[labels == 0] += self.point_embeddings[0].weight
94 | point_embedding[labels == 1] += self.point_embeddings[1].weight
95 | # print(self.point_embeddings[1].weight.shape)
96 | return point_embedding
97 |
98 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
99 | """Embeds box prompts."""
100 | boxes = boxes + 0.5 # Shift to center of pixel
101 | coords = boxes.reshape(-1, 2, 2)
102 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
103 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
104 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
105 | return corner_embedding
106 |
107 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
108 | """Embeds mask inputs."""
109 | mask_embedding = self.mask_downscaling(masks)
110 | return mask_embedding
111 |
112 | def _get_batch_size(
113 | self,
114 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
115 | boxes: Optional[torch.Tensor],
116 | masks: Optional[torch.Tensor],
117 | ) -> int:
118 | """
119 | Gets the batch size of the output given the batch size of the input prompts.
120 | """
121 | if points is not None:
122 | return points[0].shape[0]
123 | elif boxes is not None:
124 | return boxes.shape[0]
125 | elif masks is not None:
126 | return masks.shape[0]
127 | else:
128 | return 1
129 |
130 | def _get_device(self) -> torch.device:
131 | return self.point_embeddings[0].weight.device
132 |
133 | def forward(
134 | self,
135 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
136 | boxes: Optional[torch.Tensor],
137 | masks: Optional[torch.Tensor],
138 | ) -> Tuple[torch.Tensor, torch.Tensor]:
139 | """
140 | Embeds different types of prompts, returning both sparse and dense
141 | embeddings.
142 |
143 | Arguments:
144 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
145 | and labels to embed.
146 | boxes (torch.Tensor or none): boxes to embed
147 | masks (torch.Tensor or none): masks to embed
148 |
149 | Returns:
150 | torch.Tensor: sparse embeddings for the points and boxes, with shape
151 | BxNx(embed_dim), where N is determined by the number of input points
152 | and boxes.
153 | torch.Tensor: dense embeddings for the masks, in the shape
154 | Bx(embed_dim)x(embed_H)x(embed_W)
155 | """
156 | bs = self._get_batch_size(points, boxes, masks)
157 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
158 | if points is not None:
159 | coords, labels = points
160 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
161 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
162 | if boxes is not None:
163 | box_embeddings = self._embed_boxes(boxes)
164 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
165 |
166 | if masks is not None:
167 | dense_embeddings = self._embed_masks(masks)
168 | else:
169 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
170 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
171 | )
172 |
173 | return sparse_embeddings, dense_embeddings
174 |
175 |
176 | class PositionEmbeddingRandom(nn.Module):
177 | """
178 | Positional encoding using random spatial frequencies.
179 | """
180 |
181 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
182 | super().__init__()
183 | if scale is None or scale <= 0.0:
184 | scale = 1.0
185 | self.register_buffer(
186 | "positional_encoding_gaussian_matrix",
187 | scale * torch.randn((2, num_pos_feats)),
188 | )
189 |
190 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
191 | """Positionally encode points that are normalized to [0,1]."""
192 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
193 | coords = 2 * coords - 1
194 | coords = coords @ self.positional_encoding_gaussian_matrix
195 | coords = 2 * np.pi * coords
196 | # outputs d_1 x ... x d_n x C shape
197 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
198 |
199 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
200 | """Generate positional encoding for a grid of the specified size."""
201 | h, w = size
202 | device: Any = self.positional_encoding_gaussian_matrix.device
203 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
204 | y_embed = grid.cumsum(dim=0) - 0.5
205 | x_embed = grid.cumsum(dim=1) - 0.5
206 | y_embed = y_embed / h
207 | x_embed = x_embed / w
208 |
209 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
210 | return pe.permute(2, 0, 1) # C x H x W
211 |
212 | def forward_with_coords(
213 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
214 | ) -> torch.Tensor:
215 | """Positionally encode points that are not normalized to [0,1]."""
216 | coords = coords_input.clone()
217 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
218 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
219 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
220 |
--------------------------------------------------------------------------------
/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | # def forwrad(self, batched_input, multimask_output, mode='train'):
54 | # if isinstance(batched_input, list):
55 | # outputs = self.forward_test(batched_input, multimask_output)
56 | # else:
57 | # outputs = self.forward_train(batched_input, multimask_output)
58 | # return outputs
59 |
60 | # def forward_train(self, batched_input, multimask_output):
61 | # input_images = self.preprocess(batched_input)
62 | # image_embeddings = self.image_encoder(input_images)
63 | # sparse_embeddings, dense_embeddings = self.prompt_encoder(
64 | # points=None,
65 | # boxes=None,
66 | # masks=None,
67 | # )
68 | # low_res_masks, iou_predictions = self.mask_decoder(
69 | # image_embeddings=image_embeddings,
70 | # image_pe=self.prompt_encoder.get_dense_pe(),
71 | # sparse_prompt_embeddings=sparse_embeddings,
72 | # dense_prompt_embeddings=dense_embeddings,
73 | # multimask_output=multimask_output,
74 | # )
75 | # masks = self.postprocess_masks(
76 | # low_res_masks,
77 | # input_size=input_images.shape[-2:],
78 | # original_size=input_images.shape[-2:],
79 | # )
80 | # outputs = {
81 | # "masks": masks,
82 | # "iou_predictions": iou_predictions,
83 | # "low_res_logits": low_res_masks,
84 | # }
85 | # return outputs
86 |
87 | @torch.no_grad()
88 | def forward(
89 | self,
90 | batched_input: List[Dict[str, Any]],
91 | multimask_output: bool,
92 | ) -> List[Dict[str, torch.Tensor]]:
93 | """
94 | Predicts masks end-to-end from provided images and prompts.
95 | If prompts are not known in advance, using SamPredictor is
96 | recommended over calling the model directly.
97 |
98 | Arguments:
99 | batched_input (list(dict)): A list over input images, each a
100 | dictionary with the following keys. A prompt key can be
101 | excluded if it is not present.
102 | 'image': The image as a torch tensor in 3xHxW format,
103 | already transformed for input to the model.
104 | 'original_size': (tuple(int, int)) The original size of
105 | the image before transformation, as (H, W).
106 | 'point_coords': (torch.Tensor) Batched point prompts for
107 | this image, with shape BxNx2. Already transformed to the
108 | input frame of the model.
109 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
110 | with shape BxN.
111 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
112 | Already transformed to the input frame of the model.
113 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
114 | in the form Bx1xHxW.
115 | multimask_output (bool): Whether the model should predict multiple
116 | disambiguating masks, or return a single mask.
117 |
118 | Returns:
119 | (list(dict)): A list over input images, where each element is
120 | as dictionary with the following keys.
121 | 'masks': (torch.Tensor) Batched binary mask predictions,
122 | with shape BxCxHxW, where B is the number of input prompts,
123 | C is determined by multimask_output, and (H, W) is the
124 | original size of the image.
125 | 'iou_predictions': (torch.Tensor) The model's predictions
126 | of mask quality, in shape BxC.
127 | 'low_res_logits': (torch.Tensor) Low resolution logits with
128 | shape BxCxHxW, where H=W=256. Can be passed as mask input
129 | to subsequent iterations of prediction.
130 | """
131 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
132 | # input_images = self.preprocess(batched_input)
133 | image_embeddings = self.image_encoder(input_images)
134 |
135 | outputs = []
136 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
137 | if "point_coords" in image_record:
138 | points = (image_record["point_coords"], image_record["point_labels"])
139 | else:
140 | points = None
141 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
142 | points=points,
143 | boxes=image_record.get("boxes", None),
144 | masks=image_record.get("mask_inputs", None),
145 | )
146 | low_res_masks, iou_predictions = self.mask_decoder(
147 | image_embeddings=curr_embedding.unsqueeze(0),
148 | image_pe=self.prompt_encoder.get_dense_pe(),
149 | sparse_prompt_embeddings=sparse_embeddings,
150 | dense_prompt_embeddings=dense_embeddings,
151 | multimask_output=multimask_output,
152 | )
153 | masks = self.postprocess_masks(
154 | low_res_masks,
155 | input_size=image_record["image"].shape[-2:],
156 | original_size=image_record["original_size"],
157 | )
158 | masks = masks > self.mask_threshold
159 | outputs.append(
160 | {
161 | "masks": masks,
162 | "iou_predictions": iou_predictions,
163 | "low_res_logits": low_res_masks,
164 | }
165 | )
166 | return outputs
167 |
168 | def postprocess_masks(
169 | self,
170 | masks: torch.Tensor,
171 | input_size: Tuple[int, ...],
172 | original_size: Tuple[int, ...],
173 | ) -> torch.Tensor:
174 | """
175 | Remove padding and upscale masks to the original image size.
176 |
177 | Arguments:
178 | masks (torch.Tensor): Batched masks from the mask_decoder,
179 | in BxCxHxW format.
180 | input_size (tuple(int, int)): The size of the image input to the
181 | model, in (H, W) format. Used to remove padding.
182 | original_size (tuple(int, int)): The original size of the image
183 | before resizing for input to the model, in (H, W) format.
184 |
185 | Returns:
186 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
187 | is given by original_size.
188 | """
189 | masks = F.interpolate(
190 | masks,
191 | (self.image_encoder.img_size, self.image_encoder.img_size),
192 | mode="bilinear",
193 | align_corners=False,
194 | )
195 | masks = masks[..., : input_size[0], : input_size[1]]
196 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
197 | return masks
198 |
199 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
200 | """Normalize pixel values and pad to a square input."""
201 | # Normalize colors
202 | x = (x - self.pixel_mean) / self.pixel_std
203 |
204 | # Pad
205 | h, w = x.shape[-2:]
206 | padh = self.image_encoder.img_size - h
207 | padw = self.image_encoder.img_size - w
208 | x = F.pad(x, (0, padw, 0, padh))
209 | return x
210 |
--------------------------------------------------------------------------------
/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attention layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)#.clone()
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from segment_anything.modeling import Sam
11 |
12 | from typing import Optional, Tuple
13 |
14 | from .utils.transforms import ResizeLongestSide
15 |
16 |
17 | class SamPredictor:
18 | def __init__(
19 | self,
20 | sam_model: Sam,
21 | ) -> None:
22 | """
23 | Uses SAM to calculate the image embedding for an image, and then
24 | allow repeated, efficient mask prediction given prompts.
25 |
26 | Arguments:
27 | sam_model (Sam): The model to use for mask prediction.
28 | """
29 | super().__init__()
30 | self.model = sam_model
31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32 | self.reset_image()
33 |
34 | def set_image(
35 | self,
36 | image: np.ndarray,
37 | image_format: str = "RGB",
38 | ) -> None:
39 | """
40 | Calculates the image embeddings for the provided image, allowing
41 | masks to be predicted with the 'predict' method.
42 |
43 | Arguments:
44 | image (np.ndarray): The image for calculating masks. Expects an
45 | image in HWC uint8 format, with pixel values in [0, 255].
46 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
47 | """
48 | assert image_format in [
49 | "RGB",
50 | "BGR",
51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52 | if image_format != self.model.image_format:
53 | image = image[..., ::-1]
54 |
55 | # Transform the image to the form expected by the model
56 | input_image = self.transform.apply_image(image)
57 | input_image_torch = torch.as_tensor(input_image, device=self.device)
58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59 |
60 | self.set_torch_image(input_image_torch, image.shape[:2])
61 |
62 | @torch.no_grad()
63 | def set_torch_image(
64 | self,
65 | transformed_image: torch.Tensor,
66 | original_image_size: Tuple[int, ...],
67 | ) -> None:
68 | """
69 | Calculates the image embeddings for the provided image, allowing
70 | masks to be predicted with the 'predict' method. Expects the input
71 | image to be already transformed to the format expected by the model.
72 |
73 | Arguments:
74 | transformed_image (torch.Tensor): The input image, with shape
75 | 1x3xHxW, which has been transformed with ResizeLongestSide.
76 | original_image_size (tuple(int, int)): The size of the image
77 | before transformation, in (H, W) format.
78 | """
79 | assert (
80 | len(transformed_image.shape) == 4
81 | and transformed_image.shape[1] == 3
82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 | self.reset_image()
85 |
86 | self.original_size = original_image_size
87 | self.input_size = tuple(transformed_image.shape[-2:])
88 | input_image = self.model.preprocess(transformed_image)
89 | self.features = self.model.image_encoder(input_image)
90 | self.is_image_set = True
91 |
92 | def predict(
93 | self,
94 | point_coords: Optional[np.ndarray] = None,
95 | point_labels: Optional[np.ndarray] = None,
96 | box: Optional[np.ndarray] = None,
97 | mask_input: Optional[np.ndarray] = None,
98 | multimask_output: bool = True,
99 | return_logits: bool = False,
100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101 | """
102 | Predict masks for the given input prompts, using the currently set image.
103 |
104 | Arguments:
105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106 | model. Each point is in (X,Y) in pixels.
107 | point_labels (np.ndarray or None): A length N array of labels for the
108 | point prompts. 1 indicates a foreground point and 0 indicates a
109 | background point.
110 | box (np.ndarray or None): A length 4 array given a box prompt to the
111 | model, in XYXY format.
112 | mask_input (np.ndarray): A low resolution mask input to the model, typically
113 | coming from a previous prediction iteration. Has form 1xHxW, where
114 | for SAM, H=W=256.
115 | multimask_output (bool): If true, the model will return three masks.
116 | For ambiguous input prompts (such as a single click), this will often
117 | produce better masks than a single prediction. If only a single
118 | mask is needed, the model's predicted quality score can be used
119 | to select the best mask. For non-ambiguous prompts, such as multiple
120 | input prompts, multimask_output=False can give better results.
121 | return_logits (bool): If true, returns un-thresholded masks logits
122 | instead of a binary mask.
123 |
124 | Returns:
125 | (np.ndarray): The output masks in CxHxW format, where C is the
126 | number of masks, and (H, W) is the original image size.
127 | (np.ndarray): An array of length C containing the model's
128 | predictions for the quality of each mask.
129 | (np.ndarray): An array of shape CxHxW, where C is the number
130 | of masks and H=W=256. These low resolution logits can be passed to
131 | a subsequent iteration as mask input.
132 | """
133 | if not self.is_image_set:
134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135 |
136 | # Transform input prompts
137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138 | if point_coords is not None:
139 | assert (
140 | point_labels is not None
141 | ), "point_labels must be supplied if point_coords is supplied."
142 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146 | if box is not None:
147 | box = self.transform.apply_boxes(box, self.original_size)
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149 | box_torch = box_torch[None, :]
150 | if mask_input is not None:
151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152 | mask_input_torch = mask_input_torch[None, :, :, :]
153 |
154 | masks, iou_predictions, low_res_masks = self.predict_torch(
155 | coords_torch,
156 | labels_torch,
157 | box_torch,
158 | mask_input_torch,
159 | multimask_output,
160 | return_logits=return_logits,
161 | )
162 |
163 | masks_np = masks[0].detach().cpu().numpy()
164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166 | return masks_np, iou_predictions_np, low_res_masks_np
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | """
179 | Predict masks for the given input prompts, using the currently set image.
180 | Input prompts are batched torch tensors and are expected to already be
181 | transformed to the input frame using ResizeLongestSide.
182 |
183 | Arguments:
184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (torch.Tensor or None): A BxN array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form Bx1xHxW, where
193 | for SAM, H=W=256. Masks returned by a previous iteration of the
194 | predict method do not need further transformation.
195 | multimask_output (bool): If true, the model will return three masks.
196 | For ambiguous input prompts (such as a single click), this will often
197 | produce better masks than a single prediction. If only a single
198 | mask is needed, the model's predicted quality score can be used
199 | to select the best mask. For non-ambiguous prompts, such as multiple
200 | input prompts, multimask_output=False can give better results.
201 | return_logits (bool): If true, returns un-thresholded masks logits
202 | instead of a binary mask.
203 |
204 | Returns:
205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
206 | number of masks, and (H, W) is the original image size.
207 | (torch.Tensor): An array of shape BxC containing the model's
208 | predictions for the quality of each mask.
209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
210 | of masks and H=W=256. These low res logits can be passed to
211 | a subsequent iteration as mask input.
212 | """
213 | if not self.is_image_set:
214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 |
216 | if point_coords is not None:
217 | points = (point_coords, point_labels)
218 | else:
219 | points = None
220 |
221 | # Embed prompts
222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 | points=points,
224 | boxes=boxes,
225 | masks=mask_input,
226 | )
227 |
228 | # Predict masks
229 | low_res_masks, iou_predictions = self.model.mask_decoder(
230 | image_embeddings=self.features,
231 | image_pe=self.model.prompt_encoder.get_dense_pe(),
232 | sparse_prompt_embeddings=sparse_embeddings,
233 | dense_prompt_embeddings=dense_embeddings,
234 | multimask_output=multimask_output,
235 | )
236 |
237 | # Upscale the masks to the original image resolution
238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239 |
240 | if not return_logits:
241 | masks = masks > self.model.mask_threshold
242 |
243 | return masks, iou_predictions, low_res_masks
244 |
245 | def get_image_embedding(self) -> torch.Tensor:
246 | """
247 | Returns the image embeddings for the currently set image, with
248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
250 | """
251 | if not self.is_image_set:
252 | raise RuntimeError(
253 | "An image must be set with .set_image(...) to generate an embedding."
254 | )
255 | assert self.features is not None, "Features must exist if an image has been set."
256 | return self.features
257 |
258 | @property
259 | def device(self) -> torch.device:
260 | return self.model.device
261 |
262 | def reset_image(self) -> None:
263 | """Resets the currently set image."""
264 | self.is_image_set = False
265 | self.features = None
266 | self.orig_h = None
267 | self.orig_w = None
268 | self.input_h = None
269 | self.input_w = None
270 |
--------------------------------------------------------------------------------
/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | import math
11 | from copy import deepcopy
12 | from itertools import product
13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
14 |
15 |
16 | class MaskData:
17 | """
18 | A structure for storing masks and their related data in batched format.
19 | Implements basic filtering and concatenation.
20 | """
21 |
22 | def __init__(self, **kwargs) -> None:
23 | for v in kwargs.values():
24 | assert isinstance(
25 | v, (list, np.ndarray, torch.Tensor)
26 | ), "MaskData only supports list, numpy arrays, and torch tensors."
27 | self._stats = dict(**kwargs)
28 |
29 | def __setitem__(self, key: str, item: Any) -> None:
30 | assert isinstance(
31 | item, (list, np.ndarray, torch.Tensor)
32 | ), "MaskData only supports list, numpy arrays, and torch tensors."
33 | self._stats[key] = item
34 |
35 | def __delitem__(self, key: str) -> None:
36 | del self._stats[key]
37 |
38 | def __getitem__(self, key: str) -> Any:
39 | return self._stats[key]
40 |
41 | def items(self) -> ItemsView[str, Any]:
42 | return self._stats.items()
43 |
44 | def filter(self, keep: torch.Tensor) -> None:
45 | for k, v in self._stats.items():
46 | if v is None:
47 | self._stats[k] = None
48 | elif isinstance(v, torch.Tensor):
49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50 | elif isinstance(v, np.ndarray):
51 | self._stats[k] = v[keep.detach().cpu().numpy()]
52 | elif isinstance(v, list) and keep.dtype == torch.bool:
53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54 | elif isinstance(v, list):
55 | self._stats[k] = [v[i] for i in keep]
56 | else:
57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58 |
59 | def cat(self, new_stats: "MaskData") -> None:
60 | for k, v in new_stats.items():
61 | if k not in self._stats or self._stats[k] is None:
62 | self._stats[k] = deepcopy(v)
63 | elif isinstance(v, torch.Tensor):
64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65 | elif isinstance(v, np.ndarray):
66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67 | elif isinstance(v, list):
68 | self._stats[k] = self._stats[k] + deepcopy(v)
69 | else:
70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71 |
72 | def to_numpy(self) -> None:
73 | for k, v in self._stats.items():
74 | if isinstance(v, torch.Tensor):
75 | self._stats[k] = v.detach().cpu().numpy()
76 |
77 |
78 | def is_box_near_crop_edge(
79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80 | ) -> torch.Tensor:
81 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88 | return torch.any(near_crop_edge, dim=1)
89 |
90 |
91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92 | box_xywh = deepcopy(box_xyxy)
93 | box_xywh[2] = box_xywh[2] - box_xywh[0]
94 | box_xywh[3] = box_xywh[3] - box_xywh[1]
95 | return box_xywh
96 |
97 |
98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99 | assert len(args) > 0 and all(
100 | len(a) == len(args[0]) for a in args
101 | ), "Batched iteration must have inputs of all the same size."
102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103 | for b in range(n_batches):
104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105 |
106 |
107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108 | """
109 | Encodes masks to an uncompressed RLE, in the format expected by
110 | pycoco tools.
111 | """
112 | # Put in fortran order and flatten h,w
113 | b, h, w = tensor.shape
114 | tensor = tensor.permute(0, 2, 1).flatten(1)
115 |
116 | # Compute change indices
117 | diff = tensor[:, 1:] ^ tensor[:, :-1]
118 | change_indices = diff.nonzero()
119 |
120 | # Encode run length
121 | out = []
122 | for i in range(b):
123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124 | cur_idxs = torch.cat(
125 | [
126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127 | cur_idxs + 1,
128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | ]
130 | )
131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132 | counts = [] if tensor[i, 0] == 0 else [0]
133 | counts.extend(btw_idxs.detach().cpu().tolist())
134 | out.append({"size": [h, w], "counts": counts})
135 | return out
136 |
137 |
138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139 | """Compute a binary mask from an uncompressed RLE."""
140 | h, w = rle["size"]
141 | mask = np.empty(h * w, dtype=bool)
142 | idx = 0
143 | parity = False
144 | for count in rle["counts"]:
145 | mask[idx : idx + count] = parity
146 | idx += count
147 | parity ^= True
148 | mask = mask.reshape(w, h)
149 | return mask.transpose() # Put in C order
150 |
151 |
152 | def area_from_rle(rle: Dict[str, Any]) -> int:
153 | return sum(rle["counts"][1::2])
154 |
155 |
156 | def calculate_stability_score(
157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158 | ) -> torch.Tensor:
159 | """
160 | Computes the stability score for a batch of masks. The stability
161 | score is the IoU between the binary masks obtained by thresholding
162 | the predicted mask logits at high and low values.
163 | """
164 | # One mask is always contained inside the other.
165 | # Save memory by preventing unnecessary cast to torch.int64
166 | intersections = (
167 | (masks > (mask_threshold + threshold_offset))
168 | .sum(-1, dtype=torch.int16)
169 | .sum(-1, dtype=torch.int32)
170 | )
171 | unions = (
172 | (masks > (mask_threshold - threshold_offset))
173 | .sum(-1, dtype=torch.int16)
174 | .sum(-1, dtype=torch.int32)
175 | )
176 | return intersections / unions
177 |
178 |
179 | def build_point_grid(n_per_side: int) -> np.ndarray:
180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181 | offset = 1 / (2 * n_per_side)
182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186 | return points
187 |
188 |
189 | def build_all_layer_point_grids(
190 | n_per_side: int, n_layers: int, scale_per_layer: int
191 | ) -> List[np.ndarray]:
192 | """Generates point grids for all crop layers."""
193 | points_by_layer = []
194 | for i in range(n_layers + 1):
195 | n_points = int(n_per_side / (scale_per_layer**i))
196 | points_by_layer.append(build_point_grid(n_points))
197 | return points_by_layer
198 |
199 |
200 | def generate_crop_boxes(
201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202 | ) -> Tuple[List[List[int]], List[int]]:
203 | """
204 | Generates a list of crop boxes of different sizes. Each layer
205 | has (2**i)**2 boxes for the ith layer.
206 | """
207 | crop_boxes, layer_idxs = [], []
208 | im_h, im_w = im_size
209 | short_side = min(im_h, im_w)
210 |
211 | # Original image
212 | crop_boxes.append([0, 0, im_w, im_h])
213 | layer_idxs.append(0)
214 |
215 | def crop_len(orig_len, n_crops, overlap):
216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217 |
218 | for i_layer in range(n_layers):
219 | n_crops_per_side = 2 ** (i_layer + 1)
220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221 |
222 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
223 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
224 |
225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227 |
228 | # Crops in XYWH format
229 | for x0, y0 in product(crop_box_x0, crop_box_y0):
230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231 | crop_boxes.append(box)
232 | layer_idxs.append(i_layer + 1)
233 |
234 | return crop_boxes, layer_idxs
235 |
236 |
237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238 | x0, y0, _, _ = crop_box
239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240 | # Check if boxes has a channel dimension
241 | if len(boxes.shape) == 3:
242 | offset = offset.unsqueeze(1)
243 | return boxes + offset
244 |
245 |
246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247 | x0, y0, _, _ = crop_box
248 | offset = torch.tensor([[x0, y0]], device=points.device)
249 | # Check if points has a channel dimension
250 | if len(points.shape) == 3:
251 | offset = offset.unsqueeze(1)
252 | return points + offset
253 |
254 |
255 | def uncrop_masks(
256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257 | ) -> torch.Tensor:
258 | x0, y0, x1, y1 = crop_box
259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260 | return masks
261 | # Coordinate transform masks
262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263 | pad = (x0, pad_x - x0, y0, pad_y - y0)
264 | return torch.nn.functional.pad(masks, pad, value=0)
265 |
266 |
267 | def remove_small_regions(
268 | mask: np.ndarray, area_thresh: float, mode: str
269 | ) -> Tuple[np.ndarray, bool]:
270 | """
271 | Removes small disconnected regions and holes in a mask. Returns the
272 | mask and an indicator of if the mask has been modified.
273 | """
274 | import cv2 # type: ignore
275 |
276 | assert mode in ["holes", "islands"]
277 | correct_holes = mode == "holes"
278 | working_mask = (correct_holes ^ mask).astype(np.uint8)
279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280 | sizes = stats[:, -1][1:] # Row 0 is background label
281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282 | if len(small_regions) == 0:
283 | return mask, False
284 | fill_labels = [0] + small_regions
285 | if not correct_holes:
286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287 | # If every region is below threshold, keep largest
288 | if len(fill_labels) == 0:
289 | fill_labels = [int(np.argmax(sizes)) + 1]
290 | mask = np.isin(regions, fill_labels)
291 | return mask, True
292 |
293 |
294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295 | from pycocotools import mask as mask_utils # type: ignore
296 |
297 | h, w = uncompressed_rle["size"]
298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300 | return rle
301 |
302 |
303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304 | """
305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307 | """
308 | # torch.max below raises an error on empty inputs, just skip in this case
309 | if torch.numel(masks) == 0:
310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311 |
312 | # Normalize shape to CxHxW
313 | shape = masks.shape
314 | h, w = shape[-2:]
315 | if len(shape) > 2:
316 | masks = masks.flatten(0, -3)
317 | else:
318 | masks = masks.unsqueeze(0)
319 |
320 | # Get top and bottom edges
321 | in_height, _ = torch.max(masks, dim=-1)
322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324 | in_height_coords = in_height_coords + h * (~in_height)
325 | top_edges, _ = torch.min(in_height_coords, dim=-1)
326 |
327 | # Get left and right edges
328 | in_width, _ = torch.max(masks, dim=-2)
329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330 | right_edges, _ = torch.max(in_width_coords, dim=-1)
331 | in_width_coords = in_width_coords + w * (~in_width)
332 | left_edges, _ = torch.min(in_width_coords, dim=-1)
333 |
334 | # If the mask is empty the right edge will be to the left of the left edge.
335 | # Replace these boxes with [0, 0, 0, 0]
336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338 | out = out * (~empty_filter).unsqueeze(-1)
339 |
340 | # Return to original shape
341 | if len(shape) > 2:
342 | out = out.reshape(*shape[:-2], 4)
343 | else:
344 | out = out[0]
345 |
346 | return out
347 |
--------------------------------------------------------------------------------
/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | # class ResizeLongestSideBatch:
17 | # def __init__(self, target_length: int) -> None:
18 | # self.target_length = target_length
19 |
20 | # def apply_image_batch(self, image_batch: np.ndarray) -> np.ndarray:
21 | # """
22 | # Expects a numpy array with shape BxHxWxC in uint8 format.
23 | # """
24 | # target_size = self.get_preprocess_shape(image_batch.shape[1], image_batch.shape[2], self.target_length)
25 | # # print(image_batch.shape)
26 | # return np.array(resize(to_pil_image(image_batch[:]), target_size))
27 |
28 |
29 | # @staticmethod
30 | # def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
31 | # """
32 | # Compute the output size given input size and target long side length.
33 | # """
34 | # scale = long_side_length * 1.0 / max(oldh, oldw)
35 | # newh, neww = oldh * scale, oldw * scale
36 | # neww = int(neww + 0.5)
37 | # newh = int(newh + 0.5)
38 | # return (newh, neww)
39 |
40 |
41 | # if __name__ == '__main__':
42 | # transform = ResizeLongestSideBatch(1024)
43 |
44 | # image_batch = np.ones((2,3,1080,1920))
45 | # transformer_batch = transform.apply_image_batch(image_batch)
46 | # print(transformer_batch.shape)
47 |
48 |
49 |
50 |
51 |
52 | class ResizeLongestSide:
53 | """
54 | Resizes images to the longest side 'target_length', as well as provides
55 | methods for resizing coordinates and boxes. Provides methods for
56 | transforming both numpy array and batched torch tensors.
57 | """
58 |
59 | def __init__(self, target_length: int) -> None:
60 | self.target_length = target_length
61 |
62 | def apply_image(self, image: np.ndarray) -> np.ndarray:
63 | """
64 | Expects a numpy array with shape HxWxC in uint8 format.
65 | """
66 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
67 | return np.array(resize(to_pil_image(image), target_size))
68 |
69 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
70 | """
71 | Expects a numpy array of length 2 in the final dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).astype(float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
84 | """
85 | Expects a numpy array shape Bx4. Requires the original image size
86 | in (H, W) format.
87 | """
88 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
89 | return boxes.reshape(-1, 4)
90 |
91 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
92 | """
93 | Expects batched images with shape BxCxHxW and float format. This
94 | transformation may not exactly match apply_image. apply_image is
95 | the transformation expected by the model.
96 | """
97 | # Expects an image in BCHW format. May not exactly match apply_image.
98 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
99 | return F.interpolate(
100 | image, target_size, mode="bilinear", align_corners=False, antialias=True
101 | )
102 |
103 | def apply_coords_torch(
104 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
105 | ) -> torch.Tensor:
106 | """
107 | Expects a torch tensor with length 2 in the last dimension. Requires the
108 | original image size in (H, W) format.
109 | """
110 | old_h, old_w = original_size
111 | new_h, new_w = self.get_preprocess_shape(
112 | original_size[0], original_size[1], self.target_length
113 | )
114 | coords = deepcopy(coords).to(torch.float)
115 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
116 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
117 | return coords
118 |
119 | def apply_boxes_torch(
120 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
121 | ) -> torch.Tensor:
122 | """
123 | Expects a torch tensor with shape Bx4. Requires the original image
124 | size in (H, W) format.
125 | """
126 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
127 | return boxes.reshape(-1, 4)
128 |
129 | @staticmethod
130 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
131 | """
132 | Compute the output size given input size and target long side length.
133 | """
134 | scale = long_side_length * 1.0 / max(oldh, oldw)
135 | newh, neww = oldh * scale, oldw * scale
136 | neww = int(neww + 0.5)
137 | newh = int(newh + 0.5)
138 | return (newh, neww)
139 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
3 |
4 | import logging
5 | import numpy as np
6 | import argparse
7 | import random
8 | import torch.backends.cudnn as cudnn
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | import torchvision.transforms as transforms
13 | from importlib import import_module
14 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
15 | from datasets.dataset_ufpr_sam import UFPR_ALPR_Dataset, UFPR_ALPR_Dataset, SamTransform, SamTransformTest, collater
16 | from lora_predictor import LoRA_SamPredictor
17 | import cv2
18 | from icecream import ic
19 | from tqdm import tqdm
20 | from scipy.ndimage.interpolation import zoom
21 | from segment_anything.utils.amg import (
22 | MaskData,
23 | area_from_rle,
24 | batch_iterator,
25 | batched_mask_to_box,
26 | box_xyxy_to_xywh,
27 | build_all_layer_point_grids,
28 | calculate_stability_score,
29 | coco_encode_rle,
30 | generate_crop_boxes,
31 | is_box_near_crop_edge,
32 | mask_to_rle_pytorch,
33 | remove_small_regions,
34 | rle_to_mask,
35 | uncrop_boxes_xyxy,
36 | uncrop_masks,
37 | uncrop_points,
38 | )
39 |
40 | def ap(tp, conf, count):
41 | tp = np.array(tp)
42 | conf = np.array(conf)
43 | i = np.argsort(-conf)
44 | tp, conf = tp[i], conf[i]
45 | n_gt = count
46 | fpc = (1-tp[i]).cumsum()
47 | tpc = (tp[i]).cumsum()
48 | recall_curve = tpc / (n_gt + 1e-16)
49 | precision_curve = tpc / (tpc + fpc)
50 |
51 | ap = compute_ap(precision_curve, recall_curve)
52 | return ap
53 |
54 | def compute_ap(precision, recall):
55 | """ Compute the average precision, given the recall and precision curves.
56 | Code originally from https://github.com/rbgirshick/py-faster-rcnn.
57 | # Arguments
58 | recall: The recall curve (list).
59 | precision: The precision curve (list).
60 | # Returns
61 | The average precision as computed in py-faster-rcnn.
62 | """
63 | # correct AP calculation
64 | # first append sentinel values at the end
65 | mrec = np.concatenate(([0.0], recall, [1.0]))
66 | mpre = np.concatenate(([0.0], precision, [0.0]))
67 |
68 | # compute the precision envelope
69 | for i in range(mpre.size - 1, 0, -1):
70 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
71 |
72 | # to calculate area under PR curve, look for points
73 | # where X axis (recall) changes value
74 | i = np.where(mrec[1:] != mrec[:-1])[0]
75 |
76 | # and sum (\Delta recall) * prec
77 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
78 | return ap
79 |
80 | def iou(a,b):
81 |
82 | left1,top1,right1,down1 = a[0], a[1], a[2], a[3]
83 | left2,top2,right2,down2 = b[0], b[1], b[2], b[3]
84 |
85 | area1 = (right1-left1)*(top1-down1)
86 | area2 = (right2-left2)*(top2-down2)
87 | area_sum = area1+area2
88 |
89 | left = max(left1,left2)
90 | right = min(right1,right2)
91 | top = max(top1,top2)
92 | bottom = min(down1,down2)
93 |
94 | if left>=right or top>=bottom:
95 | return 0
96 | else:
97 | inter = (right-left)*(top-bottom)
98 | return inter/(area_sum-inter)
99 |
100 |
101 | def mask2bbox(mask, is_gt):
102 | # pred: w, h | label: w, h
103 | if isinstance(mask, torch.Tensor):
104 | mask = mask.cpu().detach().numpy().astype(np.uint8)
105 | elif isinstance(mask, np.ndarray):
106 | mask = mask.astype(np.uint8)
107 | kernel = np.ones((3,3), np.uint8)
108 | mask = cv2.erode(mask, kernel, iterations=2)
109 | mask = cv2.dilate(mask, kernel, iterations=3)
110 | contours, hierarchy = cv2.findContours(
111 | mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
112 | )
113 |
114 | bboxes_list = []
115 | max_w, max_h = 0, 0
116 |
117 | for cont in contours:
118 | x1, y1, w, h = cv2.boundingRect(cont)
119 | x2, y2 = x1+w, y1+h
120 | bboxes_list.append([x1, y1, x2, y2])
121 |
122 | return bboxes_list
123 |
124 | TP = 0
125 | FP = 0
126 | FN = 0
127 | tp_list = []
128 | conf_list = []
129 | gt_count = 0
130 | pred_count = 0
131 |
132 | def evaluation(pred, label, mask_iou):
133 | global TP
134 | global FP
135 | global FN
136 | global tp_list
137 | global conf_list
138 | global gt_count
139 | global pred_count
140 |
141 | pred_bboxes = mask2bbox(pred, False)
142 | label_bboxes = mask2bbox(label, True)
143 |
144 | gt_count += len(label_bboxes)
145 | pred_count += len(pred_bboxes)
146 |
147 | if len(pred_bboxes) == 0:
148 | FN += 1
149 | else:
150 | for gt in label_bboxes:
151 | is_true = False
152 | for pred in pred_bboxes:
153 | # print(iou(pred, gt))
154 | if iou(pred, gt) >= 0.5:
155 | is_true = True
156 | if is_true:
157 | TP += 1
158 | tp_list.append(1.0)
159 | conf_list.append(mask_iou.item())
160 | else:
161 | FP += 1
162 | tp_list.append(0.0)
163 | conf_list.append(mask_iou.item())
164 | return pred_bboxes, label_bboxes
165 |
166 |
167 | def inference(args, multimask_output, predictor, test_save_path):
168 | # testset = UFPR_ALPR_Dataset(root=args.root_path, split='testing', transform=transforms.Compose([Resizer([args.img_size, args.img_size])]))
169 | testset = UFPR_ALPR_Dataset(root=args.root_path, split='testing', transform=SamTransformTest(1024))
170 |
171 | testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=2, collate_fn=collater, pin_memory=True)
172 | # trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collater, pin_memory=True, worker_init_fn=worker_init_fn)
173 |
174 | logging.info(f'{len(testloader)} test iterations per epoch')
175 | predictor.model.eval()
176 |
177 | for i_batch, sample_batch in tqdm(enumerate(testloader)):
178 | # print(sample_batch.keys())
179 |
180 | with torch.no_grad():
181 | image, label = sample_batch['image'].cuda(), sample_batch['label'].cuda()
182 |
183 | show_image = image.squeeze(0) * predictor.pixel_std.cuda() + predictor.pixel_mean.cuda()
184 | # h, w = image.shape[2], image.shape[3]
185 | label = label.unsqueeze(0).unsqueeze(1)
186 |
187 | label = predictor.model.sam.postprocess_masks(label, predictor.input_size, predictor.original_size).squeeze().detach().cpu().numpy()
188 |
189 | masks, iou_predictions, low_res_masks = predictor.forward_test(image, multimask_output)
190 | bset_idx = torch.argmax(iou_predictions)
191 | masks = masks.squeeze()
192 | iou_predictions = iou_predictions.squeeze()
193 | best_idx = torch.argmax(iou_predictions)
194 | # masks = masks[best_idx]
195 | mask_iou = iou_predictions[best_idx]
196 | # print(iou_predictions.shape)
197 | # raise
198 | mask = masks[bset_idx].squeeze().detach().cpu().numpy()
199 |
200 | min_area = 2500
201 | mask, _ = remove_small_regions(mask, min_area, 'islands')
202 | mask, _ = remove_small_regions(mask, min_area, 'holes')
203 |
204 | pred_bboxes, label_bboxes = evaluation(mask, label, mask_iou)
205 |
206 | image_np = predictor.model.sam.postprocess_masks(show_image.clone().unsqueeze(0), predictor.input_size, predictor.original_size).squeeze().permute(1,2,0).detach().cpu().numpy()
207 | # image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
208 | show_mask = np.expand_dims(mask.copy(), axis=2).astype(np.uint8)
209 | show_mask = cv2.cvtColor(show_mask, cv2.COLOR_GRAY2BGR)
210 | results = cv2.addWeighted(image_np, 1.0, show_mask*255, 0.5, 0, 0, cv2.CV_32F)
211 |
212 | if label_bboxes is not None:
213 | for gt in label_bboxes:
214 | cv2.rectangle(results, (int(gt[0]),int(gt[1])), (int(gt[2]),int(gt[3])), color=(255,0,0), thickness=2)
215 | if pred_bboxes is not None:
216 | for pred in pred_bboxes:
217 | cv2.rectangle(results, (int(pred[0]),int(pred[1])), (int(pred[2]),int(pred[3])), color=(0,255,0), thickness=2)
218 |
219 | cv2.imwrite(os.path.join(test_save_path, '{}.png'.format(i_batch)), results)
220 |
221 | P = TP / (pred_count + 1e-16)
222 | R = TP / (gt_count + 1e-16)
223 | F1 = 2 * P * R / (P + R + 1e-16)
224 | AP50 = ap(tp_list, conf_list, gt_count)
225 |
226 | print('P: {:.4f}\t'.format(P),
227 | 'R: {:.4f}\t'.format(R),
228 | 'F1: {:.4f}\t'.format(F1),
229 | 'AP50: {:.4f}\t'.format(AP50))
230 | # return P, R, F1, AP50
231 |
232 |
233 | if __name__=='__main__':
234 | parser = argparse.ArgumentParser()
235 | parser.add_argument('--root_path', type=str, default='/media/disk1/yxding/dhx/Dataset/UFPR-ALPR/')
236 | parser.add_argument('--dataset', type=str, default='UFPR')
237 | parser.add_argument('-num_classes', type=int, default=1)
238 | parser.add_argument('--img_size', type=int, default=1024)
239 | parser.add_argument('--seed', type=int, default=0)
240 | parser.add_argument('--save_image', action='store_true')
241 | parser.add_argument('--deterministic', type=int, default=1)
242 | parser.add_argument('--ckpt', type=str, default='./checkpoints/sam_vit_b_01ec64.pth')
243 | parser.add_argument('--lora_ckpt', type=str,
244 | default="/media/disk1/yxding/dhx/Project/LP_SAM/LoRA_LP/exp/refine/UFPR_1024_2023-08-14-12:47:59_vit_b_sam_lora_image_encoder_mask_decoder_cls1_epo160_bs1_lr0.0005_seed0/epoch_90.pth")
245 | parser.add_argument('--vit_name', type=str, default='vit_b')
246 | parser.add_argument('--rank', type=int, default=4)
247 | parser.add_argument('--module', type=str, default='sam_lora_image_encoder_mask_decoder')
248 |
249 | args = parser.parse_args()
250 |
251 | if not args.deterministic:
252 | cudnn.benchmark = True
253 | cudnn.deterministic = False
254 | else:
255 | cudnn.benchmark = False
256 | cudnn.deterministic = True
257 |
258 | random.seed(args.seed)
259 | np.random.seed(args.seed)
260 | torch.manual_seed(args.seed)
261 | torch.cuda.manual_seed(args.seed)
262 | dataset_name = args.dataset
263 | dataset_config = {
264 | 'UFPR': {
265 | 'root_path': args.root_path,
266 | 'num_classes': args.num_classes,
267 | }
268 | }
269 |
270 | load_ckpt_path = args.lora_ckpt
271 | output_dir = os.path.join(os.path.split(load_ckpt_path)[0], 'predictions_predictor')
272 | if not os.path.exists(output_dir):
273 | os.makedirs(output_dir)
274 |
275 | sam = sam_model_registry[args.vit_name](checkpoint=args.ckpt)
276 |
277 | pkg = import_module(args.module)
278 | net = pkg.LoRA_Sam(sam, args.rank).cuda()
279 |
280 | predictor = LoRA_SamPredictor(net)
281 |
282 | assert args.lora_ckpt is not None
283 | predictor.model.load_lora_parameters(args.lora_ckpt)
284 |
285 | multimask_output = True
286 | print(os.path.split(load_ckpt_path)[1])
287 | inference(args, multimask_output, predictor, output_dir)
288 |
289 |
290 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 |
11 | from lora_predictor import LoRA_SamPredictor
12 |
13 | from importlib import import_module
14 |
15 | from segment_anything import sam_model_registry
16 | from trainer import trainer_UFPR
17 | from icecream import ic
18 | import time
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--root_path', type=str, default='/media/disk1/yxding/dhx/Dataset/UFPR-ALPR/', help='root dir for data')
22 | parser.add_argument('--output', type=str, default='./exp/')
23 | parser.add_argument('--dataset', type=str, default='UFPR')
24 | parser.add_argument('--num_classes', type=int, default=1, help='output channel of network')
25 | parser.add_argument('--max_epochs', type=int, default=10)
26 | parser.add_argument('--batch_size', type=int, default=2)
27 | parser.add_argument('--n_gpu', type=int, default=1)
28 | parser.add_argument('--deterministic', type=int, default=1)
29 | parser.add_argument('--base_lr', type=float, default=0.005)
30 | parser.add_argument('--img_size', type=int, default=1024)
31 | parser.add_argument('--seed', type=int, default=0)
32 | parser.add_argument('--vit_name', type=str, default='vit_b')
33 | parser.add_argument('--ckpt', type=str, default='./checkpoints/sam_vit_b_01ec64.pth')
34 | parser.add_argument('--lora_ckpt', type=str, default=None)
35 | parser.add_argument('--rank', type=int, default=2)
36 | parser.add_argument('--warmup', action='store_true')
37 | parser.add_argument('--warmup_period', type=int, default=250)
38 | parser.add_argument('--AdamW', action='store_true')
39 | parser.add_argument('--module', type=str, default='sam_lora_image_encoder_mask_decoder')
40 | parser.add_argument('--dice_param', type=float, default=0.8)
41 | args = parser.parse_args()
42 |
43 | if __name__=='__main__':
44 | if not args.deterministic:
45 | cudnn.benchmark = True
46 | cudnn.deterministic = False
47 | else:
48 | cudnn.benchmark = False
49 | cudnn.deterministic = True
50 |
51 | random.seed(args.seed)
52 | np.random.seed(args.seed)
53 | torch.manual_seed(args.seed)
54 | torch.cuda.manual_seed(args.seed)
55 |
56 | dataset_name = args.dataset
57 | dataset_config = {
58 | 'UFPR': {
59 | 'root_path': args.root_path,
60 | 'num_classes': args.num_classes,
61 | }
62 | }
63 |
64 | args.is_pretrain = True
65 | args.exp = dataset_name + '_' + str(args.img_size)
66 | log_path = os.path.join(args.output, "{}".format(args.exp))
67 | time_str = time.strftime('_%Y-%m-%d-%H:%M:%S', time.localtime())
68 | log_path = log_path + time_str
69 | log_path = log_path + '_' + args.vit_name
70 | log_path = log_path + '_' + str(args.module)
71 | log_path = log_path + '_cls' + str(args.num_classes)
72 | log_path = log_path + '_epo' + str(args.max_epochs)
73 | log_path = log_path + '_bs' + str(args.batch_size)
74 | log_path = log_path + '_lr' + str(args.base_lr)
75 | log_path = log_path + '_seed' + str(args.seed)
76 | log_path = log_path + '_rank' + str(args.rank)
77 |
78 | if not os.path.exists(log_path):
79 | os.makedirs(log_path)
80 |
81 | sam = sam_model_registry[args.vit_name](checkpoint=args.ckpt)
82 |
83 | pkg = import_module(args.module)
84 | net = pkg.LoRA_Sam(sam, args.rank).cuda()
85 |
86 | predictor = LoRA_SamPredictor(net)
87 |
88 | if args.lora_ckpt is not None:
89 | predictor.model.load_lora_parameters(args.lora_ckpt)
90 |
91 | multimask_output = True
92 |
93 | trainer = {'UFPR': trainer_UFPR}
94 | trainer[dataset_name](args, predictor, log_path, multimask_output)
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import random
5 | import time
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 |
12 | from tensorboardX import SummaryWriter
13 | from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
14 | from torch.utils.data import DataLoader
15 | import torch.nn.functional as F
16 | from tqdm import tqdm
17 | from utils import Focal_loss, DiceLoss, DiceLoss_softmax
18 | from torchvision import transforms
19 | from icecream import ic
20 | import cv2
21 |
22 | # from datasets.dataset_ufpr_cls2 import UFPR_ALPR_Dataset, RandomGenerator
23 | from datasets.dataset_ufpr_sam import UFPR_ALPR_Dataset, SamTransform, collater
24 |
25 | def calc_loss(outputs_logits, low_res_label_batch, dice_loss, dice_weight:float=0.8):
26 | loss_dice = dice_loss(outputs_logits, low_res_label_batch, softmax=True)
27 | return loss_dice
28 |
29 |
30 | def trainer_UFPR(args, predictor, log_path, multimask_output):
31 | logging.basicConfig(filename=os.path.join(log_path,'log.txt'), level=logging.INFO,
32 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
33 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
34 | logging.info(str(args))
35 |
36 | base_lr = args.base_lr
37 | num_classes = args.num_classes
38 | batch_size = args.batch_size * args.n_gpu
39 | trainset = UFPR_ALPR_Dataset(root=args.root_path, split='training', transform=SamTransform(1024))
40 | print("The length of train set is: {}".format(len(trainset)))
41 |
42 | curr_epoch = 0
43 | if args.lora_ckpt != None:
44 | curr_epoch = os.path.split(args.lora_ckpt)[1]
45 | curr_epoch = int(curr_epoch.replace('epoch_','').replace('.pth',''))
46 |
47 |
48 | def worker_init_fn(worker_id):
49 | random.seed(args.seed + worker_id)
50 |
51 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collater, pin_memory=True, worker_init_fn=worker_init_fn)
52 |
53 | predictor.model.train()
54 |
55 | dice_loss = DiceLoss_softmax()
56 |
57 | if args.warmup:
58 | b_lr = base_lr / args.warmup_period
59 | else:
60 | b_lr = base_lr
61 |
62 | if args.AdamW:
63 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, predictor.model.parameters()), lr=b_lr, betas=(0.9, 0.999), weight_decay=0.1)
64 | else:
65 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, predictor.model.parameters()), lr=b_lr, momentum=0.9, weight_decay=0.0001)
66 |
67 | writer = SummaryWriter(os.path.join(log_path, 'log'))
68 | iter_num = 0
69 | max_epoch = args.max_epochs
70 | max_iterations = args.max_epochs * len(trainloader)
71 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
72 |
73 |
74 | iterator = tqdm(range(curr_epoch, max_epoch), ncols=70)
75 |
76 | for epoch_num in iterator:
77 | for i_batch, sampled_batch in enumerate(trainloader):
78 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
79 | low_res_label_batch = sampled_batch['low_res_label']
80 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
81 | low_res_label_batch = low_res_label_batch.cuda()
82 | outputs = predictor.forward(image_batch, multimask_output)
83 |
84 | loss_level0 = calc_loss(outputs[:,0,:,:], low_res_label_batch, dice_loss, args.dice_param)
85 | loss_level1 = calc_loss(outputs[:,1,:,:], low_res_label_batch, dice_loss, args.dice_param)
86 | loss_level2 = calc_loss(outputs[:,2,:,:], low_res_label_batch, dice_loss, args.dice_param)
87 |
88 | loss = 1/3 * (loss_level0 + loss_level1 + loss_level2)
89 |
90 | optimizer.zero_grad()
91 | loss.backward()
92 | optimizer.step()
93 |
94 | if args.warmup and iter_num < args.warmup_period:
95 | lr_ = base_lr * ((iter_num + 1) / args.warmup_period)
96 | for param_group in optimizer.param_groups:
97 | param_group['lr'] = lr_
98 | else:
99 | if args.warmup:
100 | shift_iter = iter_num - args.warmup_period
101 | assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
102 | else:
103 | shift_iter = iter_num
104 | lr_ = base_lr * (1.0 - shift_iter / max_iterations) ** 0.9
105 | for param_group in optimizer.param_groups:
106 | param_group['lr'] = lr_
107 |
108 | iter_num = iter_num + 1
109 | writer.add_scalar('info/lr', lr_, iter_num)
110 | writer.add_scalar('info/total_loss', loss, iter_num)
111 |
112 |
113 | logging.info('iteration %d : loss : %f, loss_mse : %f, loss_dice : %f, max_label : %f' % (iter_num, loss.item(), loss_mse.item(), loss_dice.item(), torch.max(outputs).item()))
114 |
115 | if iter_num % 20 == 0:
116 | image = image_batch[1, :, :, :]
117 | image = (image - image.min()) / (image.max() - image.min())
118 | writer.add_image('train/image', image, iter_num)
119 | writer.add_image('train/pred_level0', (outputs[1, 0, ...]>0).unsqueeze(0) * 50, iter_num)
120 | writer.add_image('train/pred_level1', (outputs[1, 1, ...]>0).unsqueeze(0) * 50, iter_num)
121 | writer.add_image('train/pred_level2', (outputs[1, 2, ...]>0).unsqueeze(0) * 50, iter_num)
122 | labs = low_res_label_batch[1, ...].unsqueeze(0) * 50
123 | writer.add_image('train/gt', labs, iter_num)
124 |
125 | save_interval = 1
126 | if (epoch_num + 1) % save_interval == 0:
127 | save_model_path = os.path.join(log_path, 'epoch_'+str(epoch_num+1)+'.pth')
128 | try:
129 | predictor.model.save_lora_parameters(save_model_path)
130 | except:
131 | predictor.model.module.save_lora_parameters(save_model_path)
132 | logging.info('save model to {}'.format(save_model_path))
133 |
134 | if epoch_num >= max_epoch - 1:
135 | save_model_path = os.path.join(log_path, 'epoch_'+str(epoch_num+1)+'.pth')
136 | try:
137 | predictor.model.save_lora_parameters(save_model_path)
138 | except:
139 | predictor.model.module.save_lora_parameters(save_model_path)
140 | logging.info('save model to {}'.format(save_model_path))
141 | iterator.close()
142 | break
143 |
144 | writer.close()
145 | return 'Training Finished!'
146 |
147 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any
3 | import numpy as np
4 | from scipy.ndimage import zoom
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import imageio
9 | from einops import repeat
10 | from icecream import ic
11 |
12 | class Focal_loss(nn.Module):
13 | def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True):
14 | super(Focal_loss, self).__init__()
15 | self.size_average = size_average
16 | if isinstance(alpha, list):
17 | assert len(alpha) == num_classes
18 | print(f'Focal loss alpha={alpha}, will assign alpha values for each class')
19 | self.alpha = torch.Tensor(alpha)
20 | else:
21 | assert alpha < 1
22 | print(f'Focal loss alpha={alpha}, will shrink the impact in background')
23 | self.alpha = torch.zeros(num_classes)
24 | self.alpha[0] = alpha
25 | self.alpha[1:] = 1 - alpha
26 | self.gamma = gamma
27 | self.num_classes = num_classes
28 |
29 | def forward(self, preds, labels):
30 | self.alpha = self.alpha.to(preds.device)
31 | preds = preds.permute(0, 2, 3, 1).contiguous()
32 | preds = preds.view(-1, preds.size(-1))
33 | B, H, W = labels.shape
34 | assert B * W * H == preds.shape[0]
35 | assert preds.shape[-1] == self.num_classes
36 | preds_logsoft = F.log_softmax(preds, dim=1)
37 | preds_softmax = torch.exp(preds_logsoft)
38 |
39 | preds_softmax = preds_softmax.gather(1, labels.view(-1, 1))
40 | preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1))
41 | alpha = self.alpha.gather(0, labels.view(-1))
42 | loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma), preds_logsoft)
43 |
44 | loss = torch.mul(alpha, loss.t())
45 |
46 | if self.size_average:
47 | loss = loss.mean()
48 | else:
49 | loss = loss.sum()
50 | return loss
51 |
52 | # class focal_loss()
53 |
54 | class DiceLoss_softmax(nn.Module):
55 | def __init__(self):
56 | super(DiceLoss_softmax, self).__init__()
57 | self.num_classes = 2
58 |
59 | def _one_hot_encoder(self, input_tensor):
60 | tensor_list = []
61 | for i in range(self.num_classes):
62 | temp_prob = input_tensor == i
63 | tensor_list.append(temp_prob.unsqueeze(1))
64 | output_tensor = torch.cat(tensor_list, dim=1)
65 | return output_tensor.float()
66 |
67 | def _dice_loss(self, score, target):
68 | target = target.float()
69 | smooth = 1e-5
70 | intersect = torch.sum(score * target)
71 | y_sum = torch.sum(target * target)
72 | z_sum = torch.sum(score * score)
73 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
74 | loss = 1 - loss
75 | return loss
76 |
77 | def forward(self, inputs, target, softmax=False):
78 | inputs = inputs.unsqueeze(1)
79 | inputs = torch.concat([torch.ones_like(inputs)-inputs, inputs], dim=1)
80 | if softmax:
81 | inputs = torch.softmax(inputs, dim=1)
82 | target = self._one_hot_encoder(target)
83 |
84 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
85 | loss = 0.0
86 | for i in range(0, self.num_classes):
87 | dice = self._dice_loss(inputs[:, i], target[:, i])
88 | loss += dice
89 | return loss #/ self.num_classes
90 |
91 |
92 |
93 | class DiceLoss(nn.Module):
94 | def __init__(self):
95 | super(DiceLoss, self).__init__()
96 |
97 | def forward(self, inputs, target, smooth=1e-5):
98 | inputs = F.sigmoid(inputs)
99 | inputs = inputs.view(-1)
100 | target = target.view(-1)
101 | intersection = (inputs * target).sum()
102 | dice = (2.*intersection + smooth)/(inputs.sum() + target.sum() + smooth)
103 | return 1 - dice
104 |
--------------------------------------------------------------------------------