├── images ├── 2007_000464.jpg └── End-to-end self-supervised semantic segmentation.png ├── model ├── __pycache__ │ ├── vit.cpython-310.pyc │ ├── align_model.cpython-310.pyc │ ├── criterion.cpython-310.pyc │ └── transforms.cpython-310.pyc ├── criterion.py ├── align_model.py ├── transforms.py └── vit.py ├── data ├── __pycache__ │ ├── coco_data.cpython-310.pyc │ └── voc_data.cpython-310.pyc ├── movi_data.py ├── voc_data.py ├── coco_data.py └── stuffthing_2017.json ├── configs ├── eval_voc_config.yml ├── eval_coco27_config.yml ├── eval_movi_config.yml ├── train_coco_config.yml ├── train_movi_config.yml └── train_voc_config.yml ├── LICENSE.txt ├── README.md ├── evaluate ├── cocoStuff27_mask_visualize.py ├── proto_similarity.py ├── object_evaluation.py ├── sup_overcluster.py ├── visualize_segment.py └── eval_utils.py ├── train.py └── utils.py /images/2007_000464.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/images/2007_000464.jpg -------------------------------------------------------------------------------- /model/__pycache__/vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/vit.cpython-310.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/data/__pycache__/coco_data.cpython-310.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/data/__pycache__/voc_data.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/align_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/align_model.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/criterion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/criterion.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /images/End-to-end self-supervised semantic segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/images/End-to-end self-supervised semantic segmentation.png -------------------------------------------------------------------------------- /configs/eval_voc_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 1 2 | 3 | data: 4 | data_dir: '' 5 | dataset_name: "voc" # coco, imagenet100k, imagenet or voc 6 | num_classes: 21 7 | size_crops: 448 8 | 9 | val: 10 | arch: 'vit_small' 11 | batch_size: 1 12 | seed: 3407 13 | patch_size: 16 14 | embed_dim: 384 15 | hidden_dim: 384 16 | num_decode_layers: 1 17 | decoder_num_heads: 3 18 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query 19 | last_self_attention: False # whether use attention map as foreground hint 20 | mask_eval_size: 100 21 | checkpoint: './epoch10.pth' 22 | -------------------------------------------------------------------------------- /configs/eval_coco27_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 1 2 | 3 | data: 4 | data_dir: '' 5 | dataset_name: "coco-all" # coco-all, coco-stuff, coco-thing or voc 6 | num_classes: 27 7 | size_crops: 448 8 | 9 | val: 10 | arch: 'vit_small' 11 | batch_size: 1 12 | seed: 3407 13 | patch_size: 16 14 | embed_dim: 384 15 | hidden_dim: 768 16 | num_decode_layers: 3 17 | decoder_num_heads: 3 18 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query 19 | last_self_attention: False # whether use attention map as foreground hint 20 | mask_eval_size: 100 21 | checkpoint: './epoch10.pth' 22 | -------------------------------------------------------------------------------- /configs/eval_movi_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 1 2 | 3 | data: 4 | data_dir: "./Data/" 5 | dataset_name: "movi_e" # "movi_e" or "movi_c" 6 | num_classes: 17 7 | size_crops: 256 8 | 9 | val: 10 | arch: 'vit_small' 11 | batch_size: 1 12 | seed: 3407 13 | patch_size: 16 14 | embed_dim: 384 15 | hidden_dim: 768 16 | num_decode_layers: 6 17 | decoder_num_heads: 4 18 | num_queries: 18 # effective queries for mask generation, always ends with an 'Others' query 19 | last_self_attention: False # whether use attention map as foreground hint 20 | mask_eval_size: 256 21 | checkpoint: './log_tmp/movi_e-vit_small-bs32/model/epoch10.pth' 22 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yliu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/train_coco_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 4 2 | 3 | data: 4 | data_dir: "" 5 | dataset_name: "coco" # coco, or voc 6 | size_crops: [224, 224] 7 | augment_image: False 8 | jitter_strength: 1.0 9 | blur_strength: 1.0 10 | min_scale_crops: [0.5, 0.05] 11 | max_scale_crops: [1., 0.25] 12 | min_intersection_crops: 0.05 13 | nmb_crops: [1, 2] 14 | size_crops_val: 448 15 | 16 | train: 17 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2 18 | max_epochs: 100 19 | seed: 3407 20 | fix_vit: True 21 | exclude_norm_bias: True 22 | roi_align_kernel_size: 7 23 | arch: 'vit_small' 24 | patch_size: 16 25 | embed_dim: 384 26 | hidden_dim: 768 27 | num_decode_layers: 1 28 | decoder_num_heads: 3 29 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query 30 | last_self_attention: True # whether use attention map as foreground hint 31 | ce_temperature: 1 32 | lr_decoder: 0.0005 33 | final_lr: 0. 34 | weight_decay: 0.04 35 | weight_decay_end: 0.5 36 | negative_pressure: 0.10 37 | epsilon: 0.05 38 | save_checkpoint_every_n_epochs: 1 39 | checkpoint: '' 40 | pretrained_model: './dino_vitsmall16.pth' 41 | fix_prototypes: False 42 | -------------------------------------------------------------------------------- /configs/train_movi_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 4 2 | 3 | data: 4 | data_dir: "./Data/" 5 | dataset_name: "movi_e" # "movi_c" or "movi_e" 6 | size_crops: [224, 224] 7 | augment_image: False 8 | jitter_strength: 1.0 9 | blur_strength: 1.0 10 | min_scale_crops: [0.5, 0.05] 11 | max_scale_crops: [1., 0.25] 12 | min_intersection_crops: 0.05 13 | nmb_crops: [1, 2] 14 | size_crops_val: 256 # Crops size for validation and seg maps viz 15 | num_classes_val: 17 16 | 17 | train: 18 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2 19 | max_epochs: 10 20 | seed: 3407 21 | fix_vit: True 22 | exclude_norm_bias: True 23 | roi_align_kernel_size: 7 24 | arch: 'vit_small' 25 | patch_size: 16 26 | embed_dim: 384 27 | hidden_dim: 768 28 | num_decode_layers: 6 29 | decoder_num_heads: 4 30 | num_queries: 18 # effective queries for mask generation, always ends with an 'Others' query 31 | last_self_attention: True # whether use attention map as foreground hint 32 | ce_temperature: 1 33 | lr_decoder: 0.0005 34 | final_lr: 0. 35 | weight_decay: 0.04 36 | weight_decay_end: 0.5 37 | negative_pressure: 0.13 # 0.13 for MOVi-C, 0.13 for MOVi-E 38 | corr_coefficient: 0.15 39 | epsilon: 0.05 40 | save_checkpoint_every_n_epochs: 1 41 | checkpoint: 42 | pretrained_model: 'dino_vitsmall16.pth' 43 | prototype_queries: 44 | fix_prototypes: 45 | -------------------------------------------------------------------------------- /configs/train_voc_config.yml: -------------------------------------------------------------------------------- 1 | num_workers: 4 2 | 3 | data: 4 | data_dir: "" 5 | dataset_name: "voc" # coco, imagenet100k, imagenet or voc 6 | size_crops: [224, 224] 7 | augment_image: False 8 | jitter_strength: 1.0 9 | blur_strength: 1.0 10 | min_scale_crops: [0.5, 0.05] 11 | max_scale_crops: [1., 0.25] 12 | min_intersection_crops: 0.05 13 | nmb_crops: [1, 2] 14 | size_crops_val: 448 # Crops size for validation and seg maps viz 15 | num_classes_val: 21 16 | voc_data_path: "" 17 | 18 | train: 19 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2 20 | max_epochs: 10 21 | seed: 3407 22 | fix_vit: True 23 | exclude_norm_bias: True 24 | roi_align_kernel_size: 7 25 | arch: 'vit_small' 26 | patch_size: 16 27 | embed_dim: 384 28 | hidden_dim: 768 29 | num_decode_layers: 1 30 | decoder_num_heads: 3 31 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query 32 | last_self_attention: True # whether use attention map as foreground hint 33 | ce_temperature: 1 34 | lr_decoder: 0.0005 35 | final_lr: 0. 36 | weight_decay: 0.04 37 | weight_decay_end: 0.5 38 | negative_pressure: 0.11 39 | corr_coefficient: 0.15 40 | epsilon: 0.05 41 | save_checkpoint_every_n_epochs: 1 42 | checkpoint: #'' 43 | pretrained_model: './dino_vitsmall16.pth' 44 | prototype_queries: 45 | fix_prototypes: False 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [AlignSeg] Rethinking Self-Supervised Semantic Segmentation: Achieving End-to-End Segmentation 2 | 3 | This is the PyTorch implementation of AlignSeg. 4 | 5 |
6 | Self-supervised Semantic Segmentation Evaluation Protocols 8 |
9 | 10 | ### Dataset Setup 11 | 12 | Please download the data and organize as detailed in the next subsections. 13 | 14 | ##### Pascal VOC 15 | Here's a [zipped version](https://www.dropbox.com/s/6gd4x0i9ewasymb/voc_data.zip?dl=0) for convenience. 16 | 17 | The structure for training and evaluation should be as follows: 18 | ``` 19 | dataset root. 20 | └───SegmentationClass 21 | │ │ *.png 22 | │ │ ... 23 | └───SegmentationClassAug # contains segmentation masks from trainaug extension 24 | │ │ *.png 25 | │ │ ... 26 | └───images 27 | │ │ *.jpg 28 | │ │ ... 29 | └───sets 30 | │ │ train.txt 31 | │ │ trainaug.txt 32 | │ │ val.txt 33 | ``` 34 | 35 | #### COCO-Stuff-27 36 | The structure for training and evaluation should be as follows: 37 | ``` 38 | dataset root. 39 | └───annotations 40 | │ └─── annotations 41 | │ └─── stuffthingmaps_trainval2017 42 | │ │ stuffthing_2017.json 43 | │ └─── train2017 44 | │ │ *.png 45 | │ │ ... 46 | │ └─── val2017 47 | │ │ *.png 48 | │ │ ... 49 | └───coco 50 | │ └─── images 51 | │ └─── train2017 52 | │ │ *.jpg 53 | │ │ ... 54 | │ └─── val2017 55 | │ │ *.jpg 56 | │ │ ... 57 | ``` 58 | The “curated” split introduced by IIC can be downloaded [here](https://www.robots.ox.ac.uk/~xuji/datasets/COCOStuff164kCurated.tar.gz). 59 | 60 | ### Self-supervised Training with Frozen ViT 61 | 62 | We provide the training configuration files for PVOC and COCO-Stuff in ```/configs``` folder, fill in your own path to dataset and pre-trained ViT. 63 | 64 | As the image encoder is frozen during training, the self-supervised training is quite efficient and can be implemented with only one GPU. 65 | To start training on PVOC, you can run the following exemplary command: 66 | ``` 67 | python train.py --config_path ./configs/train_voc_config.yml 68 | ``` 69 | 70 | The pre-trained ViT by DINO can be found [here](https://github.com/facebookresearch/dino). 71 | 72 | ### End-to-End Semantic Segmentation Inference 73 | 74 | AlignSeg can perform real-time and end-to-end segmentation inference. 75 | 76 | To perform segmentation inference and visualization, you can run the following exemplary command: 77 | ``` 78 | python evaluate/visualize_segment.py --pretrained_weights {model.pth} --image_path ./images/2007_000464.jpg 79 | ``` 80 | replace `{model.pth}` with the path to the pre-trained model. 81 | 82 | ### Evaluation 83 | 84 | We provide the evaluation configuration files for PVOC and COCO-Stuff-27 in ```/configs``` folder, fill in your own path to dataset and pre-trained model. 85 | 86 | To evaluate the pre-trained model on PVOC, you can run the following exemplary command: 87 | ``` 88 | python evaluate/sup_overcluster.py --config_path ../configs/eval_voc_config.yml 89 | ``` 90 | 91 | ### Pre-trained Models 92 | 93 | We provide our pre-trained models, they can be downloaded by links below. 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |
EncoderDatasetmIoUDownload
ViT-S/16PVOC69.5model
ViT-S/16COCO-Stuff-2735.1model
115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 |
EncoderDatasetFG-ARImBODownload
ViT-S/16MOVi-C48.031.2model
ViT-S/16MOVi-E44.120.4model
139 | 140 | 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /evaluate/cocoStuff27_mask_visualize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | import torchvision.transforms as T 9 | from torchvision.transforms.functional import InterpolationMode 10 | 11 | 12 | label_to_color = { 13 | 0: (245, 245, 220), # accessory 14 | 1: (0, 100, 0), # animal 15 | 2: (178, 34, 34), # appliance 16 | 3: (0, 0, 139), # building 17 | 4: (148, 0, 211), # ceiling 18 | 5: (105, 105, 105), # electronic 19 | 6: (205, 92, 92), # floor 20 | 7: (244, 164, 96), # food 21 | 8: (245, 222, 179), # food-stuff 22 | 9: (75, 0, 130), # furniture 23 | 10: (138, 43, 226), # furniture-stuff 24 | 11: (72, 61, 139), # ground 25 | 12: (25, 25, 112), # indoor 26 | 13: (253, 245, 230), # kitchen 27 | 14: (47, 79, 79), # outdoor 28 | 15: (139, 0, 0), # person 29 | 16: (124, 252, 0), # plant 30 | 17: (210, 180, 140), # raw-material 31 | 18: (135, 206, 235), # sky 32 | 19: (85, 107, 47), # solid 33 | 20: (255, 105, 180), # sports 34 | 21: (210, 105, 30), # structural 35 | 22: (211, 211, 211), # textile 36 | 23: (184, 134, 11), # vehicle 37 | 24: (128, 128, 128), # wall 38 | 25: (32, 178, 170), # water 39 | 26: (189, 183, 107), # window 40 | 27: (255, 250, 250), # other 41 | } 42 | 43 | super_cat_to_id = { 44 | 'accessory': 0, 'animal': 1, 'appliance': 2, 'building': 3, 45 | 'ceiling': 4, 'electronic': 5, 'floor': 6, 'food': 7, 46 | 'food-stuff': 8, 'furniture': 9, 'furniture-stuff': 10, 'ground': 11, 47 | 'indoor': 12, 'kitchen': 13, 'outdoor': 14, 'person': 15, 48 | 'plant': 16, 'raw-material': 17, 'sky': 18, 'solid': 19, 49 | 'sports': 20, 'structural': 21, 'textile': 22, 'vehicle': 23, 50 | 'wall': 24, 'water': 25, 'window': 26, 51 | 'other': 27 52 | } 53 | 54 | 55 | def visual_mask(img_path, transforms, cat_id_map, RGB=False): 56 | 57 | mask = Image.open(img_path) 58 | mask = transforms(mask) 59 | save_path = img_path.replace('.png', '_RGB.jpg') 60 | 61 | # move 'id' labels from [0, 182] to [0,27] with 27=={182,255} 62 | # (182 is 'other' and 0 is things) 63 | mask *= 255 64 | assert torch.min(mask).item() >= 0 65 | mask[mask == 255] = 182 66 | assert torch.max(mask).item() <= 182 67 | for cat_id in torch.unique(mask): 68 | mask[mask == cat_id] = cat_id_map[cat_id.item()] 69 | 70 | assert torch.max(mask).item() <= 27 71 | assert torch.min(mask).item() >= 0 72 | 73 | mask = mask.squeeze(0).numpy().astype(int) 74 | mask = mask.astype(np.uint8) 75 | 76 | if not RGB: 77 | img = Image.fromarray(mask, 'L') 78 | plt.figure(figsize=(8, 8)) 79 | plt.axis('off') 80 | plt.imshow(img) 81 | # plt.savefig('mask.png', bbox_inches='tight', pad_inches=0.0) 82 | plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0) 83 | plt.show() 84 | else: 85 | # visualize by configuring palette 86 | img = Image.fromarray(mask, 'L') 87 | img_p = img.convert('P') 88 | img_p.putpalette([rgb for pixel in label_to_color.values() for rgb in pixel]) 89 | 90 | img_rgb = img_p.convert('RGB') 91 | plt.figure(figsize=(12, 4)) 92 | plt.subplot(1, 3, 1), plt.imshow(img) 93 | plt.subplot(1, 3, 2), plt.imshow(img_p) 94 | plt.subplot(1, 3, 3), plt.imshow(img_rgb) 95 | plt.tight_layout(), plt.show() 96 | img_rgb.save(save_path) 97 | 98 | 99 | if __name__ == '__main__': 100 | 101 | root = './COCO/annotations/stuffthingmaps_trainval2017/' 102 | json_file = "stuffthing_2017.json" 103 | mask_name = 'val2017/000000512194.png' 104 | 105 | mask_transforms = T.Compose([T.Resize((448, 448), interpolation=InterpolationMode.NEAREST), T.ToTensor()]) 106 | 107 | with open(os.path.join(root, json_file)) as f: 108 | an_json = json.load(f) 109 | all_cat = an_json['categories'] 110 | 111 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat]) 112 | super_cats.remove("other") # remove others from prediction targets as this is not semantic 113 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))} 114 | super_cat_to_id["other"] = 27 # ignore_index 115 | # Align 'id' labels: PNG_label = GT_label - 1 116 | cat_id_map = {(cat_dict['id'] - 1): super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat} 117 | 118 | visual_mask(os.path.join(root, mask_name), mask_transforms, cat_id_map, RGB=False) 119 | -------------------------------------------------------------------------------- /evaluate/proto_similarity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import cv2 5 | import random 6 | import colorsys 7 | 8 | import skimage.io 9 | from skimage.measure import find_contours 10 | from matplotlib.patches import Polygon 11 | import torch 12 | import torch.nn as nn 13 | import torchvision 14 | from torchvision import transforms as pth_transforms 15 | from torchvision.transforms import GaussianBlur 16 | import torch.nn.functional as F 17 | import numpy as np 18 | from PIL import Image 19 | from skimage.measure import label 20 | from matplotlib import pyplot as plt 21 | 22 | from model.align_model import AlignSegmentor 23 | from utils import neq_load_external 24 | 25 | 26 | def norm(t): 27 | return F.normalize(t, dim=-1, eps=1e-10) 28 | 29 | 30 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \ 31 | -> torch.Tensor: 32 | """ 33 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes 34 | components smaller than 3 pixels. 35 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the 36 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions 37 | so that we can use the foreground mask during training time. 38 | :param attentions: torch 4D-Tensor containing the averaged attentions 39 | :param spatial_res: spatial resolution of the attention map 40 | :param threshold: the percentage of mass to keep as foreground. 41 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring. 42 | :return: the foreground mask obtained from the ViT's attention. 43 | """ 44 | # Blur attentions 45 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions) 46 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2) 47 | # Keep threshold% of mass 48 | val, idx = torch.sort(attentions) 49 | val /= torch.sum(val, dim=-1, keepdim=True) 50 | cumval = torch.cumsum(val, dim=-1) 51 | th_attn = cumval > (1 - threshold) 52 | idx2 = torch.argsort(idx) 53 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0]) 54 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float() 55 | # Remove components with less than 3 pixels 56 | for j, th_att in enumerate(th_attn): 57 | labelled = label(th_att.cpu().numpy()) 58 | for k in range(1, np.max(labelled) + 1): 59 | mask = labelled == k 60 | if np.sum(mask) <= 2: 61 | th_attn[j, 0][mask] = 0 62 | return th_attn 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser('Evaluate segmentation on pretrained model') 67 | parser.add_argument('--pretrained_weights', default='./epoch10.pth', 68 | type=str, help="Path to pretrained weights to load.") 69 | parser.add_argument("--image_path", default='', 70 | type=str, help="Path of the image to load.") 71 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 72 | parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.") 73 | parser.add_argument('--output_dir', default='./outputs/', help='Path where to save visualizations.') 74 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks 75 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 76 | args = parser.parse_args() 77 | 78 | device = torch.device("cpu") 79 | # build model 80 | model = AlignSegmentor(arch='vit_small', 81 | patch_size=16, 82 | embed_dim=384, 83 | hidden_dim=384, 84 | num_heads=2, 85 | num_queries=5, 86 | nmb_crops=[1, 0], 87 | num_decode_layers=1, 88 | last_self_attention=True) 89 | 90 | # set model to eval mode 91 | for p in model.parameters(): 92 | p.requires_grad = False 93 | model.eval() 94 | model.to(device) 95 | 96 | # load pretrained weights 97 | if os.path.isfile(args.pretrained_weights): 98 | pratrained_model = torch.load(args.pretrained_weights, map_location="cpu") 99 | msg = model.load_state_dict(pratrained_model['state_dict'], strict=False) 100 | print(msg) 101 | else: 102 | print('no pretrained pth found!') 103 | 104 | queries = model.clsQueries.weight 105 | 106 | prototypes = torch.load('../log_tmp/prototypes21.pth') 107 | prototypes = prototypes.to(device) 108 | # calculate query assignment score 109 | sim_query_proto = norm(queries) @ norm(prototypes).T 110 | sim_query_proto = sim_query_proto.clamp(min=0.0) 111 | 112 | for i in range(sim_query_proto.size(0)): 113 | print('Proto', i, '=', sim_query_proto[i]*10) 114 | -------------------------------------------------------------------------------- /data/movi_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional, Callable 4 | from PIL import Image 5 | from pathlib import Path 6 | from torch.utils.data import DataLoader 7 | from torchvision.datasets import VisionDataset 8 | from typing import Tuple, Any 9 | 10 | 11 | class MOViDataModule: 12 | 13 | def __init__(self, 14 | data_dir: str, 15 | dataset_name: str, 16 | train_split: str, 17 | val_split: str, 18 | train_image_transform: Optional[Callable], 19 | val_image_transform: Optional[Callable], 20 | val_target_transform: Optional[Callable], 21 | batch_size: int, 22 | num_workers: int, 23 | shuffle: bool = True, 24 | return_masks: bool = False, 25 | drop_last: bool = True): 26 | """ 27 | Data module for MOVi data. 28 | If return_masks is set train_image_transform should be callable with imgs and masks or None. 29 | """ 30 | super().__init__() 31 | self.root = os.path.join(data_dir, dataset_name) 32 | self.train_split = train_split 33 | self.val_split = val_split 34 | self.batch_size = batch_size 35 | self.num_workers = num_workers 36 | self.train_image_transform = train_image_transform 37 | self.val_image_transform = val_image_transform 38 | self.val_target_transform = val_target_transform 39 | self.shuffle = shuffle 40 | self.drop_last = drop_last 41 | self.return_masks = return_masks 42 | 43 | # Set up datasets in __init__ as we need to know the number of samples to init cosine lr schedules 44 | self.movi_train = MOViDataset(root=self.root, image_set=train_split, transforms=self.train_image_transform, 45 | return_masks=self.return_masks) 46 | self.movi_val = MOViDataset(root=self.root, image_set=val_split, transform=self.val_image_transform, 47 | target_transform=self.val_target_transform) 48 | print('--- Loaded ' + dataset_name + ' with Train %d, Val %d ---' % (len(self.movi_train), len(self.movi_val))) 49 | 50 | def __len__(self): 51 | return len(self.movi_train) 52 | 53 | def train_dataloader(self): 54 | return DataLoader(self.movi_train, batch_size=self.batch_size, 55 | shuffle=self.shuffle, num_workers=self.num_workers, 56 | drop_last=self.drop_last, pin_memory=True) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.movi_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, 60 | drop_last=self.drop_last, pin_memory=True) 61 | 62 | 63 | class MOViDataset(VisionDataset): 64 | 65 | def __init__( 66 | self, 67 | root: str, 68 | image_set: str = "frames", 69 | transform: Optional[Callable] = None, 70 | target_transform: Optional[Callable] = None, 71 | transforms: Optional[Callable] = None, 72 | return_masks: bool = False 73 | ): 74 | super(MOViDataset, self).__init__(root, transforms, transform, target_transform) 75 | self.image_set = image_set 76 | if self.image_set == "frames": # set for training 77 | img_folder = "frames" 78 | elif self.image_set == "images": # set for validation 79 | img_folder = "images" 80 | else: 81 | raise ValueError(f"No support for image set {self.image_set}") 82 | image_dir = os.path.join(root, img_folder) 83 | seg_dir = os.path.join(root, 'masks') 84 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir) or not os.path.isdir(root): 85 | raise RuntimeError('Dataset not found or corrupted.') 86 | 87 | self.images = [os.path.join(image_dir, x) for x in os.listdir(image_dir)] 88 | self.masks = [os.path.join(seg_dir, x) for x in os.listdir(seg_dir)] 89 | self.return_masks = return_masks 90 | 91 | assert all([Path(f).is_file() for f in self.masks]) and all([Path(f).is_file() for f in self.images]) 92 | 93 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 94 | img = Image.open(self.images[index]).convert('RGB') 95 | if self.image_set == "images": # for validation 96 | # print('img = ', self.images[index]) 97 | mask = Image.open(self.masks[index]) 98 | print("image: ", self.images[index]) 99 | if self.transforms: 100 | img, mask = self.transforms(img, mask) 101 | return img, mask 102 | elif self.image_set == "frames": # for training 103 | if self.transforms: 104 | if self.return_masks: 105 | mask = Image.open(self.masks[index]) 106 | res = self.transforms(img, mask) 107 | else: 108 | res = self.transforms(img) 109 | return res 110 | return img 111 | 112 | def __len__(self) -> int: 113 | return len(self.images) 114 | -------------------------------------------------------------------------------- /evaluate/object_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | 10 | import yaml 11 | from torchvision.transforms.functional import InterpolationMode 12 | 13 | from data.movi_data import MOViDataModule 14 | from model.align_model import AlignSegmentor 15 | from eval_utils import ARIMetric, AverageBestOverlapMetric 16 | 17 | 18 | def norm(t): 19 | return F.normalize(t, dim=-1, eps=1e-10) 20 | 21 | 22 | def eval_objectmasks(): 23 | with open(args.config_path) as file: 24 | config = yaml.safe_load(file.read()) 25 | # print('Config: ', config) 26 | 27 | data_config = config['data'] 28 | val_config = config['val'] 29 | input_size = data_config["size_crops"] 30 | torch.manual_seed(val_config['seed']) 31 | 32 | # Init data and transforms 33 | val_image_transforms = T.Compose([T.Resize((input_size, input_size)), 34 | T.ToTensor(), 35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 36 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST), 37 | T.ToTensor()]) 38 | 39 | data_dir = data_config["data_dir"] 40 | dataset_name = data_config["dataset_name"] 41 | if "movi" in dataset_name: 42 | ignore_index = 0 43 | num_classes = 17 44 | data_module = MOViDataModule(data_dir=data_dir, 45 | dataset_name=data_config['dataset_name'], 46 | batch_size=val_config["batch_size"], 47 | return_masks=True, 48 | drop_last=True, 49 | num_workers=config["num_workers"], 50 | train_split="frames", 51 | val_split="images", 52 | train_image_transform=None, 53 | val_image_transform=val_image_transforms, 54 | val_target_transform=val_target_transforms) 55 | else: 56 | raise ValueError(f"{dataset_name} not supported") 57 | 58 | # Init method 59 | patch_size = val_config["patch_size"] 60 | spatial_res = input_size / patch_size 61 | num_proto = val_config['num_queries'] 62 | assert spatial_res.is_integer() 63 | model = AlignSegmentor(arch=val_config['arch'], 64 | patch_size=val_config['patch_size'], 65 | embed_dim=val_config['embed_dim'], 66 | hidden_dim=val_config['hidden_dim'], 67 | num_heads=val_config['decoder_num_heads'], 68 | num_queries=val_config['num_queries'], 69 | num_decode_layers=val_config['num_decode_layers'], 70 | last_self_attention=val_config['last_self_attention']) 71 | 72 | # set model to eval mode 73 | for p in model.parameters(): 74 | p.requires_grad = False 75 | model.eval() 76 | 77 | # load pretrained weights 78 | if val_config["checkpoint"] is not None: 79 | checkpoint = torch.load(val_config["checkpoint"]) 80 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True) 81 | print(msg) 82 | else: 83 | print('no pretrained pth found!') 84 | 85 | dataloader = data_module.val_dataloader() 86 | ARI_metric = ARIMetric() 87 | BO_metric = AverageBestOverlapMetric() 88 | 89 | # Calculate IoU for each image individually 90 | for idx, batch in enumerate(dataloader): 91 | imgs, masks = batch 92 | B = imgs.size(0) 93 | # assert B == 1 # image has to be evaluated individually 94 | all_queries, tokens, _, _, res, _ = model([imgs]) # tokens=(1,N,dim) 95 | 96 | # calculate token assignment 97 | token_cls = torch.einsum("bnc,bqc->bnq", norm(tokens), norm(all_queries[0])) 98 | token_cls = torch.softmax(token_cls, dim=-1) 99 | token_cls = token_cls.reshape(B, res, res, -1).permute(0, 3, 1, 2) # (1,num_query,res,res) 100 | 101 | # downsample masks / upsample preds to masks_eval_size 102 | preds = F.interpolate(token_cls, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), 103 | mode='bilinear') 104 | masks *= 255 105 | if masks.size(3) != val_config['mask_eval_size']: 106 | masks = F.interpolate(masks, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), 107 | mode='nearest') 108 | 109 | # turn masks to one-hot 110 | masks = masks.squeeze(dim=1).reshape(B, -1) 111 | masks = masks.long() 112 | num_classes = masks.max().item() + 1 113 | masks = torch.nn.functional.one_hot(masks, num_classes) 114 | masks = masks.permute(0, 2, 1).reshape(B, num_classes, 115 | val_config['mask_eval_size'], 116 | val_config['mask_eval_size']) # to (B, K, H, W) 117 | 118 | ARI_metric.update(preds, masks) 119 | BO_metric.update(preds, masks) 120 | # sys.exit(1) 121 | 122 | ARI_metric.compute() 123 | BO_metric.compute() 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--config_path', default='../configs/eval_movi_config.yml', type=str) 129 | 130 | args = parser.parse_args() 131 | 132 | eval_objectmasks() 133 | -------------------------------------------------------------------------------- /data/voc_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional, Callable 4 | from PIL import Image 5 | from pathlib import Path 6 | from torch.utils.data import DataLoader 7 | from torchvision.datasets import VisionDataset 8 | from typing import Tuple, Any 9 | 10 | 11 | class VOCDataModule: 12 | 13 | CLASS_IDX_TO_NAME = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 14 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 15 | 'train', 'tvmonitor'] 16 | 17 | def __init__(self, 18 | data_dir: str, 19 | train_split: str, 20 | val_split: str, 21 | train_image_transform: Optional[Callable], 22 | val_image_transform: Optional[Callable], 23 | val_target_transform: Optional[Callable], 24 | batch_size: int, 25 | num_workers: int, 26 | shuffle: bool = True, 27 | return_masks: bool = False, 28 | drop_last: bool = True): 29 | """ 30 | Data module for PVOC data. "trainaug" and "train" are valid train_splits. 31 | If return_masks is set train_image_transform should be callable with imgs and masks or None. 32 | """ 33 | super().__init__() 34 | self.root = os.path.join(data_dir, "PVOC") 35 | self.train_split = train_split 36 | self.val_split = val_split 37 | self.batch_size = batch_size 38 | self.num_workers = num_workers 39 | self.train_image_transform = train_image_transform 40 | self.val_image_transform = val_image_transform 41 | self.val_target_transform = val_target_transform 42 | self.shuffle = shuffle 43 | self.drop_last = drop_last 44 | self.return_masks = return_masks 45 | 46 | # Set up datasets in __init__ as we need to know the number of samples to init cosine lr schedules 47 | assert train_split == "trainaug" or train_split == "train" 48 | self.voc_train = VOCDataset(root=self.root, image_set=train_split, transforms=self.train_image_transform, 49 | return_masks=self.return_masks) 50 | self.voc_val = VOCDataset(root=self.root, image_set=val_split, transform=self.val_image_transform, 51 | target_transform=self.val_target_transform) 52 | print('--- loaded VOC with Train %d, Val %d ---' % (len(self.voc_train), len(self.voc_val))) 53 | 54 | def __len__(self): 55 | return len(self.voc_train) 56 | 57 | def class_id_to_name(self, i: int): 58 | return self.CLASS_IDX_TO_NAME[i] 59 | 60 | def train_dataloader(self): 61 | return DataLoader(self.voc_train, batch_size=self.batch_size, 62 | shuffle=self.shuffle, num_workers=self.num_workers, 63 | drop_last=self.drop_last, pin_memory=True) 64 | 65 | def val_dataloader(self): 66 | return DataLoader(self.voc_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, 67 | drop_last=self.drop_last, pin_memory=True) 68 | 69 | 70 | class VOCDataset(VisionDataset): 71 | 72 | def __init__( 73 | self, 74 | root: str, 75 | image_set: str = "trainaug", 76 | transform: Optional[Callable] = None, 77 | target_transform: Optional[Callable] = None, 78 | transforms: Optional[Callable] = None, 79 | return_masks: bool = False 80 | ): 81 | super(VOCDataset, self).__init__(root, transforms, transform, target_transform) 82 | self.image_set = image_set 83 | if self.image_set == "trainaug" or self.image_set == "train": 84 | seg_folder = "SegmentationClassAug" 85 | elif self.image_set == "val": 86 | seg_folder = "SegmentationClass" 87 | else: 88 | raise ValueError(f"No support for image set {self.image_set}") 89 | seg_dir = os.path.join(root, seg_folder) 90 | image_dir = os.path.join(root, 'images') 91 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir) or not os.path.isdir(root): 92 | raise RuntimeError('Dataset not found or corrupted.') 93 | splits_dir = os.path.join(root, 'sets') 94 | split_f = os.path.join(splits_dir, self.image_set.rstrip('\n') + '.txt') 95 | 96 | with open(os.path.join(split_f), "r") as f: 97 | file_names = [x.strip() for x in f.readlines()] 98 | 99 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 100 | self.masks = [os.path.join(seg_dir, x + ".png") for x in file_names] 101 | self.return_masks = return_masks 102 | 103 | assert all([Path(f).is_file() for f in self.masks]) and all([Path(f).is_file() for f in self.images]) 104 | 105 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 106 | img = Image.open(self.images[index]).convert('RGB') 107 | if self.image_set == "val": 108 | # print('img = ', self.images[index]) 109 | mask = Image.open(self.masks[index]) 110 | if self.transforms: 111 | img, mask = self.transforms(img, mask) 112 | return img, mask 113 | elif "train" in self.image_set: 114 | if self.transforms: 115 | if self.return_masks: 116 | mask = Image.open(self.masks[index]) 117 | res = self.transforms(img, mask) 118 | else: 119 | res = self.transforms(img) 120 | return res 121 | return img 122 | 123 | def __len__(self) -> int: 124 | return len(self.images) 125 | -------------------------------------------------------------------------------- /model/criterion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from skimage.measure import label 8 | 9 | from utils import calc_topk_accuracy 10 | 11 | 12 | def norm(t): 13 | return F.normalize(t, dim=-1, eps=1e-10) 14 | 15 | 16 | def query_ce_loss(gc_query, lc_query, num_queries, temperature=1): 17 | B = gc_query.size(0) 18 | N = 2 * num_queries 19 | criterion = nn.CrossEntropyLoss() 20 | mask = mask_correlated_samples(num_queries, gc_query.device) 21 | 22 | # calculate ce loss for each query set in B 23 | loss = top1_avg = 0 24 | labels = torch.zeros(N, device=gc_query.device).long() 25 | for i in range(B): 26 | z = torch.cat((gc_query[i], lc_query[i]), dim=0) 27 | 28 | sim = torch.matmul(z, z.T) / temperature 29 | sim_gc_lc = torch.diag(sim, num_queries) 30 | sim_lc_gc = torch.diag(sim, -num_queries) 31 | 32 | positive_samples = torch.cat((sim_gc_lc, sim_lc_gc), dim=0).reshape(N, 1) 33 | negative_samples = sim[mask].reshape(N, -1) 34 | 35 | logits = torch.cat((positive_samples, negative_samples), dim=1) 36 | ce_loss = criterion(logits, labels) 37 | 38 | top1 = calc_topk_accuracy(logits, labels, (1,)) 39 | 40 | loss += ce_loss 41 | top1_avg += top1[0] 42 | 43 | return loss / B, top1_avg / B 44 | 45 | 46 | def mask_correlated_samples(num_seq, device): 47 | N = 2 * num_seq 48 | mask = torch.ones((N, N), device=device) 49 | mask = mask.fill_diagonal_(0) 50 | for i in range(num_seq): 51 | mask[i, num_seq + i] = 0 52 | mask[num_seq + i, i] = 0 53 | mask = mask.bool() 54 | return mask 55 | 56 | 57 | class AlignCriterion(nn.Module): 58 | def __init__(self, patch_size=16, 59 | num_queries=5, 60 | nmb_crops=(1, 1), 61 | roi_align_kernel_size=7, 62 | ce_temperature=1, 63 | negative_pressure=0.1, 64 | last_self_attention=True): 65 | super(AlignCriterion, self).__init__() 66 | self.patch_size = patch_size 67 | self.num_queries = num_queries 68 | self.nmb_crops = nmb_crops 69 | self.roi_align_kernel_size = roi_align_kernel_size 70 | self.ce_temperature = ce_temperature 71 | self.negative_pressure = negative_pressure 72 | self.last_self_attention = last_self_attention 73 | 74 | def forward(self, results, bboxes): 75 | all_queries, gc_output, lc_output, attn_hard, gc_spatial_res, lc_spatial_res = results 76 | B = gc_output.size(0) 77 | 78 | # prepare foreground mask 79 | mask = attn_hard.reshape(B*sum(self.nmb_crops), -1) 80 | mask = mask.int() 81 | mask_gc, masks_lc = mask[:B * self.nmb_crops[0]], mask[B * self.nmb_crops[0]:] 82 | 83 | loss = 0 84 | ''' 85 | 1. Compute patch correlation to assignment similarity alignment loss 86 | -- Compute similarity between Queries and spatial_tokens, and align to patch correlation 87 | -- use attention map as foreground hint to mask correlation matrix 88 | -- assuming there is ONLY 1 global crop 89 | ''' 90 | # compute patch correlation between gc and lc, use as assignment target later 91 | with torch.no_grad(): 92 | gclc_correlations = [] 93 | masks_gc_lc = [] 94 | mask_gc = mask_gc.repeat(1, lc_spatial_res**2).reshape(B, lc_spatial_res**2, -1) 95 | mask_gc = mask_gc.transpose(1, 2) # (B,n,m) 96 | for i in range(self.nmb_crops[-1]): 97 | # compute cosine similarity 98 | correlation = torch.einsum("bnc,bmc->bnm", norm(gc_output), norm(lc_output[:, i])) 99 | # spatial centering for better recognizing small objects 100 | old_mean = correlation.mean() 101 | correlation -= correlation.mean(dim=-1, keepdim=True) 102 | correlation = correlation - correlation.mean() + old_mean 103 | gclc_correlations.append(correlation) 104 | 105 | # compute gc-lc foreground intersection mask 106 | mask_lc_ = masks_lc[i*B:(i+1)*B] # (B, m) 107 | mask_lc_ = mask_lc_.repeat(1, gc_spatial_res**2).reshape(B, gc_spatial_res**2, -1) # (B,n,m) 108 | mask_gc_lc_ = mask_gc * mask_lc_ 109 | masks_gc_lc.append(mask_gc_lc_.bool()) 110 | 111 | # compute gc and lc token assignment 112 | gc_token_assign = torch.einsum("bnc,bqc->bnq", norm(gc_output), norm(all_queries[0])) 113 | 114 | gclc_cor_loss = 0 115 | lc_assigns_detached = [] 116 | for i in range(self.nmb_crops[-1]): 117 | lc_token_assign = torch.einsum("bmc,bqc->bmq", norm(lc_output[:, i]), norm(all_queries[i+1])) 118 | # store lc intersection assignment 119 | lc_tmp = torch.clone(lc_token_assign.detach()) 120 | lc_tmp = lc_tmp.reshape(B, lc_spatial_res, lc_spatial_res, -1).permute(0, 3, 1, 2) # (B, num_queries, 6, 6) 121 | lc_assigns_detached.append(lc_tmp) 122 | 123 | # note here correlation value is not cosine similarity 124 | gc_token_assign_ = gc_token_assign.clamp(min=0.0) 125 | lc_token_assign_ = lc_token_assign.clamp(min=0.0) 126 | gclc_assign_cor = torch.einsum("bnq,bmq->bnm", gc_token_assign_.softmax(dim=-1), lc_token_assign_.softmax(dim=-1)) 127 | # align patch assignment similarity to feature correlation 128 | cor_align_loss = (- gclc_assign_cor * (gclc_correlations[i] - self.negative_pressure))[masks_gc_lc[i]] 129 | gclc_cor_loss += 0.15*cor_align_loss.sum() 130 | 131 | loss += gclc_cor_loss / self.nmb_crops[-1] 132 | 133 | ''' 134 | 2. Compute Global-Local Query Alignment loss 135 | -- use cross-entropy loss to align queries, and make each query different 136 | ''' 137 | query_align_loss = 0 138 | for i in range(self.nmb_crops[-1]): 139 | tmp_loss, top1 = query_ce_loss(norm(all_queries[0]), norm(all_queries[i + 1]), self.num_queries, 140 | self.ce_temperature) 141 | query_align_loss += tmp_loss 142 | 143 | loss += query_align_loss / self.nmb_crops[-1] 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /evaluate/sup_overcluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | 8 | import yaml 9 | from torchvision.transforms.functional import InterpolationMode 10 | 11 | from data.coco_data import CocoDataModule 12 | from data.voc_data import VOCDataModule 13 | from model.align_model import AlignSegmentor 14 | from utils import PredsmIoU 15 | 16 | 17 | def norm(t): 18 | return F.normalize(t, dim=-1, eps=1e-10) 19 | 20 | 21 | def eval_overcluster(): 22 | with open(args.config_path) as file: 23 | config = yaml.safe_load(file.read()) 24 | # print('Config: ', config) 25 | 26 | data_config = config['data'] 27 | val_config = config['val'] 28 | input_size = data_config["size_crops"] 29 | torch.manual_seed(val_config['seed']) 30 | torch.cuda.manual_seed_all(val_config['seed']) 31 | 32 | # Init data and transforms 33 | val_image_transforms = T.Compose([T.Resize((input_size, input_size)), 34 | T.ToTensor(), 35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 36 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST), 37 | T.ToTensor()]) 38 | 39 | data_dir = data_config["data_dir"] 40 | dataset_name = data_config["dataset_name"] 41 | if dataset_name == "voc": 42 | ignore_index = 255 43 | num_classes = 21 44 | data_module = VOCDataModule(batch_size=val_config["batch_size"], 45 | return_masks=True, 46 | num_workers=config["num_workers"], 47 | train_split="trainaug", 48 | val_split="val", 49 | data_dir=data_dir, 50 | train_image_transform=None, 51 | drop_last=True, 52 | val_image_transform=val_image_transforms, 53 | val_target_transform=val_target_transforms) 54 | elif "coco" in dataset_name: 55 | assert len(dataset_name.split("-")) == 2 56 | mask_type = dataset_name.split("-")[-1] 57 | assert mask_type in ["all", "stuff", "thing"] 58 | if mask_type == "all": 59 | num_classes = 27 60 | elif mask_type == "stuff": 61 | num_classes = 15 62 | elif mask_type == "thing": 63 | num_classes = 12 64 | ignore_index = 255 65 | file_list = os.listdir(os.path.join(data_dir, "images", "train2017")) 66 | file_list_val = os.listdir(os.path.join(data_dir, "images", "val2017")) 67 | # random.shuffle(file_list_val) 68 | data_module = CocoDataModule(batch_size=val_config["batch_size"], 69 | num_workers=config["num_workers"], 70 | file_list=file_list, 71 | data_dir=data_dir, 72 | file_list_val=file_list_val, 73 | mask_type=mask_type, 74 | train_transforms=None, 75 | val_transforms=val_image_transforms, 76 | val_target_transforms=val_target_transforms) 77 | elif dataset_name == "ade20k": 78 | num_classes = 111 79 | ignore_index = 255 80 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST)]) 81 | data_module = None 82 | else: 83 | raise ValueError(f"{dataset_name} not supported") 84 | 85 | # Init method 86 | patch_size = val_config["patch_size"] 87 | spatial_res = input_size / patch_size 88 | assert spatial_res.is_integer() 89 | model = AlignSegmentor(arch=val_config['arch'], 90 | patch_size=val_config['patch_size'], 91 | embed_dim=val_config['embed_dim'], 92 | hidden_dim=val_config['hidden_dim'], 93 | num_heads=val_config['decoder_num_heads'], 94 | num_queries=val_config['num_queries'], 95 | num_decode_layers=val_config['num_decode_layers'], 96 | last_self_attention=val_config['last_self_attention']) 97 | 98 | # set model to eval mode 99 | for p in model.parameters(): 100 | p.requires_grad = False 101 | model.eval() 102 | model.to(cuda) 103 | 104 | # load pretrained weights 105 | if val_config["checkpoint"] is not None: 106 | checkpoint = torch.load(val_config["checkpoint"]) 107 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True) 108 | print(msg) 109 | else: 110 | print('no pretrained pth found!') 111 | 112 | dataloader = data_module.val_dataloader() 113 | metric = PredsmIoU(val_config['num_queries'], num_classes) 114 | 115 | # Calculate IoU for each image individually 116 | for idx, batch in enumerate(dataloader): 117 | imgs, masks = batch 118 | B = imgs.size(0) 119 | assert B == 1 # image has to be evaluated individually 120 | all_queries, tokens, _, _, res, _ = model([imgs.to(cuda)]) # tokens=(1,N,dim) 121 | 122 | # calculate token assignment 123 | token_cls = torch.einsum("bnc,bqc->bnq", norm(tokens), norm(all_queries[0])) 124 | token_cls = torch.softmax(token_cls, dim=-1) 125 | token_cls = token_cls.reshape(B, res, res, -1).permute(0, 3, 1, 2) # (1,num_query,res,res) 126 | token_cls = token_cls.max(dim=1, keepdim=True)[1].float() # (1,1,res,res) 127 | 128 | # downsample masks/upsample preds to masks_eval_size 129 | preds = F.interpolate(token_cls, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), mode='nearest') 130 | masks *= 255 131 | if masks.size(3) != val_config['mask_eval_size']: 132 | masks = F.interpolate(masks, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), mode='nearest') 133 | 134 | metric.update(masks[masks != ignore_index], preds[masks != ignore_index]) 135 | # sys.exit(1) 136 | 137 | metric.compute() 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--config_path', default='../configs/eval_voc_config.yml', type=str) 143 | parser.add_argument('--gpu', default='0', type=str) 144 | 145 | args = parser.parse_args() 146 | 147 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 148 | cuda = torch.device('cuda') 149 | eval_overcluster() 150 | -------------------------------------------------------------------------------- /model/align_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.vit import vit_small, vit_base, vit_large, trunc_normal_ 6 | from utils import process_attentions 7 | 8 | 9 | class CrossAttentionLayer(nn.Module): 10 | 11 | def __init__(self, d_model, nhead, dropout=0.0): 12 | super().__init__() 13 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 14 | 15 | self.norm = nn.LayerNorm(d_model) 16 | self.dropout = nn.Dropout(dropout) 17 | 18 | # self._reset_parameters() 19 | 20 | def _reset_parameters(self): 21 | for p in self.parameters(): 22 | if p.dim() > 1: 23 | nn.init.xavier_uniform_(p) 24 | 25 | def with_pos_embed(self, tensor, pos): 26 | return tensor if pos is None else tensor + pos 27 | 28 | def forward(self, tgt, memory, 29 | pos=None, 30 | query_pos=None): 31 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 32 | key=self.with_pos_embed(memory, pos), value=memory)[0] 33 | tgt = tgt + self.dropout(tgt2) 34 | tgt = self.norm(tgt) 35 | return tgt 36 | 37 | 38 | class SelfAttentionLayer(nn.Module): 39 | 40 | def __init__(self, d_model, nhead, dropout=0.0): 41 | super().__init__() 42 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 43 | 44 | self.norm = nn.LayerNorm(d_model) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | # self._reset_parameters() 48 | 49 | def _reset_parameters(self): 50 | for p in self.parameters(): 51 | if p.dim() > 1: 52 | nn.init.xavier_uniform_(p) 53 | 54 | def with_pos_embed(self, tensor, pos): 55 | return tensor if pos is None else tensor + pos 56 | 57 | def forward(self, tgt, query_pos=None): 58 | q = k = self.with_pos_embed(tgt, query_pos) 59 | tgt2 = self.self_attn(q, k, value=tgt)[0] 60 | tgt = tgt + self.dropout(tgt2) 61 | tgt = self.norm(tgt) 62 | 63 | return tgt 64 | 65 | 66 | class FFNLayer(nn.Module): 67 | 68 | def __init__(self, d_model, dim_feedforward=1024, dropout=0.0): 69 | super().__init__() 70 | # Implementation of Feedforward model 71 | self.linear1 = nn.Linear(d_model, dim_feedforward) 72 | self.dropout = nn.Dropout(dropout) 73 | self.linear2 = nn.Linear(dim_feedforward, d_model) 74 | 75 | self.norm = nn.LayerNorm(d_model) 76 | self.activation = F.relu 77 | 78 | self.apply(self._init_weights) 79 | 80 | def _init_weights(self, m): 81 | if isinstance(m, nn.Linear): 82 | nn.init.trunc_normal_(m.weight, std=.02) 83 | if isinstance(m, nn.Linear) and m.bias is not None: 84 | nn.init.constant_(m.bias, 0) 85 | 86 | def forward(self, tgt): 87 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 88 | tgt = tgt + self.dropout(tgt2) 89 | tgt = self.norm(tgt) 90 | return tgt 91 | 92 | 93 | class MLP(nn.Module): 94 | 95 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2): 96 | super().__init__() 97 | self.num_layers = num_layers 98 | h = [hidden_dim] * (num_layers - 1) 99 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 100 | 101 | def forward(self, x): 102 | for i, layer in enumerate(self.layers): 103 | x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x) 104 | return x 105 | 106 | 107 | class AlignSegmentor(nn.Module): 108 | 109 | def __init__(self, arch='vit_small', 110 | patch_size=16, 111 | embed_dim=384, 112 | hidden_dim=384, 113 | num_heads=4, 114 | num_queries=21, 115 | nmb_crops=(1, 0), 116 | num_decode_layers=1, 117 | last_self_attention=True): 118 | super(AlignSegmentor, self).__init__() 119 | self.patch_size = patch_size 120 | self.embed_dim = embed_dim 121 | self.hidden_dim = hidden_dim 122 | self.nmb_crops = nmb_crops 123 | self.num_decode_layers = num_decode_layers 124 | self.last_self_attention = last_self_attention 125 | 126 | # Initialize model 127 | if arch == 'vit_small': 128 | self.backbone = vit_small(patch_size=patch_size) 129 | elif arch == 'vit_base': 130 | self.backbone = vit_base(patch_size=patch_size) 131 | elif arch == 'vit_large': 132 | self.backbone = vit_large(patch_size=patch_size) 133 | else: 134 | raise ValueError(f"{self.arch} is not supported") 135 | 136 | # learnable CLS queries and/or positional queries 137 | self.clsQueries = nn.Embedding(num_queries, embed_dim) 138 | 139 | # simple Transformer Decoder with num_decoder_layers 140 | self.decoder_cross_attention_layers = nn.ModuleList() 141 | self.decoder_self_attention_layers = nn.ModuleList() 142 | self.decoder_ffn_layers = nn.ModuleList() 143 | for _ in range(self.num_decode_layers): 144 | self.decoder_cross_attention_layers.append( 145 | CrossAttentionLayer(d_model=embed_dim, nhead=num_heads) 146 | ) 147 | self.decoder_self_attention_layers.append( 148 | SelfAttentionLayer(d_model=embed_dim, nhead=num_heads) 149 | ) 150 | self.decoder_ffn_layers.append( 151 | FFNLayer(d_model=embed_dim, dim_feedforward=hidden_dim) 152 | ) 153 | 154 | self.apply(self._init_weights) 155 | 156 | def _init_weights(self, m): 157 | if isinstance(m, nn.Linear): 158 | trunc_normal_(m.weight, std=.02) 159 | if m.bias is not None: 160 | nn.init.constant_(m.bias, 0) 161 | elif isinstance(m, nn.LayerNorm): 162 | nn.init.constant_(m.bias, 0) 163 | nn.init.constant_(m.weight, 1.0) 164 | 165 | def set_clsQuery(self, prototypes): 166 | # initialize clsQueries with generated prototypes of [num_queries, embed_dim] 167 | self.clsQueries.weight = nn.Parameter(prototypes) 168 | 169 | def forward(self, inputs, threshold=0.6): 170 | # inputs is a list of crop images 171 | B = inputs[0].size(0) 172 | 173 | # repeat query for batch use, (B, num_queries, embed_dim) 174 | outQueries = self.clsQueries.weight.unsqueeze(0).repeat(B, 1, 1) 175 | posQueries = pos = None 176 | 177 | # Extract feature 178 | outputs = self.backbone(inputs, self.nmb_crops, self.last_self_attention) 179 | if self.last_self_attention: 180 | outputs, attentions = outputs # outputs=[B*N(196+36), embed_dim], attentions(only global)=[B, heads, 196] 181 | 182 | # calculate gc and lc resolutions. Split output in gc and lc embeddings 183 | gc_res_w = inputs[0].size(2) / self.patch_size 184 | gc_res_h = inputs[0].size(3) / self.patch_size 185 | assert gc_res_w.is_integer() and gc_res_w.is_integer(), "Image dims need to be divisible by patch size" 186 | assert gc_res_w == gc_res_h, f"Only supporting square images not {inputs[0].size(2)}x{inputs[0].size(3)}" 187 | gc_spatial_res = int(gc_res_w) 188 | lc_res_w = inputs[-1].size(2) / self.patch_size 189 | assert lc_res_w.is_integer(), "Image dims need to be divisible by patch size" 190 | lc_spatial_res = int(lc_res_w) 191 | gc_spatial_output, lc_spatial_output = outputs[:B * self.nmb_crops[0] * gc_spatial_res ** 2], \ 192 | outputs[B * self.nmb_crops[0] * gc_spatial_res ** 2:] 193 | # (B*N, C) -> (B, N, C) 194 | gc_spatial_output = gc_spatial_output.reshape(B, -1, self.embed_dim) 195 | if self.nmb_crops[-1] != 0: 196 | lc_spatial_output = lc_spatial_output.reshape(B, self.nmb_crops[-1], lc_spatial_res**2, self.embed_dim) 197 | 198 | # merge attention heads and threshold attentions 199 | attn_hard = None 200 | if self.last_self_attention: 201 | attn_smooth = sum(attentions[:, i] * 1 / attentions.size(1) for i in range(attentions.size(1))) 202 | attn_smooth = attn_smooth.reshape(B * sum(self.nmb_crops), 1, gc_spatial_res, gc_spatial_res) 203 | # attn_hard is later served as 'foreground' hint, use attn_hard.bool() 204 | attn_hard = process_attentions(attn_smooth, gc_spatial_res, threshold=threshold, blur_sigma=0.6) 205 | attn_hard = attn_hard.squeeze(1) 206 | 207 | # Align Queries to each image crop's features with decoder, assuming only 1 global crop 208 | all_queries = [] 209 | for i in range(sum(self.nmb_crops)): 210 | if i == 0: 211 | features = gc_spatial_output 212 | else: 213 | features = lc_spatial_output[:, i-1] 214 | for j in range(self.num_decode_layers): 215 | # attention: cross-attention first 216 | queries = self.decoder_cross_attention_layers[j]( 217 | outQueries, features, pos=pos, query_pos=posQueries) 218 | # self-attention 219 | queries = self.decoder_self_attention_layers[j]( 220 | queries, query_pos=posQueries) 221 | # FFN 222 | queries = self.decoder_ffn_layers[j](queries) 223 | 224 | all_queries.append(queries) 225 | 226 | return all_queries, gc_spatial_output, lc_spatial_output, attn_hard, gc_spatial_res, lc_spatial_res 227 | 228 | 229 | if __name__ == '__main__': 230 | model = AlignSegmentor() -------------------------------------------------------------------------------- /model/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torchvision 5 | 6 | from PIL import ImageFilter, Image 7 | from typing import List, Tuple, Dict 8 | from torch import Tensor 9 | from torchvision.transforms import functional as F 10 | 11 | 12 | class GaussianBlur: 13 | """ 14 | Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709 following 15 | https://github.com/facebookresearch/swav/blob/5e073db0cc69dea22aa75e92bfdd75011e888f28/src/multicropdataset.py#L64 16 | """ 17 | 18 | def __init__(self, sigma=[.1, 2.]): 19 | self.sigma = sigma 20 | 21 | def __call__(self, x: Image): 22 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 23 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 24 | return x 25 | 26 | 27 | class TrainTransforms: 28 | 29 | def __init__(self, 30 | size_crops: List[int], # [224, 224] 31 | nmb_crops: List[int], # [2, 4] 32 | min_scale_crops: List[float], # [0.5, 0.05] 33 | max_scale_crops: List[float], # [1., 0.25] 34 | augment_image: bool = False, 35 | jitter_strength: float = 1.0, 36 | min_intersection: float = 0.05, 37 | blur_strength: float = 1.0): 38 | """ 39 | Main transform used for aligning target img. Implements multi-crop and calculates the corresponding 40 | crop bounding boxes for each crop-pair. 41 | :param size_crops: size of global and local crop 42 | :param nmb_crops: number of global and local crop 43 | :param min_scale_crops: the lower bound for the random area of the global and local crops before resizing 44 | :param max_scale_crops: the upper bound for the random area of the global and local crops before resizing 45 | :param augment_image: whether to perform image augmentation 46 | :param jitter_strength: the strength of jittering for brightness, contrast, saturation and hue 47 | :param min_intersection: minimum percentage of intersection of image ares for two sampled crops from the 48 | same picture should have. This makes sure that we can always calculate a loss for each pair of 49 | global and local crops. 50 | :param blur_strength: the maximum standard deviation of the Gaussian kernel 51 | """ 52 | assert len(size_crops) == len(nmb_crops) 53 | assert len(min_scale_crops) == len(nmb_crops) 54 | assert len(max_scale_crops) == len(nmb_crops) 55 | assert 0 < min_intersection < 1 56 | self.size_crops = size_crops 57 | self.nmb_crops = nmb_crops 58 | self.min_scale_crops = min_scale_crops 59 | self.max_scale_crops = max_scale_crops 60 | self.min_intersection = min_intersection 61 | self.augment_image = augment_image 62 | 63 | if self.augment_image: 64 | # Construct color transforms 65 | self.color_jitter = torchvision.transforms.ColorJitter( 66 | 0.8 * jitter_strength, 0.8 * jitter_strength, 0.8 * jitter_strength, 67 | 0.2 * jitter_strength 68 | ) 69 | color_transform = [torchvision.transforms.RandomApply([self.color_jitter], p=0.8), 70 | torchvision.transforms.RandomGrayscale(p=0.2)] 71 | blur = GaussianBlur(sigma=[blur_strength * .1, blur_strength * 2.]) 72 | color_transform.append(torchvision.transforms.RandomApply([blur], p=0.5)) 73 | self.color_transform = torchvision.transforms.Compose(color_transform) 74 | 75 | # Construct final transforms 76 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 77 | self.final_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]) 78 | 79 | # Construct randomly resized crops transforms 80 | self.rrc_transforms = [] 81 | for i in range(len(self.size_crops)): # [224, 96] 82 | random_resized_crop = torchvision.transforms.RandomResizedCrop( 83 | self.size_crops[i], 84 | scale=(self.min_scale_crops[i], self.max_scale_crops[i]), 85 | ) 86 | self.rrc_transforms.extend([random_resized_crop] * self.nmb_crops[i]) 87 | 88 | def __call__(self, sample: torch.Tensor) -> Tuple[List[Tensor], Dict[str, Tensor]]: 89 | multi_crops = [] 90 | crop_bboxes = torch.zeros(len(self.rrc_transforms), 4) 91 | 92 | for i, rrc_transform in enumerate(self.rrc_transforms): 93 | # Get random crop params 94 | y1, x1, h, w = rrc_transform.get_params(sample, rrc_transform.scale, rrc_transform.ratio) 95 | if i > 0: 96 | # Check whether crop has min overlap with existing global crops. If not resample. 97 | while True: 98 | # Calculate intersection between sampled crop and all sampled global crops 99 | bbox = torch.Tensor([x1, y1, x1 + w, y1 + h]) 100 | left_top = torch.max(bbox.unsqueeze(0)[:, None, :2], 101 | crop_bboxes[:min(i, self.nmb_crops[0]), :2]) 102 | right_bottom = torch.min(bbox.unsqueeze(0)[:, None, 2:], 103 | crop_bboxes[:min(i, self.nmb_crops[0]), 2:]) 104 | wh = _upcast(right_bottom - left_top).clamp(min=0) 105 | inter = wh[:, :, 0] * wh[:, :, 1] 106 | 107 | # set min intersection to at least 1% of image area 108 | min_intersection = int((sample.size[0] * sample.size[1]) * self.min_intersection) 109 | # Global crops should have twice the min_intersection with each other 110 | if i in list(range(self.nmb_crops[0])): 111 | min_intersection *= 2 112 | if not torch.all(inter > min_intersection): 113 | y1, x1, h, w = rrc_transform.get_params(sample, rrc_transform.scale, rrc_transform.ratio) 114 | else: 115 | break 116 | 117 | # Apply rrc params and store absolute crop bounding box 118 | img = F.resized_crop(sample, y1, x1, h, w, rrc_transform.size, rrc_transform.interpolation) 119 | crop_bboxes[i] = torch.Tensor([x1, y1, x1 + w, y1 + h]) 120 | 121 | if self.augment_image: 122 | # Apply color transforms 123 | img = self.color_transform(img) 124 | 125 | # Apply final transform 126 | img = self.final_transform(img) 127 | multi_crops.append(img) 128 | 129 | # Calculate relative bboxes for each crop pair from absolute bboxes 130 | gc_bboxes, otc_bboxes = self.calculate_bboxes(crop_bboxes) 131 | 132 | return multi_crops, {"gc": gc_bboxes, "all": otc_bboxes} 133 | 134 | def calculate_bboxes(self, crop_bboxes: Tensor): 135 | # 1. Calculate two intersection bboxes for each global crop - other crop pair 136 | gc_bboxes = crop_bboxes[:self.nmb_crops[0]] 137 | left_top = torch.max(gc_bboxes[:, None, :2], crop_bboxes[:, :2]) # [nmb_crops[0], sum(nmb_crops), 2] 138 | right_bottom = torch.min(gc_bboxes[:, None, 2:], crop_bboxes[:, 2:]) # [nmb_crops[0], sum(nmb_crops), 2] 139 | # Testing for non-intersecting crops. This should always be true, just as safeguard. 140 | assert torch.all((right_bottom - left_top) > 0) 141 | 142 | # 2. Scale intersection bbox with crop size 143 | # Extract height and width of all crop bounding boxes. Each row contains (w,h) of a crop. [sum(nmb_crops),1,2] 144 | ws_hs = torch.stack((crop_bboxes[:, 2] - crop_bboxes[:, 0], crop_bboxes[:, 3] - crop_bboxes[:, 1])).T[:, None] 145 | 146 | # Stack global crop sizes for each bbox dimension 147 | crops_sizes = torch.repeat_interleave(torch.Tensor([self.size_crops[0]]), self.nmb_crops[0] * 2) \ 148 | .reshape(self.nmb_crops[0], 2) 149 | if len(self.size_crops) == 2: 150 | lc_crops_sizes = torch.repeat_interleave(torch.Tensor([self.size_crops[1]]), self.nmb_crops[1] * 2) \ 151 | .reshape(self.nmb_crops[1], 2) 152 | crops_sizes = torch.cat((crops_sizes, lc_crops_sizes))[:, None] # [sum(nmb_crops), 1, 2] 153 | 154 | # Calculate x1s and y1s of each crop bbox 155 | x1s_y1s = crop_bboxes[:, None, :2] 156 | 157 | # Scale top left and right bottom points by percentage of width and height covered 158 | left_top_scaled_gc = crops_sizes[:self.nmb_crops[0]] \ 159 | * ((left_top - x1s_y1s[:self.nmb_crops[0]]) / ws_hs[:self.nmb_crops[0]]) 160 | right_bottom_scaled_gc = crops_sizes[:self.nmb_crops[0]] \ 161 | * ((right_bottom - x1s_y1s[:self.nmb_crops[0]]) / ws_hs[:self.nmb_crops[0]]) 162 | left_top_otc_points_per_gc = torch.stack([left_top[i] for i in range(self.nmb_crops[0])], dim=1) 163 | right_bottom_otc_points_per_gc = torch.stack([right_bottom[i] for i in range(self.nmb_crops[0])], dim=1) 164 | left_top_scaled_otc = crops_sizes * ((left_top_otc_points_per_gc - x1s_y1s) / ws_hs) 165 | right_bottom_scaled_otc = crops_sizes * ((right_bottom_otc_points_per_gc - x1s_y1s) / ws_hs) 166 | 167 | # 3. Construct bboxes in x1, y1, x2, y2 format from left top and right bottom points 168 | # gc_bboxes = relative bboxes of gc and its intersection with lc, [num_crops[0], sum(nmb_crops), 4] 169 | # otc_bboxes = relative bboxes of lc and its intersection with gc, [sum(nmb_crops), 1, 4] 170 | gc_bboxes = torch.cat((left_top_scaled_gc, right_bottom_scaled_gc), dim=2) 171 | otc_bboxes = torch.cat((left_top_scaled_otc, right_bottom_scaled_otc), dim=2) 172 | 173 | return gc_bboxes, otc_bboxes 174 | 175 | 176 | def _upcast(t: Tensor) -> Tensor: 177 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type 178 | if t.is_floating_point(): 179 | return t if t.dtype in (torch.float32, torch.float64) else t.float() 180 | else: 181 | return t if t.dtype in (torch.int32, torch.int64) else t.int() 182 | -------------------------------------------------------------------------------- /data/coco_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import DataLoader, Dataset 7 | from torchvision.datasets import VisionDataset 8 | from typing import List, Optional, Callable, Tuple, Any 9 | 10 | 11 | class CocoDataModule: 12 | 13 | def __init__(self, 14 | num_workers: int, 15 | batch_size: int, 16 | data_dir: str, 17 | train_transforms, 18 | val_transforms, 19 | file_list: List[str], 20 | mask_type: str = None, 21 | file_list_val: List[str] = None, 22 | val_target_transforms=None, 23 | shuffle: bool = True, 24 | size_val_set: int = 10): 25 | super().__init__() 26 | self.num_workers = num_workers 27 | self.batch_size = batch_size 28 | self.shuffle = shuffle 29 | self.size_val_set = size_val_set 30 | self.file_list = file_list 31 | self.file_list_val = file_list_val 32 | self.data_dir = data_dir 33 | self.train_transforms = train_transforms 34 | self.val_transforms = val_transforms 35 | self.file_list_val = file_list_val 36 | self.val_target_transforms = val_target_transforms 37 | self.mask_type = mask_type 38 | self.coco_train = None 39 | self.coco_val = None 40 | 41 | if self.mask_type is None: 42 | self.coco_train = UnlabelledCoco(self.file_list, 43 | self.train_transforms, 44 | os.path.join(self.data_dir, "images/train2017")) 45 | self.coco_val = UnlabelledCoco(self.file_list[:self.size_val_set * self.batch_size], 46 | self.val_transforms, 47 | os.path.join(self.data_dir, "images/val2017")) 48 | else: 49 | self.coco_train = COCOSegmentation(self.data_dir, 50 | self.file_list, 51 | self.mask_type, 52 | image_set="train", 53 | transforms=self.train_transforms) 54 | self.coco_val = COCOSegmentation(self.data_dir, 55 | self.file_list_val, 56 | self.mask_type, 57 | image_set="val", 58 | transform=self.val_transforms, 59 | target_transform=self.val_target_transforms) 60 | 61 | print(f"Train size {len(self.coco_train)}") 62 | print(f"Val size {len(self.coco_val)}") 63 | 64 | def __len__(self): 65 | return len(self.file_list) 66 | 67 | def train_dataloader(self): 68 | return DataLoader(self.coco_train, batch_size=self.batch_size, 69 | shuffle=self.shuffle, num_workers=self.num_workers, 70 | drop_last=True, pin_memory=True) 71 | 72 | def val_dataloader(self): 73 | return DataLoader(self.coco_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, 74 | drop_last=False, pin_memory=True) 75 | 76 | 77 | class COCOSegmentation(VisionDataset): 78 | 79 | def __init__( 80 | self, 81 | root: str, 82 | file_names: List[str], 83 | mask_type: str, 84 | image_set: str = "train", 85 | transform: Optional[Callable] = None, 86 | target_transform: Optional[Callable] = None, 87 | transforms: Optional[Callable] = None, 88 | ): 89 | super(COCOSegmentation, self).__init__(root, transforms, transform, target_transform) 90 | self.image_set = image_set 91 | self.file_names = file_names 92 | self.mask_type = mask_type 93 | assert self.image_set in ["train", "val"] 94 | assert mask_type in ["all", "stuff", "thing"] 95 | 96 | # Set mask folder depending on mask_type 97 | if mask_type == "all": 98 | seg_folder = "annotations/stuffthingmaps_trainval2017/{}2017/" 99 | json_file = "annotations/stuffthingmaps_trainval2017/stuffthing_2017.json" 100 | elif mask_type == "thing": 101 | seg_folder = "annotations/panoptic_annotations/semantic_segmentation_{}2017/" 102 | json_file = "annotations/panoptic_annotations/panoptic_val2017.json" 103 | elif mask_type == "stuff": 104 | seg_folder = "annotations/stuff_annotations/stuff_{}2017_pixelmaps/" 105 | json_file = "annotations/stuff_annotations/stuff_val2017.json" 106 | else: 107 | raise ValueError(f"No support for image set {self.image_set}") 108 | seg_folder = seg_folder.format(image_set) 109 | 110 | # Load categories to category to id map for merging to coarse categories 111 | with open(os.path.join(root, json_file)) as f: 112 | an_json = json.load(f) 113 | all_cat = an_json['categories'] 114 | if mask_type == "all": 115 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat]) 116 | super_cats.remove("other") # remove others from prediction targets as this is not semantic 117 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))} 118 | super_cat_to_id["other"] = 255 # ignore_index for CE 119 | # Align 'id' labels: PNG_label = GT_label - 1 120 | self.cat_id_map = {(cat_dict['id']-1): super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat} 121 | elif mask_type == "thing": 122 | all_thing_cat_sup = set(cat_dict["supercategory"] for cat_dict in all_cat if cat_dict["isthing"] == 1) 123 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(all_thing_cat_sup))} 124 | self.cat_id_map = {} 125 | for cat_dict in all_cat: 126 | if cat_dict["isthing"] == 1: 127 | self.cat_id_map[cat_dict["id"]] = super_cat_to_id[cat_dict["supercategory"]] 128 | elif cat_dict["isthing"] == 0: 129 | self.cat_id_map[cat_dict["id"]] = 255 130 | elif mask_type == "stuff": 131 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat]) 132 | super_cats.remove("other") # remove others from prediction targets as this is not semantic 133 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))} 134 | super_cat_to_id["other"] = 255 # ignore_index for CE 135 | self.cat_id_map = {cat_dict['id']: super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat} 136 | 137 | # Get images and masks fnames 138 | seg_dir = os.path.join(root, seg_folder) 139 | image_dir = os.path.join(root, "images", f"{image_set}2017") 140 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir): 141 | print(seg_dir) 142 | print(image_dir) 143 | raise RuntimeError('Dataset not found or corrupted.') 144 | self.images = [os.path.join(image_dir, x) for x in self.file_names] 145 | self.masks = [os.path.join(seg_dir, x.replace("jpg", "png")) for x in self.file_names] 146 | 147 | def __len__(self): 148 | return len(self.file_names) 149 | 150 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 151 | img = Image.open(self.images[index]).convert('RGB') 152 | mask = Image.open(self.masks[index]) 153 | 154 | if self.transforms: 155 | img, mask = self.transforms(img, mask) 156 | 157 | if self.mask_type == "all": 158 | # move 'id' labels from [0, 182] to [0,26], and 255=={182,255} 159 | # (183 is 'other' and 0 is things) 160 | mask *= 255 161 | assert torch.min(mask).item() >= 0 162 | mask[mask == 255] = 182 163 | assert torch.max(mask).item() <= 182 164 | for cat_id in torch.unique(mask): 165 | mask[mask == cat_id] = self.cat_id_map[cat_id.item()] 166 | 167 | assert torch.max(mask).item() <= 255 168 | assert torch.min(mask).item() >= 0 169 | mask /= 255 170 | return img, mask 171 | elif self.mask_type == "stuff": 172 | # move stuff labels from {0} U [92, 183] to [0,15] and [255] with 255 == {0, 183} 173 | # (183 is 'other' and 0 is things) 174 | mask *= 255 175 | assert torch.max(mask).item() <= 183 176 | mask[mask == 0] = 183 # [92, 183] 177 | assert torch.min(mask).item() >= 92 178 | for cat_id in torch.unique(mask): 179 | mask[mask == cat_id] = self.cat_id_map[cat_id.item()] 180 | 181 | assert torch.max(mask).item() <= 255 182 | assert torch.min(mask).item() >= 0 183 | mask /= 255 184 | return img, mask 185 | elif self.mask_type == "thing": 186 | mask *= 255 187 | assert torch.max(mask).item() <= 200 188 | mask[mask == 0] = 200 # map unlabelled to stuff 189 | merged_mask = mask.clone() 190 | for cat_id in torch.unique(mask): 191 | merged_mask[mask == cat_id] = self.cat_id_map[int(cat_id.item())] # [0, 11] + {255} 192 | 193 | assert torch.max(merged_mask).item() <= 255 194 | assert torch.min(merged_mask).item() >= 0 195 | merged_mask /= 255 196 | return img, merged_mask 197 | return img, mask 198 | 199 | 200 | class UnlabelledCoco(Dataset): 201 | 202 | def __init__(self, file_list, transforms, data_dir): 203 | self.file_names = file_list 204 | self.transform = transforms 205 | self.data_dir = data_dir 206 | 207 | def __len__(self): 208 | return len(self.file_names) 209 | 210 | def __getitem__(self, idx): 211 | img_path = self.file_names[idx] 212 | image = Image.open(os.path.join(self.data_dir, img_path)).convert('RGB') 213 | if self.transform: 214 | image = self.transform(image) 215 | return image 216 | -------------------------------------------------------------------------------- /evaluate/visualize_segment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import cv2 5 | import random 6 | import colorsys 7 | 8 | import skimage.io 9 | from skimage.measure import find_contours 10 | from matplotlib.patches import Polygon 11 | import torch 12 | import torch.nn as nn 13 | import torchvision 14 | from torchvision import transforms as pth_transforms 15 | from torchvision.transforms import GaussianBlur 16 | import torch.nn.functional as F 17 | import numpy as np 18 | from PIL import Image 19 | from skimage.measure import label 20 | from matplotlib import pyplot as plt 21 | 22 | from model.align_model import AlignSegmentor 23 | from utils import neq_load_external 24 | 25 | 26 | def norm(t): 27 | return F.normalize(t, dim=-1, eps=1e-10) 28 | 29 | 30 | def apply_mask(image, mask, color, alpha=0.5): 31 | for c in range(3): 32 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 33 | return image 34 | 35 | 36 | def random_colors(N, bright=True): 37 | """ 38 | Generate random colors. 39 | """ 40 | brightness = 1.0 if bright else 0.7 41 | hsv = [(i / N, 1, brightness) for i in range(N)] 42 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 43 | random.shuffle(colors) 44 | return colors 45 | 46 | 47 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): 48 | fig = plt.figure(figsize=figsize, frameon=False) 49 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 50 | ax.set_axis_off() 51 | fig.add_axes(ax) 52 | ax = plt.gca() 53 | 54 | N = 1 55 | mask = mask[None, :, :] 56 | # Generate random colors 57 | colors = random_colors(N) 58 | 59 | # Show area outside image boundaries. 60 | height, width = image.shape[:2] 61 | margin = 0 62 | ax.set_ylim(height + margin, -margin) 63 | ax.set_xlim(-margin, width + margin) 64 | ax.axis('off') 65 | masked_image = image.astype(np.uint32).copy() 66 | 67 | for i in range(N): 68 | color = colors[i] 69 | _mask = mask[i] 70 | if blur: 71 | _mask = cv2.blur(_mask, (10, 10)) 72 | # Mask 73 | masked_image = apply_mask(masked_image, _mask, color, alpha) 74 | # Mask Polygon 75 | # Pad to ensure proper polygons for masks that touch image edges. 76 | if contour: 77 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) 78 | padded_mask[1:-1, 1:-1] = _mask 79 | contours = find_contours(padded_mask, 0.5) 80 | for verts in contours: 81 | # Subtract the padding and flip (y, x) to (x, y) 82 | verts = np.fliplr(verts) - 1 83 | p = Polygon(verts, facecolor="none", edgecolor=color) 84 | ax.add_patch(p) 85 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 86 | fig.savefig(fname) 87 | print(f"{fname} saved.") 88 | 89 | 90 | def display_segments(image, masks, fname="test", figsize=(5, 5), alpha=0.7): 91 | N = 5 92 | # Generate random colors 93 | # colors = random_colors(N) 94 | colors = [(128, 0, 0), (30, 144, 255), (75, 0, 130), (184, 134, 11), (0, 128, 0)] 95 | colors = [(x / 255, y / 255, z / 255) for (x, y, z) in colors] 96 | 97 | fig = plt.figure(figsize=figsize, frameon=False) 98 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 99 | ax.set_axis_off() 100 | fig.add_axes(ax) 101 | ax = plt.gca() 102 | 103 | # Show area outside image boundaries. 104 | height, width = image.shape[:2] 105 | margin = 0 106 | ax.set_ylim(height + margin, -margin) 107 | ax.set_xlim(-margin, width + margin) 108 | ax.axis('off') 109 | 110 | for i in range(N): 111 | color = colors[i] 112 | _mask = masks[i] 113 | # Mask 114 | masked_image = image.astype(np.uint32).copy() 115 | masked_image = apply_mask(masked_image, _mask, color, alpha) 116 | 117 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 118 | file = os.path.join(fname, "cls" + str(i) + ".png") 119 | fig.savefig(file) 120 | print(f"{file} saved.") 121 | 122 | 123 | def display_allsegments(image, masks, n=5, fname="test", figsize=(5, 5), alpha=0.6): 124 | N = n # num of colors 125 | # colors = [(128,0,0), (184,134,11), (0,128,0), (62,78,94), (0,0,0)] # last two backgrounds 126 | colors = [(128,0,0), (30,144,255), (75,0,130), (184,134,11), (0,128,0)] # for coco 127 | colors = [(x/255, y/255, z/255) for (x, y, z) in colors] 128 | print(colors) 129 | 130 | # Generate random colors 131 | # colors = random_colors(N) 132 | 133 | fig = plt.figure(figsize=figsize, frameon=False) 134 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 135 | ax.set_axis_off() 136 | fig.add_axes(ax) 137 | ax = plt.gca() 138 | 139 | # Show area outside image boundaries. 140 | height, width = image.shape[:2] 141 | margin = 0 142 | ax.set_ylim(height + margin, -margin) 143 | ax.set_xlim(-margin, width + margin) 144 | ax.axis('off') 145 | 146 | masked_image = image.astype(np.uint32).copy() 147 | for i in range(N): 148 | color = colors[i] 149 | _mask = masks[i] 150 | # Mask 151 | masked_image = apply_mask(masked_image, _mask, color, alpha) 152 | 153 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 154 | file = os.path.join(fname, "cls" + str(i) + ".png") 155 | fig.savefig(file) 156 | print(f"{file} saved.") 157 | 158 | 159 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \ 160 | -> torch.Tensor: 161 | """ 162 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes 163 | components smaller than 3 pixels. 164 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the 165 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions 166 | so that we can use the foreground mask during training time. 167 | :param attentions: torch 4D-Tensor containing the averaged attentions 168 | :param spatial_res: spatial resolution of the attention map 169 | :param threshold: the percentage of mass to keep as foreground. 170 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring. 171 | :return: the foreground mask obtained from the ViT's attention. 172 | """ 173 | # Blur attentions 174 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions) 175 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2) 176 | # Keep threshold% of mass 177 | val, idx = torch.sort(attentions) 178 | val /= torch.sum(val, dim=-1, keepdim=True) 179 | cumval = torch.cumsum(val, dim=-1) 180 | th_attn = cumval > (1 - threshold) 181 | idx2 = torch.argsort(idx) 182 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0]) 183 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float() 184 | # Remove components with less than 3 pixels 185 | for j, th_att in enumerate(th_attn): 186 | labelled = label(th_att.cpu().numpy()) 187 | for k in range(1, np.max(labelled) + 1): 188 | mask = labelled == k 189 | if np.sum(mask) <= 2: 190 | th_attn[j, 0][mask] = 0 191 | return th_attn 192 | 193 | 194 | if __name__ == '__main__': 195 | parser = argparse.ArgumentParser('Evaluate segmentation on pretrained model') 196 | parser.add_argument('--pretrained_weights', default='./epoch10.pth', 197 | type=str, help="Path to pretrained weights to load.") 198 | parser.add_argument("--image_path", default='', type=str, help="Path of the image to load.") 199 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 200 | parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.") 201 | parser.add_argument('--output_dir', default='./outputs/', help='Path where to save visualizations.') 202 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks 203 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 204 | args = parser.parse_args() 205 | 206 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 207 | # build model 208 | # ''' 209 | model = AlignSegmentor(arch='vit_small', 210 | patch_size=16, 211 | embed_dim=384, 212 | hidden_dim=768, 213 | num_heads=3, 214 | num_queries=5, 215 | nmb_crops=[1, 0], 216 | num_decode_layers=1, 217 | last_self_attention=True) 218 | 219 | # set model to eval mode 220 | for p in model.parameters(): 221 | p.requires_grad = False 222 | model.eval() 223 | model.to(device) 224 | 225 | # load pretrained weights 226 | if os.path.isfile(args.pretrained_weights): 227 | pratrained_model = torch.load(args.pretrained_weights, map_location="cpu") 228 | msg = model.load_state_dict(pratrained_model['state_dict'], strict=False) 229 | print(msg) 230 | else: 231 | print('no pretrained pth found!') 232 | 233 | if os.path.isfile(args.image_path): 234 | with open(args.image_path, 'rb') as f: 235 | img = Image.open(f) 236 | img = img.convert('RGB') 237 | else: 238 | print(f"Provided image path {args.image_path} is non valid.") 239 | sys.exit(1) 240 | transform = pth_transforms.Compose([ 241 | pth_transforms.Resize(args.image_size), 242 | pth_transforms.ToTensor(), 243 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 244 | ]) 245 | img = transform(img) 246 | 247 | # make the image divisible by the patch size 248 | # img = c, w, h (3, 480, 480); unsqueeze -> (1, 3, 480, 480) 249 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size 250 | img = img[:, :w, :h].unsqueeze(0) 251 | 252 | w_spatial_res = img.shape[-2] // args.patch_size 253 | 254 | # get aligned_queries, spatial_token_output and attention_map 255 | all_queries, gc_output, _, attn_hard, _, _ = model([img.to(device)], threshold=args.threshold) 256 | 257 | os.makedirs(args.output_dir, exist_ok=True) 258 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), 259 | os.path.join(args.output_dir, "img.png")) 260 | 261 | # interpolate binary mask 262 | attn_hard = nn.functional.interpolate(attn_hard.unsqueeze(1), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 263 | image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) 264 | display_instances(image, attn_hard[0], fname=os.path.join(args.output_dir, "mask_" + str(args.threshold) + ".png"), blur=False) 265 | 266 | # calculate query assignment score 267 | gc_token_sim = torch.einsum("bnc,bqc->bnq", norm(gc_output), norm(all_queries[0])) 268 | gc_token_cls = torch.softmax(gc_token_sim, dim=-1) 269 | gc_token_cls = gc_token_cls.reshape(1, w_spatial_res, w_spatial_res, -1).permute(0, 3, 1, 2) 270 | 271 | # Smooth interpolation 272 | masks_prob = F.interpolate(gc_token_cls, size=w, mode='bilinear') 273 | masks_oh = masks_prob.argmax(dim=1) 274 | masks_oh = torch.nn.functional.one_hot(masks_oh, masks_prob.shape[1]) 275 | masks_oh = masks_oh.squeeze(dim=0).permute(2, 0, 1) 276 | 277 | masks = [] 278 | for i in range(masks_prob.shape[1]): 279 | mask = masks_oh[i].cpu().numpy() 280 | # print('mask = ', mask.shape, mask) 281 | masks.append(mask) 282 | display_allsegments(image, masks, n=5, fname=args.output_dir) 283 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import yaml 7 | 8 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize 9 | from torchvision.transforms.functional import InterpolationMode 10 | 11 | from data.coco_data import CocoDataModule 12 | from data.voc_data import VOCDataModule 13 | from model.align_model import AlignSegmentor 14 | from model.criterion import AlignCriterion 15 | from model.transforms import TrainTransforms 16 | from utils import AverageMeter, save_checkpoint, neq_load_external 17 | 18 | 19 | def set_path(config): 20 | if config['train']['checkpoint']: 21 | model_path = os.path.dirname(config['train']['checkpoint']) 22 | else: 23 | model_path = './log_tmp/{0}-{1}-bs{2}/model'.format(config["data"]["dataset_name"], 24 | config["train"]["arch"], 25 | config["train"]["batch_size"]) 26 | 27 | if not os.path.exists(model_path): os.makedirs(model_path) 28 | return model_path 29 | 30 | 31 | def exclude_from_wt_decay(named_params, weight_decay: float, lr: float): 32 | params = [] 33 | excluded_params = [] 34 | query_param = [] 35 | 36 | for name, param in named_params: 37 | if not param.requires_grad: 38 | continue 39 | # do not regularize biases nor Norm parameters 40 | if name.endswith(".bias") or len(param.shape) == 1: 41 | excluded_params.append(param) 42 | elif 'clsQueries' in name: 43 | query_param.append(param) 44 | else: 45 | params.append(param) 46 | return [{'params': params, 'weight_decay': weight_decay, 'lr': lr}, 47 | {'params': excluded_params, 'weight_decay': 0., 'lr': lr}, 48 | {'params': query_param, 'weight_decay': 0., 'lr': lr * 1}] 49 | 50 | 51 | def configure_optimizers(model, train_config): 52 | # Separate Decoder params from ViT params 53 | # only train Decoder 54 | decoder_params_named = [] 55 | for name, param in model.named_parameters(): 56 | if name.startswith("backbone"): 57 | param.requires_grad = False 58 | elif train_config['fix_prototypes'] and 'clsQueries' in name: 59 | param.requires_grad = False 60 | else: 61 | decoder_params_named.append((name, param)) 62 | 63 | # Prepare param groups. Exclude norm and bias from weight decay if flag set. 64 | if train_config['exclude_norm_bias']: 65 | params = exclude_from_wt_decay(decoder_params_named, 66 | weight_decay=train_config["weight_decay"], 67 | lr=train_config['lr_decoder']) 68 | else: 69 | decoder_params = [param for _, param in decoder_params_named] 70 | params = [{'params': decoder_params, 'lr': train_config['lr_decoder']}] 71 | 72 | # Init optimizer and lr schedule 73 | optimizer = torch.optim.AdamW(params, weight_decay=train_config["weight_decay"]) 74 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5) 75 | 76 | return optimizer, scheduler 77 | 78 | 79 | def start_train(): 80 | with open(args.config_path) as file: 81 | config = yaml.safe_load(file.read()) 82 | # print('Config: ', config) 83 | 84 | data_config = config['data'] 85 | train_config = config['train'] 86 | torch.manual_seed(train_config['seed']) 87 | torch.cuda.manual_seed_all(train_config['seed']) 88 | 89 | # Init data modules and tranforms 90 | dataset_name = data_config["dataset_name"] 91 | train_transforms = TrainTransforms(size_crops=data_config["size_crops"], 92 | nmb_crops=data_config["nmb_crops"], 93 | min_intersection=data_config["min_intersection_crops"], 94 | min_scale_crops=data_config["min_scale_crops"], 95 | max_scale_crops=data_config["max_scale_crops"], 96 | augment_image=data_config["augment_image"]) 97 | 98 | # Setup voc dataset used for evaluation 99 | val_size = data_config["size_crops_val"] 100 | val_image_transforms = Compose([Resize((val_size, val_size)), 101 | ToTensor(), 102 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 103 | val_target_transforms = Compose([Resize((val_size, val_size), interpolation=InterpolationMode.NEAREST), 104 | ToTensor()]) 105 | 106 | # Setup train data 107 | if dataset_name == "voc": 108 | train_data_module = VOCDataModule(batch_size=train_config["batch_size"], 109 | num_workers=config["num_workers"], 110 | train_split="trainaug", 111 | val_split="val", 112 | data_dir=data_config["voc_data_path"], 113 | train_image_transform=train_transforms, 114 | val_image_transform=val_image_transforms, 115 | val_target_transform=val_target_transforms) 116 | elif dataset_name == 'coco': 117 | file_list = os.listdir(os.path.join(data_config["data_dir"], "images/train2017")) 118 | train_data_module = CocoDataModule(batch_size=train_config["batch_size"], 119 | num_workers=config["num_workers"], 120 | file_list=file_list, 121 | data_dir=data_config["data_dir"], 122 | train_transforms=train_transforms, 123 | val_transforms=None) 124 | elif 'movi' in dataset_name: 125 | train_data_module = MOViDataModule(data_dir=data_config["data_dir"], 126 | dataset_name=data_config['dataset_name'], 127 | batch_size=train_config["batch_size"], 128 | num_workers=config["num_workers"], 129 | train_split="frames", 130 | val_split="images", 131 | train_image_transform=train_transforms, 132 | val_image_transform=val_image_transforms, 133 | val_target_transform=val_target_transforms) 134 | else: 135 | raise ValueError(f"Data set {dataset_name} not supported") 136 | 137 | model_path = set_path(config) 138 | 139 | model = AlignSegmentor(arch=train_config['arch'], 140 | patch_size=train_config['patch_size'], 141 | embed_dim=train_config['embed_dim'], 142 | hidden_dim=train_config['hidden_dim'], 143 | num_heads=train_config['decoder_num_heads'], 144 | num_queries=train_config['num_queries'], 145 | nmb_crops=data_config["nmb_crops"], 146 | num_decode_layers=train_config['num_decode_layers'], 147 | last_self_attention=train_config['last_self_attention']) 148 | model = model.to(cuda) 149 | 150 | criterion = AlignCriterion(patch_size=train_config['patch_size'], 151 | num_queries=train_config['num_queries'], 152 | nmb_crops=data_config["nmb_crops"], 153 | roi_align_kernel_size=train_config['roi_align_kernel_size'], 154 | ce_temperature=train_config['ce_temperature'], 155 | negative_pressure=train_config['negative_pressure'], 156 | last_self_attention=train_config['last_self_attention']) 157 | criterion = criterion.to(cuda) 158 | 159 | # Initialize model 160 | start_epoch = 0 161 | if train_config["checkpoint"] is not None: 162 | checkpoint = torch.load(train_config["checkpoint"]) 163 | start_epoch = checkpoint['epoch'] 164 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True) 165 | print(msg) 166 | elif train_config["checkpoint"] is None \ 167 | and train_config["pretrained_model"] is not None \ 168 | and train_config["prototype_queries"] is not None: 169 | # initialize model with pre-trained ViT and prepared Prototypes 170 | pretrained_model = torch.load(train_config["pretrained_model"], map_location=torch.device('cpu')) 171 | neq_load_external(model, pretrained_model) 172 | protos = torch.load(train_config["prototype_queries"]).to(cuda) 173 | model.set_clsQuery(protos) 174 | elif train_config["checkpoint"] is None \ 175 | and train_config["pretrained_model"] is not None \ 176 | and train_config["prototype_queries"] is None: 177 | # only load pre-trained ViT 178 | pretrained_model = torch.load(train_config["pretrained_model"], map_location=torch.device('cpu')) 179 | neq_load_external(model, pretrained_model) 180 | 181 | # Optionally fix ViT, Queries 182 | optimizer, scheduler = configure_optimizers(model, train_config) 183 | dataloader = train_data_module.train_dataloader() 184 | 185 | for epoch in range(start_epoch, train_config['max_epochs']): 186 | 187 | train(dataloader, model, optimizer, criterion, epoch) 188 | 189 | scheduler.step() 190 | print('\t Epoch: ', epoch, 'with lr: ', scheduler.get_last_lr()) 191 | 192 | if epoch % train_config['save_checkpoint_every_n_epochs'] == 0: 193 | # save check_point 194 | save_checkpoint({'epoch': epoch + 1, 195 | 'net': train_config['arch'], 196 | 'state_dict': model.state_dict(), 197 | 'optimizer': optimizer.state_dict(), 198 | }, gap=train_config['save_checkpoint_every_n_epochs'], 199 | filename=os.path.join(model_path, 'epoch%s.pth' % str(epoch + 1)), keep_all=False) 200 | 201 | print('Training %d epochs finished' % (train_config['max_epochs'])) 202 | 203 | 204 | def train(data_loader, model, optimizer, criterion, epoch): 205 | losses = AverageMeter() 206 | model.train() 207 | 208 | for idx, batch in enumerate(data_loader): 209 | inputs, bboxes = batch # inputs = [sum(num_crops), (B, 3, w, h)] 210 | B = inputs[0].size(0) 211 | tic = time.time() 212 | for i in range(len(inputs)): 213 | inputs[i] = inputs[i].to(cuda, non_blocking=True) 214 | bboxes['gc'] = bboxes['gc'].to(cuda, non_blocking=True) 215 | bboxes['all'] = bboxes['all'].to(cuda, non_blocking=True) 216 | 217 | results = model(inputs) 218 | 219 | # Calculate loss 220 | loss = criterion(results, bboxes) 221 | losses.update(loss.item(), B, step=len(data_loader)) 222 | 223 | optimizer.zero_grad() 224 | loss.backward() 225 | optimizer.step() 226 | 227 | if idx % 1 == 0: 228 | print('Epoch: [{0}][{1}/{2}]\t' 229 | 'Loss {loss.val:.4f} ({loss.local_avg:.4f}) Time:{3:.2f}\t'. 230 | format(epoch, idx, len(data_loader), time.time() - tic, loss=losses)) 231 | 232 | return losses.local_avg 233 | 234 | 235 | if __name__ == '__main__': 236 | parser = argparse.ArgumentParser() 237 | parser.add_argument('--config_path', default='./configs/train_voc_config.yml', type=str) 238 | parser.add_argument('--gpu', default='0', type=str) 239 | 240 | args = parser.parse_args() 241 | 242 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 243 | cuda = torch.device('cuda') 244 | start_train() 245 | -------------------------------------------------------------------------------- /data/stuffthing_2017.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": {"description": "COCO 2017 Stuff Dataset", "url": "http://cocodataset.org", "version": "1.0", "year": 2017, "contributor": "H. Caesar, J. Uijlings, M. Maire, T.-Y. Lin, P. Dollar and V. Ferrari", "date_created": "2017-08-31 00:00:00.0"}, 3 | "categories": [{"supercategory": "person", "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "plant", "isthing": 0, "id": 94, "name": "branch"}, {"supercategory": "building", "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "building", "isthing": 0, "id": 96, "name": "building-other"}, {"supercategory": "plant", "isthing": 0, "id": 97, "name": "bush"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 98, "name": "cabinet"}, {"supercategory": "structural", "isthing": 0, "id": 99, "name": "cage"}, {"supercategory": "raw-material", "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "floor", "isthing": 0, "id": 101, "name": "carpet"}, {"supercategory": "ceiling", "isthing": 0, "id": 102, "name": "ceiling-other"}, {"supercategory": "ceiling", "isthing": 0, "id": 103, "name": "ceiling-tile"}, {"supercategory": "textile", "isthing": 0, "id": 104, "name": "cloth"}, {"supercategory": "textile", "isthing": 0, "id": 105, "name": "clothes"}, {"supercategory": "sky", "isthing": 0, "id": 106, "name": "clouds"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 108, "name": "cupboard"}, {"supercategory": "textile", "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 110, "name": "desk-stuff"}, {"supercategory": "ground", "isthing": 0, "id": 111, "name": "dirt"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "structural", "isthing": 0, "id": 113, "name": "fence"}, {"supercategory": "floor", "isthing": 0, "id": 114, "name": "floor-marble"}, {"supercategory": "floor", "isthing": 0, "id": 115, "name": "floor-other"}, {"supercategory": "floor", "isthing": 0, "id": 116, "name": "floor-stone"}, {"supercategory": "floor", "isthing": 0, "id": 117, "name": "floor-tile"}, {"supercategory": "floor", "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "water", "isthing": 0, "id": 120, "name": "fog"}, {"supercategory": "food-stuff", "isthing": 0, "id": 121, "name": "food-other"}, {"supercategory": "food-stuff", "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 123, "name": "furniture-other"}, {"supercategory": "plant", "isthing": 0, "id": 124, "name": "grass"}, {"supercategory": "ground", "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "ground", "isthing": 0, "id": 126, "name": "ground-other"}, {"supercategory": "solid", "isthing": 0, "id": 127, "name": "hill"}, {"supercategory": "building", "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "plant", "isthing": 0, "id": 129, "name": "leaves"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "textile", "isthing": 0, "id": 131, "name": "mat"}, {"supercategory": "raw-material", "isthing": 0, "id": 132, "name": "metal"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "plant", "isthing": 0, "id": 134, "name": "moss"}, {"supercategory": "solid", "isthing": 0, "id": 135, "name": "mountain"}, {"supercategory": "ground", "isthing": 0, "id": 136, "name": "mud"}, {"supercategory": "textile", "isthing": 0, "id": 137, "name": "napkin"}, {"supercategory": "structural", "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "raw-material", "isthing": 0, "id": 139, "name": "paper"}, {"supercategory": "ground", "isthing": 0, "id": 140, "name": "pavement"}, {"supercategory": "textile", "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "plant", "isthing": 0, "id": 142, "name": "plant-other"}, {"supercategory": "raw-material", "isthing": 0, "id": 143, "name": "plastic"}, {"supercategory": "ground", "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "structural", "isthing": 0, "id": 146, "name": "railing"}, {"supercategory": "ground", "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "solid", "isthing": 0, "id": 150, "name": "rock"}, {"supercategory": "building", "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "textile", "isthing": 0, "id": 152, "name": "rug"}, {"supercategory": "food-stuff", "isthing": 0, "id": 153, "name": "salad"}, {"supercategory": "ground", "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "sky", "isthing": 0, "id": 157, "name": "sky-other"}, {"supercategory": "building", "isthing": 0, "id": 158, "name": "skyscraper"}, {"supercategory": "ground", "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "solid", "isthing": 0, "id": 160, "name": "solid-other"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "solid", "isthing": 0, "id": 162, "name": "stone"}, {"supercategory": "plant", "isthing": 0, "id": 163, "name": "straw"}, {"supercategory": "structural", "isthing": 0, "id": 164, "name": "structural-other"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 165, "name": "table"}, {"supercategory": "building", "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "isthing": 0, "id": 167, "name": "textile-other"}, {"supercategory": "textile", "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "plant", "isthing": 0, "id": 169, "name": "tree"}, {"supercategory": "food-stuff", "isthing": 0, "id": 170, "name": "vegetable"}, {"supercategory": "wall", "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "isthing": 0, "id": 172, "name": "wall-concrete"}, {"supercategory": "wall", "isthing": 0, "id": 173, "name": "wall-other"}, {"supercategory": "wall", "isthing": 0, "id": 174, "name": "wall-panel"}, {"supercategory": "wall", "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "water", "isthing": 0, "id": 179, "name": "waterdrops"}, {"supercategory": "window", "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "solid", "isthing": 0, "id": 182, "name": "wood"}, {"supercategory": "other", "isthing": 0, "id": 183, "name": "other"}] 4 | } -------------------------------------------------------------------------------- /model/vit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | 6 | from functools import partial 7 | 8 | 9 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 10 | def norm_cdf(x): 11 | # Computes standard normal cumulative distribution function 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | 19 | with torch.no_grad(): 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 46 | 47 | 48 | def drop_path(x, drop_prob: float = 0., training: bool = False): 49 | if drop_prob == 0. or not training: 50 | return x 51 | keep_prob = 1 - drop_prob 52 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 53 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 54 | random_tensor.floor_() # binarize 55 | output = x.div(keep_prob) * random_tensor 56 | return output 57 | 58 | 59 | class DropPath(nn.Module): 60 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 61 | """ 62 | 63 | def __init__(self, drop_prob=None): 64 | super(DropPath, self).__init__() 65 | self.drop_prob = drop_prob 66 | 67 | def forward(self, x): 68 | return drop_path(x, self.drop_prob, self.training) 69 | 70 | 71 | class Mlp(nn.Module): 72 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 73 | super().__init__() 74 | out_features = out_features or in_features 75 | hidden_features = hidden_features or in_features 76 | self.fc1 = nn.Linear(in_features, hidden_features) 77 | self.act = act_layer() 78 | self.fc2 = nn.Linear(hidden_features, out_features) 79 | self.drop = nn.Dropout(drop) 80 | 81 | def forward(self, x): 82 | x = self.fc1(x) 83 | x = self.act(x) 84 | x = self.drop(x) 85 | x = self.fc2(x) 86 | x = self.drop(x) 87 | return x 88 | 89 | 90 | class Attention(nn.Module): 91 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 92 | super().__init__() 93 | self.num_heads = num_heads 94 | head_dim = dim // num_heads 95 | self.scale = qk_scale or head_dim ** -0.5 96 | 97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 98 | self.attn_drop = nn.Dropout(attn_drop) 99 | self.proj = nn.Linear(dim, dim) 100 | self.proj_drop = nn.Dropout(proj_drop) 101 | 102 | def forward(self, x): 103 | B, N, C = x.shape 104 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 105 | q, k, v = qkv[0], qkv[1], qkv[2] 106 | 107 | attn = (q @ k.transpose(-2, -1)) * self.scale 108 | attn = attn.softmax(dim=-1) 109 | attn = self.attn_drop(attn) 110 | 111 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 112 | x = self.proj(x) 113 | x = self.proj_drop(x) 114 | return x, attn 115 | 116 | 117 | class Block(nn.Module): 118 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 119 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 120 | super().__init__() 121 | self.norm1 = norm_layer(dim) 122 | self.attn = Attention( 123 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 124 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 125 | self.norm2 = norm_layer(dim) 126 | mlp_hidden_dim = int(dim * mlp_ratio) 127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 128 | 129 | def forward(self, x, return_attention=False): 130 | y, attn = self.attn(self.norm1(x)) 131 | x = x + self.drop_path(y) 132 | x = x + self.drop_path(self.mlp(self.norm2(x))) 133 | if return_attention: 134 | return x, attn 135 | return x 136 | 137 | 138 | class PatchEmbed(nn.Module): 139 | """ Image to Patch Embedding 140 | """ 141 | 142 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 143 | super().__init__() 144 | num_patches = (img_size // patch_size) * (img_size // patch_size) 145 | self.img_size = img_size 146 | self.patch_size = patch_size 147 | self.num_patches = num_patches 148 | 149 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 150 | 151 | def forward(self, x): 152 | B, C, H, W = x.shape 153 | x = self.proj(x).flatten(2).transpose(1, 2) 154 | return x 155 | 156 | 157 | class VisionTransformer(nn.Module): 158 | """ Vision Transformer """ 159 | 160 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 161 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 162 | norm_layer=nn.LayerNorm): 163 | super().__init__() 164 | self.num_features = self.embed_dim = embed_dim 165 | 166 | self.patch_embed = PatchEmbed( 167 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 168 | num_patches = self.patch_embed.num_patches 169 | 170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 171 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 172 | self.pos_drop = nn.Dropout(p=drop_rate) 173 | 174 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 175 | self.blocks = nn.ModuleList([ 176 | Block( 177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 178 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 179 | for i in range(depth)]) 180 | self.norm = norm_layer(embed_dim) 181 | 182 | trunc_normal_(self.pos_embed, std=.02) 183 | trunc_normal_(self.cls_token, std=.02) 184 | self.apply(self._init_weights) 185 | 186 | def _init_weights(self, m): 187 | if isinstance(m, nn.Linear): 188 | trunc_normal_(m.weight, std=.02) 189 | if m.bias is not None: 190 | nn.init.constant_(m.bias, 0) 191 | elif isinstance(m, nn.LayerNorm): 192 | nn.init.constant_(m.bias, 0) 193 | nn.init.constant_(m.weight, 1.0) 194 | elif isinstance(m, nn.Conv2d): 195 | trunc_normal_(m.weight, std=.02) 196 | if m.bias is not None: 197 | nn.init.constant_(m.bias, 0) 198 | 199 | def interpolate_pos_encoding(self, x, w, h): 200 | npatch = x.shape[1] - 1 201 | N = self.pos_embed.shape[1] - 1 202 | if npatch == N and w == h: 203 | return self.pos_embed 204 | class_pos_embed = self.pos_embed[:, 0] 205 | patch_pos_embed = self.pos_embed[:, 1:] 206 | dim = x.shape[-1] 207 | w0 = w // self.patch_embed.patch_size 208 | h0 = h // self.patch_embed.patch_size 209 | # we add a small number to avoid floating point error in the interpolation 210 | w0, h0 = w0 + 0.1, h0 + 0.1 211 | patch_pos_embed = nn.functional.interpolate( 212 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 213 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 214 | mode='bicubic', 215 | ) 216 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 217 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 218 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 219 | 220 | def get_intermediate_layers(self, x, n=1): 221 | x = self.prepare_tokens(x) 222 | # we return the output tokens from the `n` last blocks 223 | output = [] 224 | for i, blk in enumerate(self.blocks): 225 | x = blk(x) 226 | if len(self.blocks) - i <= n: 227 | output.append(self.norm(x)) 228 | return output 229 | 230 | def prepare_tokens(self, x): 231 | B, nc, w, h = x.shape 232 | x = self.patch_embed(x) # patch linear embedding 233 | 234 | # add the [CLS] token to the embed patch tokens 235 | cls_tokens = self.cls_token.expand(B, -1, -1) 236 | x = torch.cat((cls_tokens, x), dim=1) 237 | 238 | # add positional encoding to each token 239 | x = x + self.interpolate_pos_encoding(x, w, h) 240 | 241 | return self.pos_drop(x) 242 | 243 | @torch.no_grad() 244 | def forward(self, inputs, nmb_crops=(1,0), last_self_attention=False): 245 | if not isinstance(inputs, list): 246 | inputs = [inputs] 247 | idx_crops = [1, ] # for inference 248 | if sum(nmb_crops) > 1: 249 | # for training 250 | idx_crops.append(sum(nmb_crops)) 251 | 252 | assert len(idx_crops) <= 2, "Only supporting at most two different type of crops (global and local crops)" 253 | start_idx = 0 254 | for end_idx in idx_crops: 255 | _out = torch.cat(inputs[start_idx:end_idx]) 256 | _out = self.forward_backbone(_out, last_self_attention=last_self_attention) 257 | if last_self_attention: 258 | _out, _attn = _out 259 | spatial_tokens = _out[:, 1:] # remove CLS token 260 | spatial_tokens = spatial_tokens.reshape(-1, self.embed_dim) # [B*196/36, embed_dim] 261 | 262 | if start_idx == 0: 263 | output_spatial = spatial_tokens 264 | if last_self_attention: 265 | # only keep 1st global crop attention 266 | attentions = _attn 267 | else: 268 | output_spatial = torch.cat((output_spatial, spatial_tokens)) 269 | if last_self_attention: 270 | attentions = torch.cat((attentions, _attn)) 271 | start_idx = end_idx 272 | 273 | result = output_spatial 274 | if last_self_attention: 275 | result = (result, attentions) 276 | return result 277 | 278 | def forward_backbone(self, x, last_self_attention=False): 279 | x = self.prepare_tokens(x) 280 | for i, blk in enumerate(self.blocks): 281 | if i < len(self.blocks) - 1: 282 | x = blk(x) 283 | else: 284 | x = blk(x, return_attention=last_self_attention) 285 | if last_self_attention: 286 | x, attn = x 287 | x = self.norm(x) 288 | if last_self_attention: 289 | return x, attn[:, :, 0, 1:] # [B, heads, cls, cls-patch] 290 | return x 291 | 292 | def get_last_selfattention(self, x): 293 | x = self.prepare_tokens(x) 294 | for i, blk in enumerate(self.blocks): 295 | if i < len(self.blocks) - 1: 296 | x = blk(x) 297 | else: 298 | # return attention of the last block 299 | return blk(x, return_attention=True)[1] 300 | 301 | def get_cls_tokens(self, x): 302 | x = self.prepare_tokens(x) 303 | for blk in self.blocks: 304 | x = blk(x) 305 | x = self.norm(x) 306 | return x[:, 0] 307 | 308 | 309 | def vit_small(patch_size=16, **kwargs): 310 | model = VisionTransformer( 311 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 312 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 313 | return model 314 | 315 | 316 | def vit_base(patch_size=16, **kwargs): 317 | model = VisionTransformer( 318 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 319 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 320 | return model 321 | 322 | 323 | def vit_large(patch_size=16, **kwargs): 324 | model = VisionTransformer( 325 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 326 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 327 | return model 328 | 329 | 330 | if __name__ == '__main__': 331 | clsQueries = nn.Embedding(2, 6) 332 | print(clsQueries.weight) 333 | trunc_normal_(clsQueries.weight, std=.02) 334 | print(clsQueries.weight, clsQueries.weight.mean(), clsQueries.weight.sum()) 335 | 336 | cls = nn.Parameter(torch.zeros(1, 1, 5)) 337 | trunc_normal_(cls, std=.02) 338 | print(cls, cls.mean(), cls.sum()) 339 | -------------------------------------------------------------------------------- /evaluate/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used in metrics computation.""" 2 | from typing import Optional 3 | 4 | import scipy.optimize 5 | import torch 6 | import torchmetrics 7 | 8 | 9 | class ARIMetric(torchmetrics.Metric): 10 | """Computes ARI metric.""" 11 | 12 | def __init__( 13 | self, 14 | foreground: bool = True, 15 | convert_target_one_hot: bool = False, 16 | ignore_overlaps: bool = False, 17 | ): 18 | super().__init__() 19 | self.foreground = foreground 20 | self.convert_target_one_hot = convert_target_one_hot 21 | self.ignore_overlaps = ignore_overlaps 22 | self.add_state( 23 | "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum" 24 | ) 25 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 26 | 27 | def update( 28 | self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None 29 | ): 30 | """Update this metric. 31 | 32 | Args: 33 | prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the 34 | number of classes. 35 | target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the 36 | number of classes. 37 | ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W) 38 | """ 39 | if prediction.ndim == 5: 40 | # Merge frames, height and width to single dimension. 41 | prediction = prediction.transpose(1, 2).flatten(-3, -1) 42 | target = target.transpose(1, 2).flatten(-3, -1) 43 | if ignore is not None: 44 | ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1) 45 | elif prediction.ndim == 4: 46 | # Merge height and width to single dimension. 47 | prediction = prediction.flatten(-2, -1) 48 | target = target.flatten(-2, -1) 49 | if ignore is not None: 50 | ignore = ignore.to(torch.bool).flatten(-2, -1) 51 | else: 52 | raise ValueError(f"Incorrect input shape: f{prediction.shape}") 53 | 54 | if self.ignore_overlaps: 55 | overlaps = (target > 0).sum(1, keepdim=True) > 1 56 | if ignore is None: 57 | ignore = overlaps 58 | else: 59 | ignore = ignore | overlaps 60 | 61 | if ignore is not None: 62 | assert ignore.ndim == 3 and ignore.shape[1] == 1 63 | prediction = prediction.clone() 64 | prediction[ignore.expand_as(prediction)] = 0 65 | target = target.clone() 66 | target[ignore.expand_as(target)] = 0 67 | 68 | # Make channels / gt labels the last dimension. 69 | prediction = prediction.transpose(-2, -1) 70 | target = target.transpose(-2, -1) 71 | 72 | if self.convert_target_one_hot: 73 | target_oh = tensor_to_one_hot(target, dim=2) 74 | # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for 75 | # this (then it is technically not one-hot anymore). 76 | target_oh[:, :, 0][target.sum(dim=2) == 0] = 0 77 | target = target_oh 78 | 79 | # Should be either 0 (empty, padding) or 1 (single object). 80 | assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive" 81 | 82 | if self.foreground: 83 | ari = fg_adjusted_rand_index(prediction, target) 84 | else: 85 | ari = adjusted_rand_index(prediction, target) 86 | 87 | print("\tupdating ari... ", ari.item()) 88 | 89 | self.values += ari.sum() 90 | self.total += len(ari) 91 | 92 | def compute(self) -> torch.Tensor: 93 | return self.values / self.total 94 | 95 | 96 | class UnsupervisedMaskIoUMetric(torchmetrics.Metric): 97 | """Computes IoU metric for segmentation masks when correspondences to ground truth are not known. 98 | 99 | Uses Hungarian matching to compute the assignment between predicted classes and ground truth 100 | classes. 101 | 102 | Args: 103 | use_threshold: If `True`, convert predicted class probabilities to mask using a threshold. 104 | If `False`, class probabilities are turned into mask using a softmax instead. 105 | threshold: Value to use for thresholding masks. 106 | matching: Approach to match predicted to ground truth classes. For "hungarian", computes 107 | assignment that maximizes total IoU between all classes. For "best_overlap", uses the 108 | predicted class with maximum overlap for each ground truth class. Using "best_overlap" 109 | leads to the "average best overlap" metric. 110 | compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes 111 | that were "discovered", meaning that they have an IoU greater than some threshold. 112 | correct_localization: Instead of the IoU, compute the fraction of images on which at least 113 | one ground truth class was correctly localised, meaning that they have an IoU 114 | greater than some threshold. 115 | discovery_threshold: Minimum IoU to count a class as discovered/correctly localized. 116 | ignore_background: If true, assume class at index 0 of ground truth masks is background class 117 | that is removed before computing IoU. 118 | ignore_overlaps: If true, remove points where ground truth masks has overlappign classes from 119 | predictions and ground truth masks. 120 | """ 121 | 122 | def __init__( 123 | self, 124 | use_threshold: bool = False, 125 | threshold: float = 0.5, 126 | matching: str = "hungarian", 127 | compute_discovery_fraction: bool = False, 128 | correct_localization: bool = False, 129 | discovery_threshold: float = 0.5, 130 | ignore_background: bool = False, 131 | ignore_overlaps: bool = False, 132 | ): 133 | super().__init__() 134 | self.use_threshold = use_threshold 135 | self.threshold = threshold 136 | self.discovery_threshold = discovery_threshold 137 | self.compute_discovery_fraction = compute_discovery_fraction 138 | self.correct_localization = correct_localization 139 | if compute_discovery_fraction and correct_localization: 140 | raise ValueError( 141 | "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled." 142 | ) 143 | 144 | matchings = ("hungarian", "best_overlap") 145 | if matching not in matchings: 146 | raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.") 147 | self.matching = matching 148 | self.ignore_background = ignore_background 149 | self.ignore_overlaps = ignore_overlaps 150 | 151 | self.add_state( 152 | "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum" 153 | ) 154 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 155 | 156 | def update( 157 | self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None 158 | ): 159 | """Update this metric. 160 | 161 | Args: 162 | prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the 163 | number of classes. Assumes class probabilities as inputs. 164 | target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the 165 | number of classes. 166 | ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W) 167 | """ 168 | if prediction.ndim == 5: 169 | # Merge frames, height and width to single dimension. 170 | predictions = prediction.transpose(1, 2).flatten(-3, -1) 171 | targets = target.transpose(1, 2).flatten(-3, -1) 172 | if ignore is not None: 173 | ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1) 174 | elif prediction.ndim == 4: 175 | # Merge height and width to single dimension. 176 | predictions = prediction.flatten(-2, -1) 177 | targets = target.flatten(-2, -1) 178 | if ignore is not None: 179 | ignore = ignore.to(torch.bool).flatten(-2, -1) 180 | else: 181 | raise ValueError(f"Incorrect input shape: f{prediction.shape}") 182 | 183 | if self.use_threshold: 184 | predictions = predictions > self.threshold 185 | else: 186 | indices = torch.argmax(predictions, dim=1) 187 | predictions = torch.nn.functional.one_hot(indices, num_classes=predictions.shape[1]) 188 | predictions = predictions.transpose(1, 2) 189 | 190 | if self.ignore_background: 191 | targets = targets[:, 1:] 192 | 193 | targets = targets > 0 # Ensure masks are binary 194 | 195 | if self.ignore_overlaps: 196 | overlaps = targets.sum(1, keepdim=True) > 1 197 | if ignore is None: 198 | ignore = overlaps 199 | else: 200 | ignore = ignore | overlaps 201 | 202 | if ignore is not None: 203 | assert ignore.ndim == 3 and ignore.shape[1] == 1 204 | predictions[ignore.expand_as(predictions)] = 0 205 | targets[ignore.expand_as(targets)] = 0 206 | 207 | # Should be either 0 (empty, padding) or 1 (single object). 208 | assert torch.all(targets.sum(dim=1) < 2), "Issues with target format, mask non-exclusive" 209 | 210 | for pred, target in zip(predictions, targets): 211 | nonzero_classes = torch.sum(target, dim=-1) > 0 212 | target = target[nonzero_classes] # Remove empty (e.g. padded) classes 213 | if len(target) == 0: 214 | continue # Skip elements without any target mask 215 | 216 | iou_per_class = unsupervised_mask_iou( 217 | pred, target, matching=self.matching, reduction="none" 218 | ) 219 | 220 | if self.compute_discovery_fraction: 221 | discovered = iou_per_class > self.discovery_threshold 222 | self.values += discovered.sum() / len(discovered) 223 | elif self.correct_localization: 224 | correctly_localized = torch.any(iou_per_class > self.discovery_threshold) 225 | self.values += correctly_localized.sum() 226 | else: 227 | self.values += iou_per_class.mean() 228 | self.total += 1 229 | 230 | def compute(self) -> torch.Tensor: 231 | if self.total == 0: 232 | return torch.zeros_like(self.values) 233 | else: 234 | return self.values / self.total 235 | 236 | 237 | class MaskCorLocMetric(UnsupervisedMaskIoUMetric): 238 | def __init__(self, **kwargs): 239 | super().__init__(matching="best_overlap", correct_localization=True, **kwargs) 240 | 241 | 242 | class AverageBestOverlapMetric(UnsupervisedMaskIoUMetric): 243 | def __init__(self, **kwargs): 244 | super().__init__(matching="best_overlap", **kwargs) 245 | 246 | 247 | class BestOverlapObjectRecoveryMetric(UnsupervisedMaskIoUMetric): 248 | def __init__(self, **kwargs): 249 | super().__init__(matching="best_overlap", compute_discovery_fraction=True, **kwargs) 250 | 251 | 252 | def unsupervised_mask_iou( 253 | pred_mask: torch.Tensor, 254 | true_mask: torch.Tensor, 255 | matching: str = "hungarian", 256 | reduction: str = "mean", 257 | iou_empty: float = 0.0, 258 | ) -> torch.Tensor: 259 | """Compute intersection-over-union (IoU) between masks with unknown class correspondences. 260 | 261 | This metric is also known as Jaccard index. Note that this is a non-batched implementation. 262 | 263 | Args: 264 | pred_mask: Predicted mask of shape (C, N), where C is the number of predicted classes and 265 | N is the number of points. Masks are assumed to be binary. 266 | true_mask: Ground truth mask of shape (K, N), where K is the number of ground truth 267 | classes and N is the number of points. Masks are assumed to be binary. 268 | matching: How to match predicted classes to ground truth classes. For "hungarian", computes 269 | assignment that maximizes total IoU between all classes. For "best_overlap", uses the 270 | predicted class with maximum overlap for each ground truth class (each predicted class 271 | can be assigned to multiple ground truth classes). Empty ground truth classes are 272 | assigned IoU of zero. 273 | reduction: If "mean", return IoU averaged over classes. If "none", return per-class IoU. 274 | iou_empty: IoU for the case when a class does not occur, but was also not predicted. 275 | 276 | Returns: 277 | Mean IoU over classes if reduction is `mean`, tensor of shape (K,) containing per-class IoU 278 | otherwise. 279 | """ 280 | assert pred_mask.ndim == 2 281 | assert true_mask.ndim == 2 282 | n_gt_classes = len(true_mask) 283 | pred_mask = pred_mask.unsqueeze(1).to(torch.bool) 284 | true_mask = true_mask.unsqueeze(0).to(torch.bool) 285 | 286 | intersection = torch.sum(pred_mask & true_mask, dim=-1).to(torch.float64) 287 | union = torch.sum(pred_mask | true_mask, dim=-1).to(torch.float64) 288 | pairwise_iou = intersection / union 289 | 290 | # Remove NaN from divide-by-zero: class does not occur, and class was not predicted. 291 | pairwise_iou[union == 0] = iou_empty 292 | 293 | if matching == "hungarian": 294 | pred_idxs, true_idxs = scipy.optimize.linear_sum_assignment( 295 | pairwise_iou.cpu(), maximize=True 296 | ) 297 | pred_idxs = torch.as_tensor(pred_idxs, dtype=torch.int64, device=pairwise_iou.device) 298 | true_idxs = torch.as_tensor(true_idxs, dtype=torch.int64, device=pairwise_iou.device) 299 | elif matching == "best_overlap": 300 | non_empty_gt = torch.sum(true_mask.squeeze(0), dim=1) > 0 301 | pred_idxs = torch.argmax(pairwise_iou, dim=0)[non_empty_gt] 302 | true_idxs = torch.arange(pairwise_iou.shape[1])[non_empty_gt] 303 | else: 304 | raise ValueError(f"Unknown matching {matching}") 305 | 306 | matched_iou = pairwise_iou[pred_idxs, true_idxs] 307 | iou = torch.zeros(n_gt_classes, dtype=torch.float64, device=pairwise_iou.device) 308 | iou[true_idxs] = matched_iou 309 | 310 | if reduction == "mean": 311 | return iou.mean() 312 | else: 313 | return iou 314 | 315 | 316 | def tensor_to_one_hot(tensor: torch.Tensor, dim: int) -> torch.Tensor: 317 | """Convert tensor to one-hot encoding by using maximum across dimension as one-hot element.""" 318 | assert 0 <= dim 319 | max_idxs = torch.argmax(tensor, dim=dim, keepdim=True) 320 | shape = [1] * dim + [-1] + [1] * (tensor.ndim - dim - 1) 321 | one_hot = max_idxs == torch.arange(tensor.shape[dim], device=tensor.device).view(*shape) 322 | return one_hot.to(torch.long) 323 | 324 | 325 | def adjusted_rand_index(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> torch.Tensor: 326 | """Computes adjusted Rand index (ARI), a clustering similarity score. 327 | 328 | This implementation ignores points with no cluster label in `true_mask` (i.e. those points for 329 | which `true_mask` is a zero vector). In the context of segmentation, that means this function 330 | can ignore points in an image corresponding to the background (i.e. not to an object). 331 | 332 | Implementation adapted from https://github.com/deepmind/multi_object_datasets and 333 | https://github.com/google-research/slot-attention-video/blob/main/savi/lib/metrics.py 334 | 335 | Args: 336 | pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape 337 | (batch_size, n_points, n_pred_clusters). 338 | true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points, 339 | n_true_clusters). 340 | 341 | Returns: 342 | ARI scores of shape (batch_size,). 343 | """ 344 | n_pred_clusters = pred_mask.shape[-1] 345 | pred_cluster_ids = torch.argmax(pred_mask, axis=-1) 346 | 347 | # Convert true and predicted clusters to one-hot ('oh') representations. We use float64 here on 348 | # purpose, otherwise mixed precision training automatically casts to FP16 in some of the 349 | # operations below, which can create overflows. 350 | true_mask_oh = true_mask.to(torch.float64) # already one-hot 351 | pred_mask_oh = torch.nn.functional.one_hot(pred_cluster_ids, n_pred_clusters).to(torch.float64) 352 | 353 | n_ij = torch.einsum("bnc,bnk->bck", true_mask_oh, pred_mask_oh) 354 | a = torch.sum(n_ij, axis=-1) 355 | b = torch.sum(n_ij, axis=-2) 356 | n_fg_points = torch.sum(a, axis=1) 357 | 358 | rindex = torch.sum(n_ij * (n_ij - 1), axis=(1, 2)) 359 | aindex = torch.sum(a * (a - 1), axis=1) 360 | bindex = torch.sum(b * (b - 1), axis=1) 361 | expected_rindex = aindex * bindex / torch.clamp(n_fg_points * (n_fg_points - 1), min=1) 362 | max_rindex = (aindex + bindex) / 2 363 | denominator = max_rindex - expected_rindex 364 | ari = (rindex - expected_rindex) / denominator 365 | 366 | # There are two cases for which the denominator can be zero: 367 | # 1. If both true_mask and pred_mask assign all pixels to a single cluster. 368 | # (max_rindex == expected_rindex == rindex == n_fg_points * (n_fg_points-1)) 369 | # 2. If both true_mask and pred_mask assign max 1 point to each cluster. 370 | # (max_rindex == expected_rindex == rindex == 0) 371 | # In both cases, we want the ARI score to be 1.0: 372 | return torch.where(denominator > 0, ari, torch.ones_like(ari)) 373 | 374 | 375 | def fg_adjusted_rand_index( 376 | pred_mask: torch.Tensor, true_mask: torch.Tensor, bg_dim: int = 0 377 | ) -> torch.Tensor: 378 | """Compute adjusted random index using only foreground groups (FG-ARI). 379 | 380 | Args: 381 | pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape 382 | (batch_size, n_points, n_pred_clusters). 383 | true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points, 384 | n_true_clusters). 385 | bg_dim: Index of background class in true mask. 386 | 387 | Returns: 388 | ARI scores of shape (batch_size,). 389 | """ 390 | n_true_clusters = true_mask.shape[-1] 391 | assert 0 <= bg_dim < n_true_clusters 392 | if bg_dim == 0: 393 | true_mask_only_fg = true_mask[..., 1:] 394 | elif bg_dim == n_true_clusters - 1: 395 | true_mask_only_fg = true_mask[..., :-1] 396 | else: 397 | true_mask_only_fg = torch.cat( 398 | (true_mask[..., :bg_dim], true_mask[..., bg_dim + 1 :]), dim=-1 399 | ) 400 | 401 | return adjusted_rand_index(pred_mask, true_mask_only_fg) 402 | 403 | 404 | def _all_equal_masked(values: torch.Tensor, mask: torch.Tensor, dim=-1) -> torch.Tensor: 405 | """Check if all masked values along a dimension of a tensor are the same. 406 | 407 | All non-masked values are considered as true, i.e. if no value is masked, true is returned 408 | for this dimension. 409 | """ 410 | assert mask.dtype == torch.bool 411 | _, first_non_masked_idx = torch.max(mask, dim=dim) 412 | 413 | comparison_value = values.gather(index=first_non_masked_idx.unsqueeze(dim), dim=dim) 414 | 415 | return torch.logical_or(~mask, values == comparison_value).all(dim=dim) 416 | 417 | 418 | def masks_to_bboxes(masks: torch.Tensor, empty_value: float = -1.0) -> torch.Tensor: 419 | """Compute bounding boxes around the provided masks. 420 | 421 | Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py 422 | 423 | Args: 424 | masks: Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial 425 | dimensions. 426 | empty_value: Value bounding boxes should contain for empty masks. 427 | 428 | Returns: 429 | Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1) 430 | is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right 431 | corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain 432 | `empty_value` instead. 433 | """ 434 | masks = masks.bool() 435 | if masks.numel() == 0: 436 | return torch.zeros((0, 4), device=masks.device) 437 | 438 | large_value = 1e8 439 | inv_mask = ~masks 440 | 441 | h, w = masks.shape[-2:] 442 | 443 | y = torch.arange(0, h, dtype=torch.float, device=masks.device) 444 | x = torch.arange(0, w, dtype=torch.float, device=masks.device) 445 | y, x = torch.meshgrid(y, x, indexing="ij") 446 | 447 | x_mask = masks * x.unsqueeze(0) 448 | x_max = x_mask.flatten(1).max(-1)[0] 449 | x_min = x_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0] 450 | 451 | y_mask = masks * y.unsqueeze(0) 452 | y_max = y_mask.flatten(1).max(-1)[0] 453 | y_min = y_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0] 454 | 455 | bboxes = torch.stack((x_min, y_min, x_max, y_max), dim=1) 456 | bboxes[x_min == large_value] = empty_value 457 | 458 | return bboxes 459 | 460 | 461 | def _remap_one_hot_mask( 462 | mask: torch.Tensor, new_classes: torch.Tensor, n_new_classes: int, strip_empty: bool = False 463 | ): 464 | """Remap classes from binary mask to new classes. 465 | 466 | In the case of an overlap of classes for a point, the new class with the highest ID is 467 | assigned to that point. If no class is assigned to a point, the point will have no class 468 | assigned after remapping as well. 469 | 470 | Args: 471 | mask: Binary mask of shape (B, P, K) where K is the number of old classes and P is the 472 | number of points. 473 | new_classes: Tensor of shape (B, K) containing ids of new classes for each old class. 474 | n_new_classes: Number of classes after remapping, i.e. highest class id that can occur. 475 | strip_empty: Whether to remove the empty pixels mask 476 | 477 | Returns: 478 | Tensor of shape (B, P, J), where J is the new number of classes. 479 | """ 480 | assert new_classes.shape[1] == mask.shape[2] 481 | mask_dense = (mask * new_classes.unsqueeze(1)).max(dim=-1).values 482 | mask = torch.nn.functional.one_hot(mask_dense.to(torch.long), num_classes=n_new_classes + 1) 483 | 484 | if strip_empty: 485 | mask = mask[..., 1:] 486 | 487 | return mask 488 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Tuple, List, Dict 3 | 4 | import faiss 5 | import torch 6 | import numpy as np 7 | import os 8 | from datetime import datetime 9 | import glob 10 | import matplotlib.pyplot as plt 11 | from joblib import Parallel, delayed 12 | from scipy.optimize import linear_sum_assignment 13 | from skimage.measure import label 14 | from collections import deque, defaultdict 15 | 16 | from torch import nn 17 | from torchvision import transforms 18 | from torchvision.transforms import GaussianBlur 19 | from torchmetrics import Metric 20 | 21 | 22 | def save_checkpoint(state, is_best=0, gap=1, filename='models/checkpoint.pth', keep_all=False): 23 | torch.save(state, filename) 24 | last_epoch_path = os.path.join(os.path.dirname(filename), 25 | 'epoch%s.pth' % str(state['epoch'] - gap)) 26 | if not keep_all: 27 | try: 28 | os.remove(last_epoch_path) 29 | except: 30 | pass 31 | if is_best: 32 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'model_best_*.pth')) 33 | for i in past_best: 34 | try: 35 | os.remove(i) 36 | except: 37 | pass 38 | torch.save(state, os.path.join(os.path.dirname(filename), 'model_best_epoch%s.pth' % str(state['epoch']))) 39 | 40 | 41 | class PredsmIoU(Metric): 42 | """ 43 | Subclasses Metric. Computes mean Intersection over Union (mIoU) given ground-truth and predictions. 44 | .update() can be called repeatedly to add data from multiple validation loops. 45 | """ 46 | 47 | def __init__(self, 48 | num_pred_classes: int, 49 | num_gt_classes: int): 50 | """ 51 | :param num_pred_classes: The number of predicted classes. 52 | :param num_gt_classes: The number of gt classes. 53 | """ 54 | super().__init__(dist_sync_on_step=False, compute_on_step=False) 55 | self.num_pred_classes = num_pred_classes 56 | self.num_gt_classes = num_gt_classes 57 | self.add_state("iou", []) 58 | self.add_state("iou_excludeFirst", []) 59 | self.n_jobs = -1 60 | 61 | def update(self, gt: torch.Tensor, pred: torch.Tensor, many_to_one=True, precision_based=True, linear_probe=False): 62 | pred = pred.cpu().numpy().astype(int) 63 | gt = gt.cpu().numpy().astype(int) 64 | assert len(np.unique(pred)) <= self.num_pred_classes 65 | assert np.max(pred) <= self.num_pred_classes 66 | iou_all, iou_excludeFirst = self.compute_miou(gt, pred, self.num_pred_classes, len(np.unique(gt)), 67 | many_to_one=many_to_one, precision_based=precision_based, linear_probe=linear_probe) 68 | self.iou.append(iou_all) 69 | self.iou_excludeFirst.append(iou_excludeFirst) 70 | 71 | def compute(self): 72 | """ 73 | Compute mIoU 74 | """ 75 | mIoU = np.mean(self.iou) 76 | mIoU_excludeFirst = np.mean(self.iou_excludeFirst) 77 | print('---mIoU computed---', mIoU) 78 | print('---mIoU exclude first---', mIoU_excludeFirst) 79 | return mIoU 80 | 81 | def compute_miou(self, gt: np.ndarray, pred: np.ndarray, num_pred: int, num_gt: int, 82 | many_to_one=False, precision_based=False, linear_probe=False): 83 | """ 84 | Compute mIoU with optional hungarian matching or many-to-one matching (extracts information from labels). 85 | :param gt: numpy array with all flattened ground-truth class assignments per pixel 86 | :param pred: numpy array with all flattened class assignment predictions per pixel 87 | :param num_pred: number of predicted classes 88 | :param num_gt: number of ground truth classes 89 | :param many_to_one: Compute a many-to-one mapping of predicted classes to ground truth instead of hungarian 90 | matching. 91 | :param precision_based: Use precision as matching criteria instead of IoU for assigning predicted class to 92 | ground truth class. 93 | :param linear_probe: Skip hungarian / many-to-one matching. Used for evaluating predictions of fine-tuned heads. 94 | :return: mIoU over all classes, true positives per class, false negatives per class, false positives per class, 95 | reordered predictions matching gt 96 | """ 97 | assert pred.shape == gt.shape 98 | print(f"unique semantic class = {np.unique(gt)}") 99 | gt_class = np.unique(gt).tolist() 100 | tp = [0] * num_gt 101 | fp = [0] * num_gt 102 | fn = [0] * num_gt 103 | iou = [0] * num_gt 104 | 105 | if linear_probe: 106 | reordered_preds = pred 107 | else: 108 | if many_to_one: 109 | match = self._original_match(num_pred, num_gt, pred, gt, precision_based=precision_based) 110 | # remap predictions 111 | reordered_preds = np.zeros(len(pred)) 112 | for target_i, matched_preds in match.items(): 113 | for pred_i in matched_preds: 114 | reordered_preds[pred == int(pred_i)] = int(target_i) 115 | else: 116 | match = self._hungarian_match(num_pred, num_gt, pred, gt) 117 | # remap predictions 118 | reordered_preds = np.zeros(len(pred)) 119 | for target_i, pred_i in zip(*match): 120 | reordered_preds[pred == int(pred_i)] = int(target_i) 121 | # merge all unmatched predictions to background 122 | for unmatched_pred in np.delete(np.arange(num_pred), np.array(match[1])): 123 | reordered_preds[pred == int(unmatched_pred)] = 0 124 | 125 | # tp, fp, and fn evaluation 126 | for i_part in range(0, num_gt): 127 | tmp_all_gt = (gt == gt_class[i_part]) 128 | tmp_pred = (reordered_preds == gt_class[i_part]) 129 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred) 130 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred) 131 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred) 132 | 133 | # Calculate IoU per class 134 | for i_part in range(0, num_gt): 135 | iou[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 136 | 137 | print('\tiou = ', iou, np.mean(iou[1:])) 138 | if len(iou) > 1: 139 | return np.mean(iou), np.mean(iou[1:]) 140 | else: 141 | # return np.mean(iou), tp, fp, fn, reordered_preds.astype(int).tolist() 142 | return np.mean(iou), np.mean(iou) 143 | 144 | @staticmethod 145 | def get_score(flat_preds: np.ndarray, flat_targets: np.ndarray, c1: int, c2: int, precision_based: bool = False) \ 146 | -> float: 147 | """ 148 | Calculates IoU given gt class c1 and prediction class c2. 149 | :param flat_preds: flattened predictions 150 | :param flat_targets: flattened gt 151 | :param c1: ground truth class to match 152 | :param c2: predicted class to match 153 | :param precision_based: flag to calculate precision instead of IoU. 154 | :return: The score if gt-c1 was matched to predicted c2. 155 | """ 156 | tmp_all_gt = (flat_targets == c1) 157 | tmp_pred = (flat_preds == c2) 158 | tp = np.sum(tmp_all_gt & tmp_pred) 159 | fp = np.sum(~tmp_all_gt & tmp_pred) 160 | if not precision_based: 161 | fn = np.sum(tmp_all_gt & ~tmp_pred) 162 | jac = float(tp) / max(float(tp + fp + fn), 1e-8) 163 | return jac 164 | else: 165 | prec = float(tp) / max(float(tp + fp), 1e-8) 166 | # print('\tgt, pred = ', c1, c2, ' | precision=', prec) 167 | return prec 168 | 169 | def compute_score_matrix(self, num_pred: int, num_gt: int, pred: np.ndarray, gt: np.ndarray, 170 | precision_based: bool = False) -> np.ndarray: 171 | """ 172 | Compute score matrix. Each element i, j of matrix is the score if i was matched j. Computation is parallelized 173 | over self.n_jobs. 174 | :param num_pred: number of predicted classes 175 | :param num_gt: number of ground-truth classes 176 | :param pred: flattened predictions 177 | :param gt: flattened gt 178 | :param precision_based: flag to calculate precision instead of IoU. 179 | :return: num_pred x num_gt matrix with A[i, j] being the score if ground-truth class i was matched to 180 | predicted class j. 181 | """ 182 | # print("Parallelizing iou computation") 183 | # start = time.time() 184 | score_mat = Parallel(n_jobs=self.n_jobs)(delayed(self.get_score)(pred, gt, c1, c2, precision_based=precision_based) 185 | for c2 in range(num_pred) for c1 in np.unique(gt)) 186 | # print(f"took {time.time() - start} seconds") 187 | score_mat = np.array(score_mat) 188 | return score_mat.reshape((num_pred, num_gt)).T 189 | 190 | def _hungarian_match(self, num_pred: int, num_gt: int, pred: np.ndarray, gt: np.ndarray): 191 | # do hungarian matching. If num_pred > num_gt match will be partial only. 192 | iou_mat = self.compute_score_matrix(num_pred, num_gt, pred, gt) 193 | match = linear_sum_assignment(1 - iou_mat) 194 | print("Matched clusters to gt classes:") 195 | print(match) 196 | return match 197 | 198 | def _original_match(self, num_pred, num_gt, pred, gt, precision_based=False) -> Dict[int, list]: 199 | score_mat = self.compute_score_matrix(num_pred, num_gt, pred, gt, precision_based=precision_based) 200 | gt_class = np.unique(gt).tolist() 201 | preds_to_gts = {} 202 | preds_to_gt_scores = {} 203 | # Greedily match predicted class to ground-truth class by best score. 204 | for pred_c in range(num_pred): 205 | for gt_i in range(num_gt): 206 | score = score_mat[gt_i, pred_c] 207 | if (pred_c not in preds_to_gts) or (score > preds_to_gt_scores[pred_c]): 208 | preds_to_gts[pred_c] = gt_class[gt_i] 209 | preds_to_gt_scores[pred_c] = score 210 | gt_to_matches = defaultdict(list) 211 | for k, v in preds_to_gts.items(): 212 | gt_to_matches[v].append(k) 213 | # print('original match:', gt_to_matches) 214 | return gt_to_matches 215 | 216 | 217 | class PredsmIoUKmeans(PredsmIoU): 218 | """ 219 | Used to track k-means cluster correspondence to ground-truth categories during fine-tuning. 220 | """ 221 | 222 | def __init__(self, 223 | clustering_granularities: List[int], 224 | num_gt_classes: int, 225 | pca_dim: int = 50): 226 | """ 227 | :param clustering_granularities: list of clustering granularities for embeddings 228 | :param num_gt_classes: number of ground-truth classes 229 | :param pca_dim: target dimensionality of PCA 230 | """ 231 | super(PredsmIoU, self).__init__(compute_on_step=False, dist_sync_on_step=False) # Init Metric super class 232 | self.pca_dim = pca_dim 233 | self.num_pred_classes = clustering_granularities 234 | self.num_gt_classes = num_gt_classes 235 | self.add_state("masks", []) 236 | self.add_state("embeddings", []) 237 | self.add_state("gt", []) 238 | self.n_jobs = -1 # num_jobs = num_cores 239 | self.num_train_pca = 4000000 # take num_train_pca many vectors at max for training pca 240 | 241 | def update(self, masks: torch.Tensor, embeddings: torch.Tensor, gt: torch.Tensor) -> None: 242 | self.masks.append(masks) 243 | self.embeddings.append(embeddings) 244 | self.gt.append(gt) 245 | 246 | def compute(self, is_global_zero: bool) -> List[any]: 247 | if is_global_zero: 248 | # interpolate embeddings to match ground-truth masks spatially 249 | embeddings = torch.cat([e.cpu() for e in self.embeddings], dim=0) # move everything to cpu before catting 250 | valid_masks = torch.cat(self.masks, dim=0).cpu().numpy() 251 | res_w = valid_masks.shape[2] 252 | embeddings = nn.functional.interpolate(embeddings, size=(res_w, res_w), mode='bilinear') 253 | embeddings = embeddings.permute(0, 2, 3, 1).reshape(valid_masks.shape[0] * res_w ** 2, -1).numpy() 254 | 255 | # Normalize embeddings and reduce dims of embeddings by PCA 256 | normalized_embeddings = (embeddings - np.mean(embeddings, axis=0)) / ( 257 | np.std(embeddings, axis=0, ddof=0) + 1e-5) 258 | d_orig = embeddings.shape[1] 259 | pca = faiss.PCAMatrix(d_orig, self.pca_dim) 260 | pca.train(normalized_embeddings[:self.num_train_pca]) 261 | assert pca.is_trained 262 | transformed_feats = pca.apply_py(normalized_embeddings) 263 | 264 | # Cluster transformed feats with kmeans 265 | results = [] 266 | gt = torch.cat(self.gt, dim=0).cpu().numpy()[valid_masks] 267 | for k in self.num_pred_classes: # [500, 300, 21] 268 | kmeans = faiss.Kmeans(self.pca_dim, k, niter=50, nredo=5, seed=1, verbose=True, gpu=False, 269 | spherical=False) 270 | kmeans.train(transformed_feats) 271 | _, pred_labels = kmeans.index.search(transformed_feats, 1) 272 | clusters = pred_labels.squeeze() 273 | 274 | # Filter predictions by valid masks (removes voc boundary gt class) 275 | pred_flattened = clusters.reshape(valid_masks.shape[0], 1, res_w, res_w)[valid_masks] 276 | assert len(np.unique(pred_flattened)) == k 277 | assert np.max(pred_flattened) == k - 1 278 | 279 | # Calculate mIoU. Do many-to-one matching if k > self.num_gt_classes. 280 | if k == self.num_gt_classes: 281 | results.append((k, k, self.compute_miou(gt, pred_flattened, k, self.num_gt_classes, 282 | many_to_one=False))) 283 | else: 284 | results.append((k, k, self.compute_miou(gt, pred_flattened, k, self.num_gt_classes, 285 | many_to_one=True))) 286 | results.append((k, f"{k}_prec", self.compute_miou(gt, pred_flattened, k, self.num_gt_classes, 287 | many_to_one=True, precision_based=True))) 288 | return results 289 | 290 | 291 | def eval_jac(gt: torch.Tensor, pred_mask: torch.Tensor, with_boundary: bool = True) -> float: 292 | """ 293 | Calculate Intersection over Union averaged over all pictures. with_boundary flag, if set, doesn't filter out the 294 | boundary class as background. 295 | """ 296 | jacs = 0 297 | for k, mask in enumerate(gt): 298 | if with_boundary: 299 | gt_fg_mask = (mask != 0).float() 300 | else: 301 | gt_fg_mask = ((mask != 0) & (mask != 255)).float() 302 | intersection = gt_fg_mask * pred_mask[k] 303 | intersection = torch.sum(torch.sum(intersection, dim=-1), dim=-1) 304 | union = (gt_fg_mask + pred_mask[k]) > 0 305 | union = torch.sum(torch.sum(union, dim=-1), dim=-1) 306 | jacs += intersection / union 307 | res = jacs / gt.size(0) 308 | print(res) 309 | return res.item() 310 | 311 | 312 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \ 313 | -> torch.Tensor: 314 | """ 315 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes 316 | components smaller than 3 pixels. 317 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the 318 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions 319 | so that we can use the foreground mask during training time. 320 | :param attentions: torch 4D-Tensor containing the averaged attentions 321 | :param spatial_res: spatial resolution of the attention map 322 | :param threshold: the percentage of mass to keep as foreground. 323 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring. 324 | :return: the foreground mask obtained from the ViT's attention. 325 | """ 326 | # Blur attentions 327 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions) 328 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2) 329 | # Keep threshold% of mass 330 | val, idx = torch.sort(attentions) 331 | val /= torch.sum(val, dim=-1, keepdim=True) 332 | cumval = torch.cumsum(val, dim=-1) 333 | th_attn = cumval > (1 - threshold) 334 | idx2 = torch.argsort(idx) 335 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0]) 336 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float() 337 | # Remove components with less than 3 pixels 338 | for j, th_att in enumerate(th_attn): 339 | labelled = label(th_att.cpu().numpy()) 340 | for k in range(1, np.max(labelled) + 1): 341 | mask = labelled == k 342 | if np.sum(mask) <= 2: 343 | th_attn[j, 0][mask] = 0 344 | return th_attn.detach() 345 | 346 | 347 | def neq_load_customized(model, pretrained_dict): 348 | """ 349 | load pre-trained model in a non-equal way, 350 | when new model has been partially modified 351 | """ 352 | model_dict = model.state_dict() 353 | tmp = {} 354 | print('\n=======Check Weights Loading======') 355 | print('Weights not used from pretrained file:') 356 | for k, v in pretrained_dict.items(): 357 | if k in model_dict: 358 | tmp[k] = v 359 | else: 360 | print(k) 361 | 362 | print('\n-----------------------------------') 363 | print('Weights not loaded into new model:') 364 | for k, v in model_dict.items(): 365 | if k not in pretrained_dict: 366 | print(k) 367 | print('===================================\n') 368 | 369 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 370 | del pretrained_dict 371 | model_dict.update(tmp) 372 | del tmp 373 | model.load_state_dict(model_dict) 374 | return model 375 | 376 | 377 | def neq_load_external(model, pretrained_dict): 378 | """ 379 | load pre-trained model from external source 380 | """ 381 | model_dict = model.state_dict() 382 | tmp = {} 383 | print('\n=======Check Weights Loading======') 384 | print('Weights not used from pretrained file:') 385 | for k, v in pretrained_dict.items(): 386 | if k.startswith('model'): 387 | k = k.removeprefix('model.') # for Leopart 388 | if 'backbone.' + k in model_dict: 389 | tmp['backbone.' + k] = v 390 | else: 391 | print(k) 392 | 393 | print('\n-----------------------------------') 394 | print('Weights not loaded into new model:') 395 | for k, v in model_dict.items(): 396 | if k not in tmp: 397 | print(k) 398 | print('===================================\n') 399 | 400 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 401 | del pretrained_dict 402 | model_dict.update(tmp) 403 | del tmp 404 | model.load_state_dict(model_dict) 405 | return model 406 | 407 | 408 | def write_log(content, epoch, filename): 409 | if not os.path.exists(filename): 410 | log_file = open(filename, 'w') 411 | else: 412 | log_file = open(filename, 'a') 413 | log_file.write('## Epoch %d:\n' % epoch) 414 | log_file.write('time: %s\n' % str(datetime.now())) 415 | log_file.write(content + '\n\n') 416 | log_file.close() 417 | 418 | 419 | def calc_topk_accuracy(output, target, topk=(1,)): 420 | ''' 421 | Given predicted and ground truth labels, 422 | calculate top-k accuracies. 423 | ''' 424 | maxk = max(topk) 425 | batch_size = target.size(0) 426 | 427 | _, pred = output.topk(maxk, 1, True, True) 428 | pred = pred.t() 429 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 430 | 431 | res = [] 432 | for k in topk: 433 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 434 | res.append(correct_k.mul_(1 / batch_size)) 435 | return res 436 | 437 | 438 | def calc_accuracy(output, target): 439 | '''output: (B, N); target: (B)''' 440 | target = target.squeeze() 441 | _, pred = torch.max(output, 1) 442 | return torch.mean((pred == target).float()) 443 | 444 | 445 | def calc_accuracy_binary(output, target): 446 | '''output, target: (B, N), output is logits, before sigmoid ''' 447 | pred = output > 0 448 | acc = torch.mean((pred == target.byte()).float()) 449 | del pred, output, target 450 | return acc 451 | 452 | 453 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 454 | assert len(mean) == len(std) == 3 455 | inv_mean = [-mean[i] / std[i] for i in range(3)] 456 | inv_std = [1 / i for i in std] 457 | return transforms.Normalize(mean=inv_mean, std=inv_std) 458 | 459 | 460 | class AverageMeter(object): 461 | """Computes and stores the average and current value""" 462 | 463 | def __init__(self): 464 | self.reset() 465 | 466 | def reset(self): 467 | self.val = 0 468 | self.avg = 0 469 | self.sum = 0 470 | self.count = 0 471 | self.local_history = deque([]) 472 | self.local_avg = 0 473 | self.history = [] 474 | self.dict = {} # save all data values here 475 | self.save_dict = {} # save mean and std here, for summary table 476 | 477 | def update(self, val, n=1, history=0, step=5): 478 | self.val = val 479 | self.sum += val * n 480 | self.count += n 481 | self.avg = self.sum / self.count 482 | if history: 483 | self.history.append(val) 484 | if step > 0: 485 | self.local_history.append(val) 486 | if len(self.local_history) > step: 487 | self.local_history.popleft() 488 | self.local_avg = np.average(self.local_history) 489 | 490 | def dict_update(self, val, key): 491 | if key in self.dict.keys(): 492 | self.dict[key].append(val) 493 | else: 494 | self.dict[key] = [val] 495 | 496 | def __len__(self): 497 | return self.count 498 | 499 | 500 | class AccuracyTable(object): 501 | '''compute accuracy for each class''' 502 | 503 | def __init__(self): 504 | self.dict = {} 505 | 506 | def update(self, pred, tar): 507 | pred = torch.squeeze(pred) 508 | tar = torch.squeeze(tar) 509 | for i, j in zip(pred, tar): 510 | i = int(i) 511 | j = int(j) 512 | if j not in self.dict.keys(): 513 | self.dict[j] = {'count': 0, 'correct': 0} 514 | self.dict[j]['count'] += 1 515 | if i == j: 516 | self.dict[j]['correct'] += 1 517 | 518 | def print_table(self, label): 519 | for key in self.dict.keys(): 520 | acc = self.dict[key]['correct'] / self.dict[key]['count'] 521 | print('%s: %2d, accuracy: %3d/%3d = %0.6f' \ 522 | % (label, key, self.dict[key]['correct'], self.dict[key]['count'], acc)) 523 | 524 | 525 | class ConfusionMeter(object): 526 | '''compute and show confusion matrix''' 527 | 528 | def __init__(self, num_class): 529 | self.num_class = num_class 530 | self.mat = np.zeros((num_class, num_class)) 531 | self.precision = [] 532 | self.recall = [] 533 | 534 | def update(self, pred, tar): 535 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy() 536 | pred = np.squeeze(pred) 537 | tar = np.squeeze(tar) 538 | for p, t in zip(pred.flat, tar.flat): 539 | self.mat[p][t] += 1 540 | 541 | def print_mat(self): 542 | print('Confusion Matrix: (target in columns)') 543 | print(self.mat) 544 | 545 | def plot_mat(self, path, dictionary=None, annotate=False): 546 | plt.figure(dpi=600) 547 | plt.imshow(self.mat, 548 | cmap=plt.cm.jet, 549 | interpolation=None, 550 | extent=(0.5, np.shape(self.mat)[0] + 0.5, np.shape(self.mat)[1] + 0.5, 0.5)) 551 | width, height = self.mat.shape 552 | if annotate: 553 | for x in range(width): 554 | for y in range(height): 555 | plt.annotate(str(int(self.mat[x][y])), xy=(y + 1, x + 1), 556 | horizontalalignment='center', 557 | verticalalignment='center', 558 | fontsize=8) 559 | 560 | if dictionary is not None: 561 | plt.xticks([i + 1 for i in range(width)], 562 | [dictionary[i] for i in range(width)], 563 | rotation='vertical') 564 | plt.yticks([i + 1 for i in range(height)], 565 | [dictionary[i] for i in range(height)]) 566 | plt.xlabel('Ground Truth') 567 | plt.ylabel('Prediction') 568 | plt.colorbar() 569 | plt.tight_layout() 570 | plt.savefig(path, format='svg') 571 | plt.clf() 572 | 573 | # for i in range(width): 574 | # if np.sum(self.mat[i,:]) != 0: 575 | # self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:])) 576 | # if np.sum(self.mat[:,i]) != 0: 577 | # self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i])) 578 | # print('Average Precision: %0.4f' % np.mean(self.precision)) 579 | # print('Average Recall: %0.4f' % np.mean(self.recall)) 580 | 581 | 582 | if __name__ == '__main__': 583 | pass 584 | --------------------------------------------------------------------------------