├── LICENSE ├── README.md ├── SAM_counting_anything__ArXiv_.pdf ├── app.py ├── automatic_mask_generator.py ├── env.yaml ├── example.png ├── requirements.txt ├── resultFSC.png ├── resultcoco.png ├── test_FSC.py ├── test_coco.py ├── vis_FSC.ipynb └── vis_coco.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vision-Intelligence-and-Robots-Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # count-anything 2 | An empirical study on few-shot counting using segment anything (SAM) 3 | 4 | [Online Demo](https://huggingface.co/spaces/nebula/counting-anything) 5 | 6 | [arXiv](http://arxiv.org/abs/2304.10817) 7 | 8 | Meta AI recently released the Segment Anything model [[SAM]](https://github.com/facebookresearch/segment-anything), which has garnered attention due to its impressive performance in class-agnostic segmenting. In this study, we explore the use of SAM for the challenging task of few-shot object counting, which involves counting objects of an unseen category by providing a few bounding boxes of examples. We compare SAM's performance with other few-shot counting methods and find that it is currently unsatisfactory without further fine-tuning, particularly for small and crowded objects. 9 | 10 | ![image](example.png) 11 | ## Install 12 | Install python dependencies. We use conda and python 3.10.4 and PyTorch 1.13.1 13 | > conda env create -f env.yaml 14 | 15 | ## Dataset preparation 16 | - For FSC-147: 17 | Images can be downloaded from here: https://drive.google.com/file/d/1ymDYrGs9DSRicfZbSCDiOu0ikGDh5k6S/view?usp=sharing 18 | 19 | - For COCO val2017: 20 | Images can be downloaded from here: https://cocodataset.org/ 21 | ## Comparison Results 22 | 23 | ### FSC 24 | 25 | ![image](resultFSC.png) 26 | 27 | ### COCO 28 | 29 | ![image](resultcoco.png) 30 | ## Test 31 | Download the [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) 32 | 33 | - For FSC-147: 34 | ``` 35 | python test_FSC.py --data_path --model_path 36 | ``` 37 | 38 | - For COCO val2017: 39 | ``` 40 | python test_coco.py --data_path --model_path 41 | ``` 42 | 43 | ## Visualize 44 | You can run [vis_FSC.ipynb](vis_FSC.ipynb) for FSC-147 or [vis_coco.ipynb](vis_coco.ipynb) for coco. 45 | 46 | ## Acknowledgement 47 | We thank facebookresearch for their segment-anything model [[project]](https://github.com/facebookresearch/segment-anything), cvlab-stonybrook for their Learning To Count Everything [[project]](https://github.com/cvlab-stonybrook/LearningToCountEverything) and coco [[datasets]](https://cocodataset.org/). 48 | 49 | ## Citation 50 | If you find the code useful, please cite: 51 | ``` 52 | @article{ma2023countanything, 53 | title={CAN SAM COUNT ANYTHING? AN EMPIRICAL STUDY ON SAM COUNTING}, 54 | author={Ma, Zhiheng and Hong, Xiaopeng and Shangguan Qinnan}, 55 | journal={arXiv preprint arXiv:2304.10817}, 56 | year={2023} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /SAM_counting_anything__ArXiv_.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-Intelligence-and-Robots-Group/count-anything/2756753f529a159893718ff99c00b0299f9c36a2/SAM_counting_anything__ArXiv_.pdf -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import cv2 3 | import gradio as gr 4 | import torch 5 | from segment_anything import sam_model_registry 6 | from automatic_mask_generator import SamAutomaticMaskGenerator 7 | 8 | device = 'cuda' 9 | sam = sam_model_registry['vit_h'](checkpoint='./sam_vit_h_4b8939.pth') 10 | sam.to(device=device) 11 | 12 | 13 | mask_generator = SamAutomaticMaskGenerator( 14 | model=sam, 15 | min_mask_region_area=25 16 | ) 17 | 18 | def binarize(x): 19 | return (x != 0).astype('uint8') * 255 20 | 21 | def draw_box(boxes=[], img=None): 22 | if len(boxes) == 0 and img is None: 23 | return None 24 | 25 | if img is None: 26 | img = Image.new('RGB', (512, 512), (255, 255, 255)) 27 | colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] 28 | draw = ImageDraw.Draw(img) 29 | # print(boxes) 30 | for bid, box in enumerate(boxes): 31 | draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) 32 | return img 33 | 34 | 35 | def draw_pred_box(boxes=[], img=None): 36 | if len(boxes) == 0 and img is None: 37 | return None 38 | 39 | if img is None: 40 | img = Image.new('RGB', (512, 512), (255, 255, 255)) 41 | colors = "green" 42 | draw = ImageDraw.Draw(img) 43 | # print(boxes) 44 | for bid, box in enumerate(boxes): 45 | draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors, width=4) 46 | return img 47 | 48 | 49 | def debug(input_img): 50 | mask = input_img["mask"] 51 | mask = mask[..., 0] 52 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 53 | 54 | boxes = [] 55 | for contour in contours: 56 | y1, y2 = contour[:, 0, 1].min(), contour[:, 0, 1].max() 57 | x1, x2 = contour[:, 0, 0].min(), contour[:, 0, 0].max() 58 | boxes.append([x1, y1, x2, y2]) 59 | draw_image = draw_box(boxes, Image.fromarray(input_img["image"])) 60 | 61 | masks = mask_generator.generate(input_img["image"], boxes) 62 | pred_cnt = len(masks) 63 | pred_bboxes = [] 64 | for i in masks: 65 | x0, y0, w, h = i['bbox'] 66 | pred_bboxes.append([x0, y0, x0+w, y0+h]) 67 | pred_image = draw_pred_box(pred_bboxes, Image.fromarray(input_img["image"])) 68 | return [draw_image, pred_image, "Count: {}".format(pred_cnt)] 69 | 70 | description = """

71 | Count Anything 72 |
73 | 74 | [Project Page] 75 | [Paper] 76 | [GitHub] 77 | 78 |

79 | """ 80 | 81 | run = gr.Interface( 82 | debug, 83 | gr.Image(shape=[512, 512], source="upload", tool="sketch").style(height=500, width=500), 84 | [gr.Image(), gr.Image(), gr.Text()], 85 | description = description 86 | ) 87 | 88 | run.launch() -------------------------------------------------------------------------------- /automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 4 | 5 | from typing import Any, Dict, List, Optional, Tuple 6 | import torch.nn.functional as F 7 | from collections import defaultdict 8 | 9 | from segment_anything.modeling import Sam 10 | from segment_anything.predictor import SamPredictor 11 | from segment_anything.utils.amg import ( 12 | MaskData, 13 | area_from_rle, 14 | batch_iterator, 15 | batched_mask_to_box, 16 | box_xyxy_to_xywh, 17 | build_all_layer_point_grids, 18 | calculate_stability_score, 19 | coco_encode_rle, 20 | generate_crop_boxes, 21 | is_box_near_crop_edge, 22 | mask_to_rle_pytorch, 23 | remove_small_regions, 24 | rle_to_mask, 25 | uncrop_boxes_xyxy, 26 | uncrop_masks, 27 | uncrop_points, 28 | ) 29 | 30 | 31 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 32 | x0, y0, _, _ = crop_box 33 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 34 | # Check if boxes has a channel dimension 35 | if len(boxes.shape) == 3: 36 | offset = offset.unsqueeze(1) 37 | return boxes + offset 38 | 39 | def pre_process_ref_box(ref_box, crop_box, layer_idx): 40 | if layer_idx == 0: 41 | return ref_box 42 | else: 43 | new_bbox = [] 44 | x0, y0, x1, y1 = crop_box 45 | for ref in ref_box: 46 | x0_r, y0_r, x1_r, y1_r = ref 47 | area = (y1_r - y0_r) * (x1_r - x0_r) 48 | x_0_new = max(x0, x0_r) 49 | y_0_new = max(y0, y0_r) 50 | x_1_new = min(x1, x1_r) 51 | y_1_new = min(y1, y1_r) 52 | crop_area = (y_1_new - y_0_new) * (x_1_new - x_0_new) 53 | if crop_area / area > 0.7: 54 | new_bbox.append([x_0_new, y_0_new, x_1_new, y_1_new]) 55 | 56 | return new_bbox 57 | 58 | 59 | 60 | 61 | class SamAutomaticMaskGenerator: 62 | def __init__( 63 | self, 64 | model: Sam, 65 | points_per_side: Optional[int] = 32, 66 | points_per_batch: int = 64, 67 | pred_iou_thresh: float = 0.88, 68 | stability_score_thresh: float = 0.95, 69 | stability_score_offset: float = 1.0, 70 | box_nms_thresh: float = 0.7, 71 | crop_n_layers: int = 0, 72 | crop_nms_thresh: float = 0.7, 73 | crop_overlap_ratio: float = 512 / 1500, 74 | crop_n_points_downscale_factor: int = 1, 75 | point_grids: Optional[List[np.ndarray]] = None, 76 | min_mask_region_area: int = 0, 77 | output_mode: str = "binary_mask", 78 | ) -> None: 79 | """ 80 | Using a SAM model, generates masks for the entire image. 81 | Generates a grid of point prompts over the image, then filters 82 | low quality and duplicate masks. The default settings are chosen 83 | for SAM with a ViT-H backbone. 84 | 85 | Arguments: 86 | model (Sam): The SAM model to use for mask prediction. 87 | points_per_side (int or None): The number of points to be sampled 88 | along one side of the image. The total number of points is 89 | points_per_side**2. If None, 'point_grids' must provide explicit 90 | point sampling. 91 | points_per_batch (int): Sets the number of points run simultaneously 92 | by the model. Higher numbers may be faster but use more GPU memory. 93 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 94 | model's predicted mask quality. 95 | stability_score_thresh (float): A filtering threshold in [0,1], using 96 | the stability of the mask under changes to the cutoff used to binarize 97 | the model's mask predictions. 98 | stability_score_offset (float): The amount to shift the cutoff when 99 | calculated the stability score. 100 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 101 | suppression to filter duplicate masks. 102 | crops_n_layers (int): If >0, mask prediction will be run again on 103 | crops of the image. Sets the number of layers to run, where each 104 | layer has 2**i_layer number of image crops. 105 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 106 | suppression to filter duplicate masks between different crops. 107 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 108 | In the first crop layer, crops will overlap by this fraction of 109 | the image length. Later layers with more crops scale down this overlap. 110 | crop_n_points_downscale_factor (int): The number of points-per-side 111 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 112 | point_grids (list(np.ndarray) or None): A list over explicit grids 113 | of points used for sampling, normalized to [0,1]. The nth grid in the 114 | list is used in the nth crop layer. Exclusive with points_per_side. 115 | min_mask_region_area (int): If >0, postprocessing will be applied 116 | to remove disconnected regions and holes in masks with area smaller 117 | than min_mask_region_area. Requires opencv. 118 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 119 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 120 | For large resolutions, 'binary_mask' may consume large amounts of 121 | memory. 122 | """ 123 | 124 | assert (points_per_side is None) != ( 125 | point_grids is None 126 | ), "Exactly one of points_per_side or point_grid must be provided." 127 | if points_per_side is not None: 128 | self.point_grids = build_all_layer_point_grids( 129 | points_per_side, 130 | crop_n_layers, 131 | crop_n_points_downscale_factor, 132 | ) 133 | elif point_grids is not None: 134 | self.point_grids = point_grids 135 | else: 136 | raise ValueError("Can't have both points_per_side and point_grid be None.") 137 | 138 | assert output_mode in [ 139 | "binary_mask", 140 | "uncompressed_rle", 141 | "coco_rle", 142 | ], f"Unknown output_mode {output_mode}." 143 | if output_mode == "coco_rle": 144 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 145 | 146 | if min_mask_region_area > 0: 147 | import cv2 # type: ignore # noqa: F401 148 | 149 | self.predictor = SamPredictor(model) 150 | self.points_per_batch = points_per_batch 151 | self.pred_iou_thresh = pred_iou_thresh 152 | self.stability_score_thresh = stability_score_thresh 153 | self.stability_score_offset = stability_score_offset 154 | self.box_nms_thresh = box_nms_thresh 155 | self.crop_n_layers = crop_n_layers 156 | self.crop_nms_thresh = crop_nms_thresh 157 | self.crop_overlap_ratio = crop_overlap_ratio 158 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 159 | self.min_mask_region_area = min_mask_region_area 160 | self.output_mode = output_mode 161 | 162 | self.prototype = defaultdict(list) 163 | 164 | @torch.no_grad() 165 | def generate(self, image: np.ndarray, ref_bbox) -> List[Dict[str, Any]]: 166 | """ 167 | Generates masks for the given image. 168 | 169 | Arguments: 170 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 171 | 172 | 173 | Returns: 174 | list(dict(str, any)): A list over records for masks. Each record is 175 | a dict containing the following keys: 176 | segmentation (dict(str, any) or np.ndarray): The mask. If 177 | output_mode='binary_mask', is an array of shape HW. Otherwise, 178 | is a dictionary containing the RLE. 179 | bbox (list(float)): The box around the mask, in XYWH format. 180 | area (int): The area in pixels of the mask. 181 | predicted_iou (float): The model's own prediction of the mask's 182 | quality. This is filtered by the pred_iou_thresh parameter. 183 | point_coords (list(list(float))): The point coordinates input 184 | to the model to generate this mask. 185 | stability_score (float): A measure of the mask's quality. This 186 | is filtered on using the stability_score_thresh parameter. 187 | crop_box (list(float)): The crop of the image used to generate 188 | the mask, given in XYWH format. 189 | """ 190 | 191 | # Generate masks 192 | mask_data = self._generate_masks(image, ref_bbox) 193 | 194 | # Filter small disconnected regions and holes in masks 195 | if self.min_mask_region_area > 0: 196 | mask_data = self.postprocess_small_regions( 197 | mask_data, 198 | self.min_mask_region_area, 199 | max(self.box_nms_thresh, self.crop_nms_thresh), 200 | ) 201 | 202 | # Encode masks 203 | if self.output_mode == "coco_rle": 204 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 205 | elif self.output_mode == "binary_mask": 206 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 207 | else: 208 | mask_data["segmentations"] = mask_data["rles"] 209 | 210 | # Write mask records 211 | curr_anns = [] 212 | for idx in range(len(mask_data["segmentations"])): 213 | ann = { 214 | "segmentation": mask_data["segmentations"][idx], 215 | "area": area_from_rle(mask_data["rles"][idx]), 216 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 217 | "predicted_iou": mask_data["iou_preds"][idx].item(), 218 | "point_coords": [mask_data["points"][idx].tolist()], 219 | "stability_score": mask_data["stability_score"][idx].item(), 220 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 221 | } 222 | curr_anns.append(ann) 223 | 224 | return curr_anns 225 | 226 | def _generate_masks(self, image: np.ndarray, ref_box) -> MaskData: 227 | orig_size = image.shape[:2] 228 | crop_boxes, layer_idxs = generate_crop_boxes( 229 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 230 | ) 231 | 232 | # Iterate over image crops 233 | # data = MaskData() 234 | data_dic = defaultdict(MaskData) 235 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 236 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, ref_box) 237 | data_dic[layer_idx].cat(crop_data) 238 | 239 | data = MaskData() 240 | for layer_idx in data_dic.keys(): 241 | proto_fea = torch.concat(self.prototype[layer_idx], dim=0) 242 | if len(proto_fea) > 1: 243 | cos_dis = proto_fea @ proto_fea.t() 244 | sim_thresh = torch.min(cos_dis) 245 | else: 246 | sim_thresh = 0.7 247 | sub_data = data_dic[layer_idx] 248 | fea = sub_data['fea'] 249 | cos_dis = torch.max(fea @ proto_fea.t(), dim=1)[0] 250 | sub_data.filter(cos_dis>=sim_thresh) 251 | data.cat(sub_data) 252 | 253 | self.prototype = defaultdict(list) 254 | 255 | 256 | # Remove duplicate masks between crops 257 | if len(crop_boxes) > 1: 258 | # Prefer masks from smaller crops 259 | scores = 1 / box_area(data["crop_boxes"]) 260 | scores = scores.to(data["boxes"].device) 261 | keep_by_nms = batched_nms( 262 | data["boxes"].float(), 263 | scores, 264 | torch.zeros(len(data["boxes"])), # categories 265 | iou_threshold=self.crop_nms_thresh, 266 | ) 267 | data.filter(keep_by_nms) 268 | 269 | data.to_numpy() 270 | return data 271 | 272 | def _process_crop( 273 | self, 274 | image: np.ndarray, 275 | crop_box: List[int], 276 | crop_layer_idx: int, 277 | orig_size: Tuple[int, ...], 278 | ref_box 279 | ) -> MaskData: 280 | # Crop the image and calculate embeddings 281 | x0, y0, x1, y1 = crop_box 282 | cropped_im = image[y0:y1, x0:x1, :] 283 | cropped_im_size = cropped_im.shape[:2] 284 | self.predictor.set_image(cropped_im) 285 | 286 | ref_box = pre_process_ref_box(ref_box, crop_box, crop_layer_idx) 287 | if len(ref_box) > 0: 288 | ref_box = torch.tensor(ref_box, device=self.predictor.device) 289 | transformed_boxes = self.predictor.transform.apply_boxes_torch(ref_box, cropped_im_size) 290 | masks, iou_preds, low_res_masks = self.predictor.predict_torch( 291 | point_coords=None, 292 | point_labels=None, 293 | boxes=transformed_boxes, 294 | multimask_output=False 295 | ) 296 | feature = self.predictor.get_image_embedding() 297 | 298 | low_res_masks = F.interpolate(low_res_masks, size=feature.shape[-2:], mode='bilinear', align_corners=False) 299 | 300 | feature = feature.flatten(2, 3) 301 | low_res_masks = low_res_masks.flatten(2, 3) 302 | masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float() 303 | topk_idx = torch.topk(low_res_masks, 1)[1] 304 | masks_low_res.scatter_(2, topk_idx, 1.0) 305 | 306 | 307 | prototype_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2) 308 | prototype_fea = F.normalize(prototype_fea, dim=1) 309 | self.prototype[crop_layer_idx].append(prototype_fea) 310 | 311 | 312 | if crop_layer_idx == 0: # add reference gounding 313 | x = ref_box[:, 0] + (ref_box[:, 2] - ref_box[:, 0]) / 2 314 | y = ref_box[:, 1] + (ref_box[:, 3] - ref_box[:, 1]) / 2 315 | points = torch.stack([x, y], dim=1) 316 | data = MaskData( 317 | masks=masks.flatten(0, 1), 318 | iou_preds= torch.ones_like(iou_preds.flatten(0, 1)), 319 | fea = prototype_fea, 320 | points=points.cpu(), 321 | stability_score = torch.ones_like(iou_preds.flatten(0, 1)), 322 | ) 323 | data["boxes"] = batched_mask_to_box(data["masks"]) 324 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 325 | del data["masks"] 326 | else: 327 | data = MaskData() 328 | 329 | 330 | 331 | # Get points for this crop 332 | points_scale = np.array(cropped_im_size)[None, ::-1] 333 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 334 | 335 | # Generate masks for this crop in batches 336 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 337 | batch_data = self._process_batch(points, cropped_im_size, 338 | crop_box, orig_size) 339 | data.cat(batch_data) 340 | del batch_data 341 | self.predictor.reset_image() 342 | 343 | # Remove duplicates within this crop. 344 | keep_by_nms = batched_nms( 345 | data["boxes"].float(), 346 | data["iou_preds"], 347 | torch.zeros(len(data["boxes"])), # categories 348 | iou_threshold=self.box_nms_thresh, 349 | ) 350 | data.filter(keep_by_nms) 351 | 352 | # Return to the original image frame 353 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 354 | data["points"] = uncrop_points(data["points"], crop_box) 355 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 356 | 357 | return data 358 | 359 | def _process_batch( 360 | self, 361 | points: np.ndarray, 362 | im_size: Tuple[int, ...], 363 | crop_box: List[int], 364 | orig_size: Tuple[int, ...], 365 | ) -> MaskData: 366 | orig_h, orig_w = orig_size 367 | 368 | # Run model on this batch 369 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 370 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 371 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 372 | masks, iou_preds, low_res_masks = self.predictor.predict_torch( 373 | in_points[:, None, :], 374 | in_labels[:, None], 375 | multimask_output=True, 376 | return_logits=True, 377 | ) 378 | 379 | feature = self.predictor.get_image_embedding() 380 | low_res_masks=low_res_masks.flatten(0, 1) 381 | low_res_masks = F.interpolate(low_res_masks[:, None, :, :], size=feature.shape[-2:], 382 | mode='bilinear', align_corners=False) 383 | # low_res_masks = low_res_masks > self.predictor.model.mask_threshold 384 | 385 | # fea = feature.flatten(2, 3) 386 | # low_res_masks = low_res_masks.flatten(2, 3) 387 | # topk_idx = torch.topk(low_res_masks, 4)[1] 388 | # fea = fea.expand(topk_idx.shape[0], -1, -1) 389 | # topk_idx = topk_idx.expand(-1, fea.shape[1], -1) 390 | # fea = fea.gather(2, topk_idx) 391 | 392 | 393 | feature = feature.flatten(2, 3) 394 | low_res_masks = low_res_masks.flatten(2, 3) 395 | masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float() 396 | topk_idx = torch.topk(low_res_masks, 1)[1] 397 | masks_low_res.scatter_(2, topk_idx, 1.0) 398 | pool_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2) 399 | pool_fea = F.normalize(pool_fea, dim=1) 400 | 401 | # k_val = torch.topk(torch.flatten(low_res_masks, start_dim=2, end_dim=3), k=4, dim=-1)[0][:, :, -1, None] 402 | # low_res_masks = (low_res_masks >= k_val).float() 403 | # low_res_masks = low_res_masks.float() 404 | # pool_fea = (feature * low_res_masks).sum(dim=(2, 3)) / low_res_masks.sum(dim=(2, 3)) 405 | 406 | 407 | 408 | # Serialize predictions and store in MaskData 409 | data = MaskData( 410 | masks=masks.flatten(0, 1), 411 | iou_preds=iou_preds.flatten(0, 1), 412 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 413 | fea = pool_fea, 414 | ) 415 | del masks 416 | 417 | 418 | # Filter by predicted IoU 419 | if self.pred_iou_thresh > 0.0: 420 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 421 | data.filter(keep_mask) 422 | 423 | # Calculate stability score 424 | data["stability_score"] = calculate_stability_score( 425 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 426 | ) 427 | if self.stability_score_thresh > 0.0: 428 | keep_mask = data["stability_score"] >= self.stability_score_thresh 429 | data.filter(keep_mask) 430 | 431 | # Threshold masks and calculate boxes 432 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 433 | data["boxes"] = batched_mask_to_box(data["masks"]) 434 | 435 | # Filter boxes that touch crop boundaries 436 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 437 | if not torch.all(keep_mask): 438 | data.filter(keep_mask) 439 | 440 | # Compress to RLE 441 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 442 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 443 | del data["masks"] 444 | 445 | return data 446 | 447 | @staticmethod 448 | def postprocess_small_regions( 449 | mask_data: MaskData, min_area: int, nms_thresh: float 450 | ) -> MaskData: 451 | """ 452 | Removes small disconnected regions and holes in masks, then reruns 453 | box NMS to remove any new duplicates. 454 | 455 | Edits mask_data in place. 456 | 457 | Requires open-cv as a dependency. 458 | """ 459 | if len(mask_data["rles"]) == 0: 460 | return mask_data 461 | 462 | # Filter small disconnected regions and holes 463 | new_masks = [] 464 | scores = [] 465 | for rle in mask_data["rles"]: 466 | mask = rle_to_mask(rle) 467 | 468 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 469 | unchanged = not changed 470 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 471 | unchanged = unchanged and not changed 472 | 473 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 474 | # Give score=0 to changed masks and score=1 to unchanged masks 475 | # so NMS will prefer ones that didn't need postprocessing 476 | scores.append(float(unchanged)) 477 | 478 | # Recalculate boxes and remove any new duplicates 479 | masks = torch.cat(new_masks, dim=0) 480 | boxes = batched_mask_to_box(masks) 481 | keep_by_nms = batched_nms( 482 | boxes.float(), 483 | torch.as_tensor(scores), 484 | torch.zeros(len(boxes)), # categories 485 | iou_threshold=nms_thresh, 486 | ) 487 | 488 | # Only recalculate RLEs for masks that have changed 489 | for i_mask in keep_by_nms: 490 | if scores[i_mask] == 0.0: 491 | mask_torch = masks[i_mask].unsqueeze(0) 492 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 493 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 494 | mask_data.filter(keep_by_nms) 495 | 496 | return mask_data 497 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: ltce 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - asttokens=2.2.1=pyhd8ed1ab_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=pyhd8ed1ab_3 11 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 12 | - blas=1.0=openblas 13 | - brotli=1.0.9=h5eee18b_7 14 | - brotli-bin=1.0.9=h5eee18b_7 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2023.01.10=h06a4308_0 17 | - cairo=1.16.0=hb05425b_4 18 | - certifi=2022.12.7=py310h06a4308_0 19 | - contourpy=1.0.5=py310hdb19cb5_0 20 | - cycler=0.11.0=pyhd3eb1b0_0 21 | - dbus=1.13.18=hb2f20db_0 22 | - debugpy=1.5.1=py310h295c915_0 23 | - decorator=5.1.1=pyhd8ed1ab_0 24 | - eigen=3.4.0=h4bd325d_0 25 | - entrypoints=0.4=pyhd8ed1ab_0 26 | - executing=1.2.0=pyhd8ed1ab_0 27 | - expat=2.2.10=h9c3ff4c_0 28 | - ffmpeg=4.2.2=h20bf706_0 29 | - fontconfig=2.14.1=hef1e5e3_0 30 | - fonttools=4.25.0=pyhd3eb1b0_0 31 | - freetype=2.10.4=h0708190_1 32 | - giflib=5.2.1=h36c2ea0_2 33 | - glib=2.69.1=h4ff587b_1 34 | - gmp=6.2.1=h58526e2_0 35 | - gnutls=3.6.13=h85f3911_1 36 | - graphite2=1.3.14=h295c915_1 37 | - gst-plugins-base=1.14.1=h6a678d5_1 38 | - gstreamer=1.14.1=h5eee18b_1 39 | - harfbuzz=4.3.0=hf52aaf7_1 40 | - hdf5=1.10.6=h3ffc7dd_1 41 | - icu=58.2=hf484d3e_1000 42 | - ipykernel=6.15.0=pyh210e3f2_0 43 | - ipython=8.12.0=pyh41d4057_0 44 | - jedi=0.18.2=pyhd8ed1ab_0 45 | - jpeg=9e=h166bdaf_1 46 | - jupyter_client=7.3.4=pyhd8ed1ab_0 47 | - jupyter_core=5.3.0=py310hff52083_0 48 | - keyutils=1.6.1=h166bdaf_0 49 | - kiwisolver=1.4.4=py310h6a678d5_0 50 | - krb5=1.19.3=h3790be6_0 51 | - lame=3.100=h7f98852_1001 52 | - lcms2=2.12=h3be6417_0 53 | - ld_impl_linux-64=2.38=h1181459_1 54 | - lerc=3.0=h295c915_0 55 | - libblas=3.9.0=15_linux64_openblas 56 | - libbrotlicommon=1.0.9=h5eee18b_7 57 | - libbrotlidec=1.0.9=h5eee18b_7 58 | - libbrotlienc=1.0.9=h5eee18b_7 59 | - libcblas=3.9.0=15_linux64_openblas 60 | - libclang=10.0.1=default_hb85057a_2 61 | - libdeflate=1.17=h5eee18b_0 62 | - libedit=3.1.20191231=he28a2e2_2 63 | - libevent=2.1.12=h8f2d780_0 64 | - libffi=3.3=he6710b0_2 65 | - libgcc-ng=11.2.0=h1234567_1 66 | - libgfortran-ng=12.2.0=h69a702a_19 67 | - libgfortran5=12.2.0=h337968e_19 68 | - libgomp=11.2.0=h1234567_1 69 | - liblapack=3.9.0=15_linux64_openblas 70 | - libllvm10=10.0.1=he513fc3_3 71 | - libopenblas=0.3.20=pthreads_h78a6416_0 72 | - libopus=1.3.1=h7f98852_1 73 | - libpng=1.6.39=h5eee18b_0 74 | - libpq=12.9=h16c4e8d_3 75 | - libprotobuf=3.20.3=he621ea3_0 76 | - libsodium=1.0.18=h36c2ea0_1 77 | - libstdcxx-ng=11.2.0=h1234567_1 78 | - libtiff=4.5.0=h6a678d5_2 79 | - libuuid=1.41.5=h5eee18b_0 80 | - libvpx=1.7.0=h439df22_0 81 | - libwebp=1.2.4=h11a3e52_1 82 | - libwebp-base=1.2.4=h5eee18b_1 83 | - libxcb=1.15=h7f8727e_0 84 | - libxkbcommon=1.0.1=hfa300c1_0 85 | - libxml2=2.9.14=h74e7548_0 86 | - libxslt=1.1.35=h4e12654_0 87 | - lz4-c=1.9.3=h9c3ff4c_1 88 | - matplotlib=3.7.1=py310h06a4308_1 89 | - matplotlib-base=3.7.1=py310h1128e8f_1 90 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 91 | - munkres=1.1.4=py_0 92 | - ncurses=6.4=h6a678d5_0 93 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 94 | - nettle=3.6=he412f7d_0 95 | - nspr=4.33=h295c915_0 96 | - nss=3.74=h0370c37_0 97 | - opencv=4.6.0=py310h1128e8f_3 98 | - openh264=2.1.1=h4ff587b_0 99 | - openjpeg=2.4.0=h3ad879b_0 100 | - openssl=1.1.1t=h7f8727e_0 101 | - packaging=23.1=pyhd8ed1ab_0 102 | - parso=0.8.3=pyhd8ed1ab_0 103 | - pcre=8.45=h9c3ff4c_0 104 | - pexpect=4.8.0=pyh1a96a4e_2 105 | - pickleshare=0.7.5=py_1003 106 | - pip=23.0.1=py310h06a4308_0 107 | - pixman=0.40.0=h36c2ea0_0 108 | - platformdirs=3.2.0=pyhd8ed1ab_0 109 | - ply=3.11=py310h06a4308_0 110 | - prompt-toolkit=3.0.38=pyha770c72_0 111 | - prompt_toolkit=3.0.38=hd8ed1ab_0 112 | - psutil=5.9.0=py310h5eee18b_0 113 | - ptyprocess=0.7.0=pyhd3deb0d_0 114 | - pure_eval=0.2.2=pyhd8ed1ab_0 115 | - pygments=2.15.0=pyhd8ed1ab_0 116 | - pyparsing=3.0.9=py310h06a4308_0 117 | - pyqt=5.15.7=py310h6a678d5_1 118 | - python=3.10.4=h12debd9_0 119 | - python-dateutil=2.8.2=pyhd8ed1ab_0 120 | - python_abi=3.10=2_cp310 121 | - pyzmq=23.2.0=py310h6a678d5_0 122 | - qt-main=5.15.2=h327a75a_7 123 | - qt-webengine=5.15.9=hd2b0992_4 124 | - qtwebkit=5.212=h4eab89a_4 125 | - readline=8.2=h5eee18b_0 126 | - setuptools=65.6.3=py310h06a4308_0 127 | - sip=6.6.2=py310h6a678d5_0 128 | - six=1.16.0=pyh6c4a22f_0 129 | - sqlite=3.41.2=h5eee18b_0 130 | - stack_data=0.6.2=pyhd8ed1ab_0 131 | - tk=8.6.12=h1ccaba5_0 132 | - toml=0.10.2=pyhd3eb1b0_0 133 | - tornado=6.1=py310h5764c6d_3 134 | - tqdm=4.65.0=py310h2f386ee_0 135 | - traitlets=5.9.0=pyhd8ed1ab_0 136 | - typing-extensions=4.5.0=hd8ed1ab_0 137 | - typing_extensions=4.5.0=pyha770c72_0 138 | - tzdata=2023c=h04d1e81_0 139 | - wcwidth=0.2.6=pyhd8ed1ab_0 140 | - wheel=0.38.4=py310h06a4308_0 141 | - x264=1!157.20191217=h7b6447c_0 142 | - xz=5.2.10=h5eee18b_1 143 | - zeromq=4.3.4=h9c3ff4c_1 144 | - zlib=1.2.13=h5eee18b_0 145 | - zstd=1.5.2=ha4553b6_0 146 | - pip: 147 | - charset-normalizer==3.1.0 148 | - cmake==3.26.3 149 | - filelock==3.11.0 150 | - idna==3.4 151 | - jinja2==3.1.2 152 | - lit==16.0.1 153 | - markupsafe==2.1.2 154 | - mpmath==1.3.0 155 | - networkx==3.1 156 | - numpy==1.24.2 157 | - nvidia-cublas-cu11==11.10.3.66 158 | - nvidia-cuda-cupti-cu11==11.7.101 159 | - nvidia-cuda-nvrtc-cu11==11.7.99 160 | - nvidia-cuda-runtime-cu11==11.7.99 161 | - nvidia-cudnn-cu11==8.5.0.96 162 | - nvidia-cufft-cu11==10.9.0.58 163 | - nvidia-curand-cu11==10.2.10.91 164 | - nvidia-cusolver-cu11==11.4.0.1 165 | - nvidia-cusparse-cu11==11.7.4.91 166 | - nvidia-nccl-cu11==2.14.3 167 | - nvidia-nvtx-cu11==11.7.91 168 | - pillow==9.5.0 169 | - pyqt5-sip==12.11.0 170 | - requests==2.28.2 171 | - segment-anything==1.0 172 | - sympy==1.11.1 173 | - torch==2.0.0 174 | - torchaudio==2.0.1 175 | - torchvision==0.15.1 176 | - triton==2.0.0 177 | - urllib3==1.26.15 -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-Intelligence-and-Robots-Group/count-anything/2756753f529a159893718ff99c00b0299f9c36a2/example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio==3.27.0 2 | matplotlib==3.7.1 3 | numpy==1.24.1 4 | opencv_python==4.7.0.72 5 | Pillow==9.3.0 6 | pycocotools==2.0.6 7 | segment_anything==1.0 8 | torch==2.0.0 9 | torchvision==0.15.1 10 | tqdm==4.65.0 11 | -------------------------------------------------------------------------------- /resultFSC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-Intelligence-and-Robots-Group/count-anything/2756753f529a159893718ff99c00b0299f9c36a2/resultFSC.png -------------------------------------------------------------------------------- /resultcoco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-Intelligence-and-Robots-Group/count-anything/2756753f529a159893718ff99c00b0299f9c36a2/resultcoco.png -------------------------------------------------------------------------------- /test_FSC.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | from os.path import exists 7 | import os 8 | 9 | from segment_anything import sam_model_registry 10 | from automatic_mask_generator import SamAutomaticMaskGenerator 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code") 17 | parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the FSC147 dataset") 18 | parser.add_argument("-ts", "--test_split", type=str, default='val', choices=["val_PartA","val_PartB","test_PartA","test_PartB","test", "val"], help="what data split to evaluate on") 19 | parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type") 20 | parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model") 21 | parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize") 22 | parser.add_argument("-d", "--device", default='0', help='assign device') 23 | args = parser.parse_args() 24 | 25 | data_path = args.data_path 26 | anno_file = data_path + 'annotation_FSC147_384.json' 27 | data_split_file = data_path + 'Train_Test_Val_FSC_147.json' 28 | im_dir = data_path + 'images_384_VarV2' 29 | 30 | 31 | if not exists(anno_file) or not exists(im_dir): 32 | print("Make sure you set up the --data-path correctly.") 33 | print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path)) 34 | print("Aborting the evaluation") 35 | exit(-1) 36 | 37 | def show_anns(anns): 38 | if len(anns) == 0: 39 | return 40 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 41 | ax = plt.gca() 42 | ax.set_autoscale_on(False) 43 | for ann in sorted_anns: 44 | x0, y0, w, h = ann['bbox'] 45 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 46 | ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25) 47 | 48 | 49 | debug = True 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() 51 | device = 'cuda' 52 | sam = sam_model_registry[args.model_type](checkpoint=args.model_path) 53 | sam.to(device=device) 54 | 55 | 56 | mask_generator = SamAutomaticMaskGenerator( 57 | model=sam, 58 | min_mask_region_area=25 59 | ) 60 | 61 | with open(anno_file) as f: 62 | annotations = json.load(f) 63 | 64 | with open(data_split_file) as f: 65 | data_split = json.load(f) 66 | 67 | 68 | cnt = 0 69 | SAE = 0 # sum of absolute errors 70 | SSE = 0 # sum of square errors 71 | 72 | print("Evaluation on {} data".format(args.test_split)) 73 | im_ids = data_split[args.test_split] 74 | 75 | # with open("err.json") as f: 76 | # im_ids = json.load(f) 77 | 78 | 79 | pbar = tqdm(im_ids) 80 | # err_list = [] 81 | for im_id in pbar: 82 | anno = annotations[im_id] 83 | bboxes = anno['box_examples_coordinates'] 84 | dots = np.array(anno['points']) 85 | 86 | image = cv2.imread('{}/{}'.format(im_dir, im_id)) 87 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 88 | 89 | input_boxes = list() 90 | for bbox in bboxes: 91 | x1, y1 = bbox[0][0], bbox[0][1] 92 | x2, y2 = bbox[2][0], bbox[2][1] 93 | input_boxes.append([x1, y1, x2, y2]) 94 | 95 | masks = mask_generator.generate(image, input_boxes) 96 | if args.viz: 97 | if not exists('viz'): 98 | os.mkdir('viz') 99 | plt.figure(figsize=(10,10)) 100 | plt.imshow(image) 101 | show_anns(masks) 102 | plt.axis('off') 103 | plt.savefig('viz/{}'.format(im_id)) 104 | plt.close() 105 | 106 | gt_cnt = dots.shape[0] 107 | pred_cnt = len(masks) 108 | cnt = cnt + 1 109 | err = abs(gt_cnt - pred_cnt) 110 | SAE += err 111 | SSE += err**2 112 | # if err / gt_cnt > 0.7: 113 | # err_list.append(im_id) 114 | 115 | pbar.set_description('{:<8}: actual-predicted: {:6d}, {:6.1f}, error: {:6.1f}. Current MAE: {:5.2f}, RMSE: {:5.2f}'.\ 116 | format(im_id, gt_cnt, pred_cnt, abs(pred_cnt - gt_cnt), SAE/cnt, (SSE/cnt)**0.5)) 117 | 118 | print('On {} data, MAE: {:6.2f}, RMSE: {:6.2f}'.format(args.test_split, SAE/cnt, (SSE/cnt)**0.5)) 119 | # with open('err.json', "w") as f: 120 | # json.dump(err_list, f) -------------------------------------------------------------------------------- /test_coco.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | from os.path import exists 7 | import os 8 | 9 | from segment_anything import sam_model_registry 10 | from automatic_mask_generator import SamAutomaticMaskGenerator 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code") 17 | parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the coco dataset") 18 | parser.add_argument("-ts", "--test_split", type=str, default='val2017', choices=["val2017"], help="what data split to evaluate on") 19 | parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type") 20 | parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model") 21 | parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize") 22 | parser.add_argument("-d", "--device", default='0', help='assign device') 23 | args = parser.parse_args() 24 | 25 | data_path = args.data_path 26 | anno_file = data_path + 'annotations_trainval2017/annotations/instances_val2017.json' 27 | im_dir = data_path + 'val2017' 28 | 29 | 30 | if not exists(anno_file) or not exists(im_dir): 31 | print("Make sure you set up the --data-path correctly.") 32 | print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path)) 33 | print("Aborting the evaluation") 34 | exit(-1) 35 | 36 | def show_anns(anns): 37 | if len(anns) == 0: 38 | return 39 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 40 | ax = plt.gca() 41 | ax.set_autoscale_on(False) 42 | for ann in sorted_anns: 43 | x0, y0, w, h = ann['bbox'] 44 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 45 | ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25) 46 | 47 | 48 | debug = True 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() 50 | device = 'cuda' 51 | sam = sam_model_registry[args.model_type](checkpoint=args.model_path) 52 | sam.to(device=device) 53 | 54 | 55 | mask_generator = SamAutomaticMaskGenerator( 56 | model=sam, 57 | min_mask_region_area=25 58 | ) 59 | 60 | with open(anno_file) as f: 61 | annotations = json.load(f) 62 | 63 | images = sorted(annotations['images'],key=lambda x:x['file_name']) 64 | 65 | prepared_json = {} 66 | for i in images: 67 | prepared_json[i['file_name']] = { 68 | "H":i['height'], 69 | "W":i['width'], 70 | "boxes":{}, 71 | # "category_ids":[], 72 | } 73 | for i in annotations['annotations']: 74 | im_id = str(i['image_id']) 75 | prezero = 12 - len(im_id) 76 | im_id = '0'*prezero + im_id + ".jpg" 77 | if i["category_id"] in prepared_json[im_id]["boxes"]: 78 | prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox']) 79 | else: 80 | prepared_json[im_id]["boxes"][i["category_id"]] = [] 81 | prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox']) 82 | 83 | im_ids = [] 84 | for i in prepared_json.keys(): 85 | im_ids.append(i) 86 | 87 | 88 | cnt = 0 89 | folds = [ 90 | [1,5,9,14,18,22,27,33,37,41,46,50,54,58,62,67,74,78,82,87], 91 | [2,6,10,15,19,23,28,34,38,42,47,51,55,59,63,70,75,79,84,88], 92 | [3,7,11,16,20,24,31,35,39,43,48,52,56,60,64,72,76,80,85,89], 93 | [4,8,13,17,21,25,32,36,40,44,49,53,57,61,65,73,77,81,86,90], 94 | ] 95 | SAE = [0,0,0,0] # sum of absolute errors 96 | SSE = [0,0,0,0] # sum of square errors 97 | 98 | print("Evaluation on {} data".format(args.test_split)) 99 | 100 | # logs = [] 101 | 102 | 103 | pbar = tqdm(im_ids) 104 | # err_list = [] 105 | for im_id in pbar: 106 | category_id = list(prepared_json[im_id]['boxes'].keys()) 107 | 108 | image = cv2.imread('{}/{}'.format(im_dir, im_id)) 109 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 110 | # log = [] 111 | # log.append(im_id) 112 | 113 | for id in category_id: 114 | boxes = prepared_json[im_id]['boxes'][id] 115 | 116 | input_boxes = list() 117 | x1, y1 = boxes[0][0],boxes[0][1] 118 | x2, y2 = boxes[0][0] + boxes[0][2],boxes[0][1] + boxes[0][3] 119 | input_boxes.append([x1, y1, x2, y2]) 120 | 121 | masks = mask_generator.generate(image, input_boxes) 122 | 123 | if args.viz: 124 | if not exists('viz'): 125 | os.mkdir('viz') 126 | plt.figure(figsize=(10,10)) 127 | plt.imshow(image) 128 | show_anns(masks) 129 | plt.axis('off') 130 | plt.savefig('viz/{}_{}.jpg'.format(im_id[0:-4],id)) 131 | plt.close() 132 | 133 | gt_cnt = len(boxes) 134 | pred_cnt = len(masks) 135 | err = abs(gt_cnt - pred_cnt) 136 | log.append("\n{},gt_cnt:{},pred_cnt:{}".format(id,gt_cnt,pred_cnt)) 137 | if id in folds[0]: 138 | SAE[0] += err 139 | SSE[0] += err**2 140 | elif id in folds[1]: 141 | SAE[1] += err 142 | SSE[1] += err**2 143 | elif id in folds[2]: 144 | SAE[2] += err 145 | SSE[2] += err**2 146 | elif id in folds[3]: 147 | SAE[3] += err 148 | SSE[3] += err**2 149 | 150 | cnt = cnt + 1 151 | # logs.append(log) 152 | pbar.set_description('fold1: {:6.2f}, fold2: {:6.2f}, fold3: {:6.2f}, fold4: {:6.2f},'.\ 153 | format(SAE[0]/cnt,SAE[1]/cnt,SAE[2]/cnt,SAE[3]/cnt)) 154 | 155 | print('On {} data, fold1 MAE: {:6.2f}, RMSE: {:6.2f}\n \ 156 | fold2 MAE: {:6.2f}, RMSE: {:6.2f}\n \ 157 | fold3 MAE: {:6.2f}, RMSE: {:6.2f}\n \ 158 | fold4 MAE: {:6.2f}, RMSE: {:6.2f}\n \ 159 | '.format(args.test_split,SAE[0]/cnt,(SSE[0]/cnt)**0.5,SAE[1]/cnt,(SSE[1]/cnt)**0.5,SAE[2]/cnt,(SSE[2]/cnt)**0.5,SAE[3]/cnt,(SSE[3]/cnt)**0.5)) --------------------------------------------------------------------------------