├── images
├── 2007_000464.jpg
└── End-to-end self-supervised semantic segmentation.png
├── model
├── __pycache__
│ ├── vit.cpython-310.pyc
│ ├── align_model.cpython-310.pyc
│ ├── criterion.cpython-310.pyc
│ └── transforms.cpython-310.pyc
├── criterion.py
├── align_model.py
├── transforms.py
└── vit.py
├── data
├── __pycache__
│ ├── coco_data.cpython-310.pyc
│ └── voc_data.cpython-310.pyc
├── movi_data.py
├── voc_data.py
├── coco_data.py
└── stuffthing_2017.json
├── configs
├── eval_voc_config.yml
├── eval_coco27_config.yml
├── eval_movi_config.yml
├── train_coco_config.yml
├── train_movi_config.yml
└── train_voc_config.yml
├── LICENSE.txt
├── README.md
├── evaluate
├── cocoStuff27_mask_visualize.py
├── proto_similarity.py
├── object_evaluation.py
├── sup_overcluster.py
├── visualize_segment.py
└── eval_utils.py
├── train.py
└── utils.py
/images/2007_000464.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/images/2007_000464.jpg
--------------------------------------------------------------------------------
/model/__pycache__/vit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/vit.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco_data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/data/__pycache__/coco_data.cpython-310.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc_data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/data/__pycache__/voc_data.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/align_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/align_model.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/criterion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/criterion.cpython-310.pyc
--------------------------------------------------------------------------------
/model/__pycache__/transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/model/__pycache__/transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/images/End-to-end self-supervised semantic segmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yliu1229/AlignSeg/HEAD/images/End-to-end self-supervised semantic segmentation.png
--------------------------------------------------------------------------------
/configs/eval_voc_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 1
2 |
3 | data:
4 | data_dir: ''
5 | dataset_name: "voc" # coco, imagenet100k, imagenet or voc
6 | num_classes: 21
7 | size_crops: 448
8 |
9 | val:
10 | arch: 'vit_small'
11 | batch_size: 1
12 | seed: 3407
13 | patch_size: 16
14 | embed_dim: 384
15 | hidden_dim: 384
16 | num_decode_layers: 1
17 | decoder_num_heads: 3
18 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query
19 | last_self_attention: False # whether use attention map as foreground hint
20 | mask_eval_size: 100
21 | checkpoint: './epoch10.pth'
22 |
--------------------------------------------------------------------------------
/configs/eval_coco27_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 1
2 |
3 | data:
4 | data_dir: ''
5 | dataset_name: "coco-all" # coco-all, coco-stuff, coco-thing or voc
6 | num_classes: 27
7 | size_crops: 448
8 |
9 | val:
10 | arch: 'vit_small'
11 | batch_size: 1
12 | seed: 3407
13 | patch_size: 16
14 | embed_dim: 384
15 | hidden_dim: 768
16 | num_decode_layers: 3
17 | decoder_num_heads: 3
18 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query
19 | last_self_attention: False # whether use attention map as foreground hint
20 | mask_eval_size: 100
21 | checkpoint: './epoch10.pth'
22 |
--------------------------------------------------------------------------------
/configs/eval_movi_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 1
2 |
3 | data:
4 | data_dir: "./Data/"
5 | dataset_name: "movi_e" # "movi_e" or "movi_c"
6 | num_classes: 17
7 | size_crops: 256
8 |
9 | val:
10 | arch: 'vit_small'
11 | batch_size: 1
12 | seed: 3407
13 | patch_size: 16
14 | embed_dim: 384
15 | hidden_dim: 768
16 | num_decode_layers: 6
17 | decoder_num_heads: 4
18 | num_queries: 18 # effective queries for mask generation, always ends with an 'Others' query
19 | last_self_attention: False # whether use attention map as foreground hint
20 | mask_eval_size: 256
21 | checkpoint: './log_tmp/movi_e-vit_small-bs32/model/epoch10.pth'
22 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yliu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/train_coco_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 4
2 |
3 | data:
4 | data_dir: ""
5 | dataset_name: "coco" # coco, or voc
6 | size_crops: [224, 224]
7 | augment_image: False
8 | jitter_strength: 1.0
9 | blur_strength: 1.0
10 | min_scale_crops: [0.5, 0.05]
11 | max_scale_crops: [1., 0.25]
12 | min_intersection_crops: 0.05
13 | nmb_crops: [1, 2]
14 | size_crops_val: 448
15 |
16 | train:
17 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2
18 | max_epochs: 100
19 | seed: 3407
20 | fix_vit: True
21 | exclude_norm_bias: True
22 | roi_align_kernel_size: 7
23 | arch: 'vit_small'
24 | patch_size: 16
25 | embed_dim: 384
26 | hidden_dim: 768
27 | num_decode_layers: 1
28 | decoder_num_heads: 3
29 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query
30 | last_self_attention: True # whether use attention map as foreground hint
31 | ce_temperature: 1
32 | lr_decoder: 0.0005
33 | final_lr: 0.
34 | weight_decay: 0.04
35 | weight_decay_end: 0.5
36 | negative_pressure: 0.10
37 | epsilon: 0.05
38 | save_checkpoint_every_n_epochs: 1
39 | checkpoint: ''
40 | pretrained_model: './dino_vitsmall16.pth'
41 | fix_prototypes: False
42 |
--------------------------------------------------------------------------------
/configs/train_movi_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 4
2 |
3 | data:
4 | data_dir: "./Data/"
5 | dataset_name: "movi_e" # "movi_c" or "movi_e"
6 | size_crops: [224, 224]
7 | augment_image: False
8 | jitter_strength: 1.0
9 | blur_strength: 1.0
10 | min_scale_crops: [0.5, 0.05]
11 | max_scale_crops: [1., 0.25]
12 | min_intersection_crops: 0.05
13 | nmb_crops: [1, 2]
14 | size_crops_val: 256 # Crops size for validation and seg maps viz
15 | num_classes_val: 17
16 |
17 | train:
18 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2
19 | max_epochs: 10
20 | seed: 3407
21 | fix_vit: True
22 | exclude_norm_bias: True
23 | roi_align_kernel_size: 7
24 | arch: 'vit_small'
25 | patch_size: 16
26 | embed_dim: 384
27 | hidden_dim: 768
28 | num_decode_layers: 6
29 | decoder_num_heads: 4
30 | num_queries: 18 # effective queries for mask generation, always ends with an 'Others' query
31 | last_self_attention: True # whether use attention map as foreground hint
32 | ce_temperature: 1
33 | lr_decoder: 0.0005
34 | final_lr: 0.
35 | weight_decay: 0.04
36 | weight_decay_end: 0.5
37 | negative_pressure: 0.13 # 0.13 for MOVi-C, 0.13 for MOVi-E
38 | corr_coefficient: 0.15
39 | epsilon: 0.05
40 | save_checkpoint_every_n_epochs: 1
41 | checkpoint:
42 | pretrained_model: 'dino_vitsmall16.pth'
43 | prototype_queries:
44 | fix_prototypes:
45 |
--------------------------------------------------------------------------------
/configs/train_voc_config.yml:
--------------------------------------------------------------------------------
1 | num_workers: 4
2 |
3 | data:
4 | data_dir: ""
5 | dataset_name: "voc" # coco, imagenet100k, imagenet or voc
6 | size_crops: [224, 224]
7 | augment_image: False
8 | jitter_strength: 1.0
9 | blur_strength: 1.0
10 | min_scale_crops: [0.5, 0.05]
11 | max_scale_crops: [1., 0.25]
12 | min_intersection_crops: 0.05
13 | nmb_crops: [1, 2]
14 | size_crops_val: 448 # Crops size for validation and seg maps viz
15 | num_classes_val: 21
16 | voc_data_path: ""
17 |
18 | train:
19 | batch_size: 32 # effective batch size is bs * gpus * res_w ** 2
20 | max_epochs: 10
21 | seed: 3407
22 | fix_vit: True
23 | exclude_norm_bias: True
24 | roi_align_kernel_size: 7
25 | arch: 'vit_small'
26 | patch_size: 16
27 | embed_dim: 384
28 | hidden_dim: 768
29 | num_decode_layers: 1
30 | decoder_num_heads: 3
31 | num_queries: 5 # effective queries for mask generation, always ends with an 'Others' query
32 | last_self_attention: True # whether use attention map as foreground hint
33 | ce_temperature: 1
34 | lr_decoder: 0.0005
35 | final_lr: 0.
36 | weight_decay: 0.04
37 | weight_decay_end: 0.5
38 | negative_pressure: 0.11
39 | corr_coefficient: 0.15
40 | epsilon: 0.05
41 | save_checkpoint_every_n_epochs: 1
42 | checkpoint: #''
43 | pretrained_model: './dino_vitsmall16.pth'
44 | prototype_queries:
45 | fix_prototypes: False
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## [AlignSeg] Rethinking Self-Supervised Semantic Segmentation: Achieving End-to-End Segmentation
2 |
3 | This is the PyTorch implementation of AlignSeg.
4 |
5 |
6 |

8 |
9 |
10 | ### Dataset Setup
11 |
12 | Please download the data and organize as detailed in the next subsections.
13 |
14 | ##### Pascal VOC
15 | Here's a [zipped version](https://www.dropbox.com/s/6gd4x0i9ewasymb/voc_data.zip?dl=0) for convenience.
16 |
17 | The structure for training and evaluation should be as follows:
18 | ```
19 | dataset root.
20 | └───SegmentationClass
21 | │ │ *.png
22 | │ │ ...
23 | └───SegmentationClassAug # contains segmentation masks from trainaug extension
24 | │ │ *.png
25 | │ │ ...
26 | └───images
27 | │ │ *.jpg
28 | │ │ ...
29 | └───sets
30 | │ │ train.txt
31 | │ │ trainaug.txt
32 | │ │ val.txt
33 | ```
34 |
35 | #### COCO-Stuff-27
36 | The structure for training and evaluation should be as follows:
37 | ```
38 | dataset root.
39 | └───annotations
40 | │ └─── annotations
41 | │ └─── stuffthingmaps_trainval2017
42 | │ │ stuffthing_2017.json
43 | │ └─── train2017
44 | │ │ *.png
45 | │ │ ...
46 | │ └─── val2017
47 | │ │ *.png
48 | │ │ ...
49 | └───coco
50 | │ └─── images
51 | │ └─── train2017
52 | │ │ *.jpg
53 | │ │ ...
54 | │ └─── val2017
55 | │ │ *.jpg
56 | │ │ ...
57 | ```
58 | The “curated” split introduced by IIC can be downloaded [here](https://www.robots.ox.ac.uk/~xuji/datasets/COCOStuff164kCurated.tar.gz).
59 |
60 | ### Self-supervised Training with Frozen ViT
61 |
62 | We provide the training configuration files for PVOC and COCO-Stuff in ```/configs``` folder, fill in your own path to dataset and pre-trained ViT.
63 |
64 | As the image encoder is frozen during training, the self-supervised training is quite efficient and can be implemented with only one GPU.
65 | To start training on PVOC, you can run the following exemplary command:
66 | ```
67 | python train.py --config_path ./configs/train_voc_config.yml
68 | ```
69 |
70 | The pre-trained ViT by DINO can be found [here](https://github.com/facebookresearch/dino).
71 |
72 | ### End-to-End Semantic Segmentation Inference
73 |
74 | AlignSeg can perform real-time and end-to-end segmentation inference.
75 |
76 | To perform segmentation inference and visualization, you can run the following exemplary command:
77 | ```
78 | python evaluate/visualize_segment.py --pretrained_weights {model.pth} --image_path ./images/2007_000464.jpg
79 | ```
80 | replace `{model.pth}` with the path to the pre-trained model.
81 |
82 | ### Evaluation
83 |
84 | We provide the evaluation configuration files for PVOC and COCO-Stuff-27 in ```/configs``` folder, fill in your own path to dataset and pre-trained model.
85 |
86 | To evaluate the pre-trained model on PVOC, you can run the following exemplary command:
87 | ```
88 | python evaluate/sup_overcluster.py --config_path ../configs/eval_voc_config.yml
89 | ```
90 |
91 | ### Pre-trained Models
92 |
93 | We provide our pre-trained models, they can be downloaded by links below.
94 |
95 |
96 |
97 | | Encoder |
98 | Dataset |
99 | mIoU |
100 | Download |
101 |
102 |
103 | | ViT-S/16 |
104 | PVOC |
105 | 69.5 |
106 | model |
107 |
108 |
109 | | ViT-S/16 |
110 | COCO-Stuff-27 |
111 | 35.1 |
112 | model |
113 |
114 |
115 |
116 |
117 |
118 | | Encoder |
119 | Dataset |
120 | FG-ARI |
121 | mBO |
122 | Download |
123 |
124 |
125 | | ViT-S/16 |
126 | MOVi-C |
127 | 48.0 |
128 | 31.2 |
129 | model |
130 |
131 |
132 | | ViT-S/16 |
133 | MOVi-E |
134 | 44.1 |
135 | 20.4 |
136 | model |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/evaluate/cocoStuff27_mask_visualize.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 | import torchvision.transforms as T
9 | from torchvision.transforms.functional import InterpolationMode
10 |
11 |
12 | label_to_color = {
13 | 0: (245, 245, 220), # accessory
14 | 1: (0, 100, 0), # animal
15 | 2: (178, 34, 34), # appliance
16 | 3: (0, 0, 139), # building
17 | 4: (148, 0, 211), # ceiling
18 | 5: (105, 105, 105), # electronic
19 | 6: (205, 92, 92), # floor
20 | 7: (244, 164, 96), # food
21 | 8: (245, 222, 179), # food-stuff
22 | 9: (75, 0, 130), # furniture
23 | 10: (138, 43, 226), # furniture-stuff
24 | 11: (72, 61, 139), # ground
25 | 12: (25, 25, 112), # indoor
26 | 13: (253, 245, 230), # kitchen
27 | 14: (47, 79, 79), # outdoor
28 | 15: (139, 0, 0), # person
29 | 16: (124, 252, 0), # plant
30 | 17: (210, 180, 140), # raw-material
31 | 18: (135, 206, 235), # sky
32 | 19: (85, 107, 47), # solid
33 | 20: (255, 105, 180), # sports
34 | 21: (210, 105, 30), # structural
35 | 22: (211, 211, 211), # textile
36 | 23: (184, 134, 11), # vehicle
37 | 24: (128, 128, 128), # wall
38 | 25: (32, 178, 170), # water
39 | 26: (189, 183, 107), # window
40 | 27: (255, 250, 250), # other
41 | }
42 |
43 | super_cat_to_id = {
44 | 'accessory': 0, 'animal': 1, 'appliance': 2, 'building': 3,
45 | 'ceiling': 4, 'electronic': 5, 'floor': 6, 'food': 7,
46 | 'food-stuff': 8, 'furniture': 9, 'furniture-stuff': 10, 'ground': 11,
47 | 'indoor': 12, 'kitchen': 13, 'outdoor': 14, 'person': 15,
48 | 'plant': 16, 'raw-material': 17, 'sky': 18, 'solid': 19,
49 | 'sports': 20, 'structural': 21, 'textile': 22, 'vehicle': 23,
50 | 'wall': 24, 'water': 25, 'window': 26,
51 | 'other': 27
52 | }
53 |
54 |
55 | def visual_mask(img_path, transforms, cat_id_map, RGB=False):
56 |
57 | mask = Image.open(img_path)
58 | mask = transforms(mask)
59 | save_path = img_path.replace('.png', '_RGB.jpg')
60 |
61 | # move 'id' labels from [0, 182] to [0,27] with 27=={182,255}
62 | # (182 is 'other' and 0 is things)
63 | mask *= 255
64 | assert torch.min(mask).item() >= 0
65 | mask[mask == 255] = 182
66 | assert torch.max(mask).item() <= 182
67 | for cat_id in torch.unique(mask):
68 | mask[mask == cat_id] = cat_id_map[cat_id.item()]
69 |
70 | assert torch.max(mask).item() <= 27
71 | assert torch.min(mask).item() >= 0
72 |
73 | mask = mask.squeeze(0).numpy().astype(int)
74 | mask = mask.astype(np.uint8)
75 |
76 | if not RGB:
77 | img = Image.fromarray(mask, 'L')
78 | plt.figure(figsize=(8, 8))
79 | plt.axis('off')
80 | plt.imshow(img)
81 | # plt.savefig('mask.png', bbox_inches='tight', pad_inches=0.0)
82 | plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0)
83 | plt.show()
84 | else:
85 | # visualize by configuring palette
86 | img = Image.fromarray(mask, 'L')
87 | img_p = img.convert('P')
88 | img_p.putpalette([rgb for pixel in label_to_color.values() for rgb in pixel])
89 |
90 | img_rgb = img_p.convert('RGB')
91 | plt.figure(figsize=(12, 4))
92 | plt.subplot(1, 3, 1), plt.imshow(img)
93 | plt.subplot(1, 3, 2), plt.imshow(img_p)
94 | plt.subplot(1, 3, 3), plt.imshow(img_rgb)
95 | plt.tight_layout(), plt.show()
96 | img_rgb.save(save_path)
97 |
98 |
99 | if __name__ == '__main__':
100 |
101 | root = './COCO/annotations/stuffthingmaps_trainval2017/'
102 | json_file = "stuffthing_2017.json"
103 | mask_name = 'val2017/000000512194.png'
104 |
105 | mask_transforms = T.Compose([T.Resize((448, 448), interpolation=InterpolationMode.NEAREST), T.ToTensor()])
106 |
107 | with open(os.path.join(root, json_file)) as f:
108 | an_json = json.load(f)
109 | all_cat = an_json['categories']
110 |
111 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat])
112 | super_cats.remove("other") # remove others from prediction targets as this is not semantic
113 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))}
114 | super_cat_to_id["other"] = 27 # ignore_index
115 | # Align 'id' labels: PNG_label = GT_label - 1
116 | cat_id_map = {(cat_dict['id'] - 1): super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat}
117 |
118 | visual_mask(os.path.join(root, mask_name), mask_transforms, cat_id_map, RGB=False)
119 |
--------------------------------------------------------------------------------
/evaluate/proto_similarity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import cv2
5 | import random
6 | import colorsys
7 |
8 | import skimage.io
9 | from skimage.measure import find_contours
10 | from matplotlib.patches import Polygon
11 | import torch
12 | import torch.nn as nn
13 | import torchvision
14 | from torchvision import transforms as pth_transforms
15 | from torchvision.transforms import GaussianBlur
16 | import torch.nn.functional as F
17 | import numpy as np
18 | from PIL import Image
19 | from skimage.measure import label
20 | from matplotlib import pyplot as plt
21 |
22 | from model.align_model import AlignSegmentor
23 | from utils import neq_load_external
24 |
25 |
26 | def norm(t):
27 | return F.normalize(t, dim=-1, eps=1e-10)
28 |
29 |
30 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \
31 | -> torch.Tensor:
32 | """
33 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes
34 | components smaller than 3 pixels.
35 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the
36 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions
37 | so that we can use the foreground mask during training time.
38 | :param attentions: torch 4D-Tensor containing the averaged attentions
39 | :param spatial_res: spatial resolution of the attention map
40 | :param threshold: the percentage of mass to keep as foreground.
41 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring.
42 | :return: the foreground mask obtained from the ViT's attention.
43 | """
44 | # Blur attentions
45 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions)
46 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2)
47 | # Keep threshold% of mass
48 | val, idx = torch.sort(attentions)
49 | val /= torch.sum(val, dim=-1, keepdim=True)
50 | cumval = torch.cumsum(val, dim=-1)
51 | th_attn = cumval > (1 - threshold)
52 | idx2 = torch.argsort(idx)
53 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0])
54 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float()
55 | # Remove components with less than 3 pixels
56 | for j, th_att in enumerate(th_attn):
57 | labelled = label(th_att.cpu().numpy())
58 | for k in range(1, np.max(labelled) + 1):
59 | mask = labelled == k
60 | if np.sum(mask) <= 2:
61 | th_attn[j, 0][mask] = 0
62 | return th_attn
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser('Evaluate segmentation on pretrained model')
67 | parser.add_argument('--pretrained_weights', default='./epoch10.pth',
68 | type=str, help="Path to pretrained weights to load.")
69 | parser.add_argument("--image_path", default='',
70 | type=str, help="Path of the image to load.")
71 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
72 | parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
73 | parser.add_argument('--output_dir', default='./outputs/', help='Path where to save visualizations.')
74 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
75 | obtained by thresholding the self-attention maps to keep xx% of the mass.""")
76 | args = parser.parse_args()
77 |
78 | device = torch.device("cpu")
79 | # build model
80 | model = AlignSegmentor(arch='vit_small',
81 | patch_size=16,
82 | embed_dim=384,
83 | hidden_dim=384,
84 | num_heads=2,
85 | num_queries=5,
86 | nmb_crops=[1, 0],
87 | num_decode_layers=1,
88 | last_self_attention=True)
89 |
90 | # set model to eval mode
91 | for p in model.parameters():
92 | p.requires_grad = False
93 | model.eval()
94 | model.to(device)
95 |
96 | # load pretrained weights
97 | if os.path.isfile(args.pretrained_weights):
98 | pratrained_model = torch.load(args.pretrained_weights, map_location="cpu")
99 | msg = model.load_state_dict(pratrained_model['state_dict'], strict=False)
100 | print(msg)
101 | else:
102 | print('no pretrained pth found!')
103 |
104 | queries = model.clsQueries.weight
105 |
106 | prototypes = torch.load('../log_tmp/prototypes21.pth')
107 | prototypes = prototypes.to(device)
108 | # calculate query assignment score
109 | sim_query_proto = norm(queries) @ norm(prototypes).T
110 | sim_query_proto = sim_query_proto.clamp(min=0.0)
111 |
112 | for i in range(sim_query_proto.size(0)):
113 | print('Proto', i, '=', sim_query_proto[i]*10)
114 |
--------------------------------------------------------------------------------
/data/movi_data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from typing import Optional, Callable
4 | from PIL import Image
5 | from pathlib import Path
6 | from torch.utils.data import DataLoader
7 | from torchvision.datasets import VisionDataset
8 | from typing import Tuple, Any
9 |
10 |
11 | class MOViDataModule:
12 |
13 | def __init__(self,
14 | data_dir: str,
15 | dataset_name: str,
16 | train_split: str,
17 | val_split: str,
18 | train_image_transform: Optional[Callable],
19 | val_image_transform: Optional[Callable],
20 | val_target_transform: Optional[Callable],
21 | batch_size: int,
22 | num_workers: int,
23 | shuffle: bool = True,
24 | return_masks: bool = False,
25 | drop_last: bool = True):
26 | """
27 | Data module for MOVi data.
28 | If return_masks is set train_image_transform should be callable with imgs and masks or None.
29 | """
30 | super().__init__()
31 | self.root = os.path.join(data_dir, dataset_name)
32 | self.train_split = train_split
33 | self.val_split = val_split
34 | self.batch_size = batch_size
35 | self.num_workers = num_workers
36 | self.train_image_transform = train_image_transform
37 | self.val_image_transform = val_image_transform
38 | self.val_target_transform = val_target_transform
39 | self.shuffle = shuffle
40 | self.drop_last = drop_last
41 | self.return_masks = return_masks
42 |
43 | # Set up datasets in __init__ as we need to know the number of samples to init cosine lr schedules
44 | self.movi_train = MOViDataset(root=self.root, image_set=train_split, transforms=self.train_image_transform,
45 | return_masks=self.return_masks)
46 | self.movi_val = MOViDataset(root=self.root, image_set=val_split, transform=self.val_image_transform,
47 | target_transform=self.val_target_transform)
48 | print('--- Loaded ' + dataset_name + ' with Train %d, Val %d ---' % (len(self.movi_train), len(self.movi_val)))
49 |
50 | def __len__(self):
51 | return len(self.movi_train)
52 |
53 | def train_dataloader(self):
54 | return DataLoader(self.movi_train, batch_size=self.batch_size,
55 | shuffle=self.shuffle, num_workers=self.num_workers,
56 | drop_last=self.drop_last, pin_memory=True)
57 |
58 | def val_dataloader(self):
59 | return DataLoader(self.movi_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
60 | drop_last=self.drop_last, pin_memory=True)
61 |
62 |
63 | class MOViDataset(VisionDataset):
64 |
65 | def __init__(
66 | self,
67 | root: str,
68 | image_set: str = "frames",
69 | transform: Optional[Callable] = None,
70 | target_transform: Optional[Callable] = None,
71 | transforms: Optional[Callable] = None,
72 | return_masks: bool = False
73 | ):
74 | super(MOViDataset, self).__init__(root, transforms, transform, target_transform)
75 | self.image_set = image_set
76 | if self.image_set == "frames": # set for training
77 | img_folder = "frames"
78 | elif self.image_set == "images": # set for validation
79 | img_folder = "images"
80 | else:
81 | raise ValueError(f"No support for image set {self.image_set}")
82 | image_dir = os.path.join(root, img_folder)
83 | seg_dir = os.path.join(root, 'masks')
84 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir) or not os.path.isdir(root):
85 | raise RuntimeError('Dataset not found or corrupted.')
86 |
87 | self.images = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]
88 | self.masks = [os.path.join(seg_dir, x) for x in os.listdir(seg_dir)]
89 | self.return_masks = return_masks
90 |
91 | assert all([Path(f).is_file() for f in self.masks]) and all([Path(f).is_file() for f in self.images])
92 |
93 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
94 | img = Image.open(self.images[index]).convert('RGB')
95 | if self.image_set == "images": # for validation
96 | # print('img = ', self.images[index])
97 | mask = Image.open(self.masks[index])
98 | print("image: ", self.images[index])
99 | if self.transforms:
100 | img, mask = self.transforms(img, mask)
101 | return img, mask
102 | elif self.image_set == "frames": # for training
103 | if self.transforms:
104 | if self.return_masks:
105 | mask = Image.open(self.masks[index])
106 | res = self.transforms(img, mask)
107 | else:
108 | res = self.transforms(img)
109 | return res
110 | return img
111 |
112 | def __len__(self) -> int:
113 | return len(self.images)
114 |
--------------------------------------------------------------------------------
/evaluate/object_evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import os
4 | import sys
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision.transforms as T
9 |
10 | import yaml
11 | from torchvision.transforms.functional import InterpolationMode
12 |
13 | from data.movi_data import MOViDataModule
14 | from model.align_model import AlignSegmentor
15 | from eval_utils import ARIMetric, AverageBestOverlapMetric
16 |
17 |
18 | def norm(t):
19 | return F.normalize(t, dim=-1, eps=1e-10)
20 |
21 |
22 | def eval_objectmasks():
23 | with open(args.config_path) as file:
24 | config = yaml.safe_load(file.read())
25 | # print('Config: ', config)
26 |
27 | data_config = config['data']
28 | val_config = config['val']
29 | input_size = data_config["size_crops"]
30 | torch.manual_seed(val_config['seed'])
31 |
32 | # Init data and transforms
33 | val_image_transforms = T.Compose([T.Resize((input_size, input_size)),
34 | T.ToTensor(),
35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
36 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST),
37 | T.ToTensor()])
38 |
39 | data_dir = data_config["data_dir"]
40 | dataset_name = data_config["dataset_name"]
41 | if "movi" in dataset_name:
42 | ignore_index = 0
43 | num_classes = 17
44 | data_module = MOViDataModule(data_dir=data_dir,
45 | dataset_name=data_config['dataset_name'],
46 | batch_size=val_config["batch_size"],
47 | return_masks=True,
48 | drop_last=True,
49 | num_workers=config["num_workers"],
50 | train_split="frames",
51 | val_split="images",
52 | train_image_transform=None,
53 | val_image_transform=val_image_transforms,
54 | val_target_transform=val_target_transforms)
55 | else:
56 | raise ValueError(f"{dataset_name} not supported")
57 |
58 | # Init method
59 | patch_size = val_config["patch_size"]
60 | spatial_res = input_size / patch_size
61 | num_proto = val_config['num_queries']
62 | assert spatial_res.is_integer()
63 | model = AlignSegmentor(arch=val_config['arch'],
64 | patch_size=val_config['patch_size'],
65 | embed_dim=val_config['embed_dim'],
66 | hidden_dim=val_config['hidden_dim'],
67 | num_heads=val_config['decoder_num_heads'],
68 | num_queries=val_config['num_queries'],
69 | num_decode_layers=val_config['num_decode_layers'],
70 | last_self_attention=val_config['last_self_attention'])
71 |
72 | # set model to eval mode
73 | for p in model.parameters():
74 | p.requires_grad = False
75 | model.eval()
76 |
77 | # load pretrained weights
78 | if val_config["checkpoint"] is not None:
79 | checkpoint = torch.load(val_config["checkpoint"])
80 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True)
81 | print(msg)
82 | else:
83 | print('no pretrained pth found!')
84 |
85 | dataloader = data_module.val_dataloader()
86 | ARI_metric = ARIMetric()
87 | BO_metric = AverageBestOverlapMetric()
88 |
89 | # Calculate IoU for each image individually
90 | for idx, batch in enumerate(dataloader):
91 | imgs, masks = batch
92 | B = imgs.size(0)
93 | # assert B == 1 # image has to be evaluated individually
94 | all_queries, tokens, _, _, res, _ = model([imgs]) # tokens=(1,N,dim)
95 |
96 | # calculate token assignment
97 | token_cls = torch.einsum("bnc,bqc->bnq", norm(tokens), norm(all_queries[0]))
98 | token_cls = torch.softmax(token_cls, dim=-1)
99 | token_cls = token_cls.reshape(B, res, res, -1).permute(0, 3, 1, 2) # (1,num_query,res,res)
100 |
101 | # downsample masks / upsample preds to masks_eval_size
102 | preds = F.interpolate(token_cls, size=(val_config['mask_eval_size'], val_config['mask_eval_size']),
103 | mode='bilinear')
104 | masks *= 255
105 | if masks.size(3) != val_config['mask_eval_size']:
106 | masks = F.interpolate(masks, size=(val_config['mask_eval_size'], val_config['mask_eval_size']),
107 | mode='nearest')
108 |
109 | # turn masks to one-hot
110 | masks = masks.squeeze(dim=1).reshape(B, -1)
111 | masks = masks.long()
112 | num_classes = masks.max().item() + 1
113 | masks = torch.nn.functional.one_hot(masks, num_classes)
114 | masks = masks.permute(0, 2, 1).reshape(B, num_classes,
115 | val_config['mask_eval_size'],
116 | val_config['mask_eval_size']) # to (B, K, H, W)
117 |
118 | ARI_metric.update(preds, masks)
119 | BO_metric.update(preds, masks)
120 | # sys.exit(1)
121 |
122 | ARI_metric.compute()
123 | BO_metric.compute()
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument('--config_path', default='../configs/eval_movi_config.yml', type=str)
129 |
130 | args = parser.parse_args()
131 |
132 | eval_objectmasks()
133 |
--------------------------------------------------------------------------------
/data/voc_data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from typing import Optional, Callable
4 | from PIL import Image
5 | from pathlib import Path
6 | from torch.utils.data import DataLoader
7 | from torchvision.datasets import VisionDataset
8 | from typing import Tuple, Any
9 |
10 |
11 | class VOCDataModule:
12 |
13 | CLASS_IDX_TO_NAME = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
14 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
15 | 'train', 'tvmonitor']
16 |
17 | def __init__(self,
18 | data_dir: str,
19 | train_split: str,
20 | val_split: str,
21 | train_image_transform: Optional[Callable],
22 | val_image_transform: Optional[Callable],
23 | val_target_transform: Optional[Callable],
24 | batch_size: int,
25 | num_workers: int,
26 | shuffle: bool = True,
27 | return_masks: bool = False,
28 | drop_last: bool = True):
29 | """
30 | Data module for PVOC data. "trainaug" and "train" are valid train_splits.
31 | If return_masks is set train_image_transform should be callable with imgs and masks or None.
32 | """
33 | super().__init__()
34 | self.root = os.path.join(data_dir, "PVOC")
35 | self.train_split = train_split
36 | self.val_split = val_split
37 | self.batch_size = batch_size
38 | self.num_workers = num_workers
39 | self.train_image_transform = train_image_transform
40 | self.val_image_transform = val_image_transform
41 | self.val_target_transform = val_target_transform
42 | self.shuffle = shuffle
43 | self.drop_last = drop_last
44 | self.return_masks = return_masks
45 |
46 | # Set up datasets in __init__ as we need to know the number of samples to init cosine lr schedules
47 | assert train_split == "trainaug" or train_split == "train"
48 | self.voc_train = VOCDataset(root=self.root, image_set=train_split, transforms=self.train_image_transform,
49 | return_masks=self.return_masks)
50 | self.voc_val = VOCDataset(root=self.root, image_set=val_split, transform=self.val_image_transform,
51 | target_transform=self.val_target_transform)
52 | print('--- loaded VOC with Train %d, Val %d ---' % (len(self.voc_train), len(self.voc_val)))
53 |
54 | def __len__(self):
55 | return len(self.voc_train)
56 |
57 | def class_id_to_name(self, i: int):
58 | return self.CLASS_IDX_TO_NAME[i]
59 |
60 | def train_dataloader(self):
61 | return DataLoader(self.voc_train, batch_size=self.batch_size,
62 | shuffle=self.shuffle, num_workers=self.num_workers,
63 | drop_last=self.drop_last, pin_memory=True)
64 |
65 | def val_dataloader(self):
66 | return DataLoader(self.voc_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
67 | drop_last=self.drop_last, pin_memory=True)
68 |
69 |
70 | class VOCDataset(VisionDataset):
71 |
72 | def __init__(
73 | self,
74 | root: str,
75 | image_set: str = "trainaug",
76 | transform: Optional[Callable] = None,
77 | target_transform: Optional[Callable] = None,
78 | transforms: Optional[Callable] = None,
79 | return_masks: bool = False
80 | ):
81 | super(VOCDataset, self).__init__(root, transforms, transform, target_transform)
82 | self.image_set = image_set
83 | if self.image_set == "trainaug" or self.image_set == "train":
84 | seg_folder = "SegmentationClassAug"
85 | elif self.image_set == "val":
86 | seg_folder = "SegmentationClass"
87 | else:
88 | raise ValueError(f"No support for image set {self.image_set}")
89 | seg_dir = os.path.join(root, seg_folder)
90 | image_dir = os.path.join(root, 'images')
91 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir) or not os.path.isdir(root):
92 | raise RuntimeError('Dataset not found or corrupted.')
93 | splits_dir = os.path.join(root, 'sets')
94 | split_f = os.path.join(splits_dir, self.image_set.rstrip('\n') + '.txt')
95 |
96 | with open(os.path.join(split_f), "r") as f:
97 | file_names = [x.strip() for x in f.readlines()]
98 |
99 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
100 | self.masks = [os.path.join(seg_dir, x + ".png") for x in file_names]
101 | self.return_masks = return_masks
102 |
103 | assert all([Path(f).is_file() for f in self.masks]) and all([Path(f).is_file() for f in self.images])
104 |
105 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
106 | img = Image.open(self.images[index]).convert('RGB')
107 | if self.image_set == "val":
108 | # print('img = ', self.images[index])
109 | mask = Image.open(self.masks[index])
110 | if self.transforms:
111 | img, mask = self.transforms(img, mask)
112 | return img, mask
113 | elif "train" in self.image_set:
114 | if self.transforms:
115 | if self.return_masks:
116 | mask = Image.open(self.masks[index])
117 | res = self.transforms(img, mask)
118 | else:
119 | res = self.transforms(img)
120 | return res
121 | return img
122 |
123 | def __len__(self) -> int:
124 | return len(self.images)
125 |
--------------------------------------------------------------------------------
/model/criterion.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from skimage.measure import label
8 |
9 | from utils import calc_topk_accuracy
10 |
11 |
12 | def norm(t):
13 | return F.normalize(t, dim=-1, eps=1e-10)
14 |
15 |
16 | def query_ce_loss(gc_query, lc_query, num_queries, temperature=1):
17 | B = gc_query.size(0)
18 | N = 2 * num_queries
19 | criterion = nn.CrossEntropyLoss()
20 | mask = mask_correlated_samples(num_queries, gc_query.device)
21 |
22 | # calculate ce loss for each query set in B
23 | loss = top1_avg = 0
24 | labels = torch.zeros(N, device=gc_query.device).long()
25 | for i in range(B):
26 | z = torch.cat((gc_query[i], lc_query[i]), dim=0)
27 |
28 | sim = torch.matmul(z, z.T) / temperature
29 | sim_gc_lc = torch.diag(sim, num_queries)
30 | sim_lc_gc = torch.diag(sim, -num_queries)
31 |
32 | positive_samples = torch.cat((sim_gc_lc, sim_lc_gc), dim=0).reshape(N, 1)
33 | negative_samples = sim[mask].reshape(N, -1)
34 |
35 | logits = torch.cat((positive_samples, negative_samples), dim=1)
36 | ce_loss = criterion(logits, labels)
37 |
38 | top1 = calc_topk_accuracy(logits, labels, (1,))
39 |
40 | loss += ce_loss
41 | top1_avg += top1[0]
42 |
43 | return loss / B, top1_avg / B
44 |
45 |
46 | def mask_correlated_samples(num_seq, device):
47 | N = 2 * num_seq
48 | mask = torch.ones((N, N), device=device)
49 | mask = mask.fill_diagonal_(0)
50 | for i in range(num_seq):
51 | mask[i, num_seq + i] = 0
52 | mask[num_seq + i, i] = 0
53 | mask = mask.bool()
54 | return mask
55 |
56 |
57 | class AlignCriterion(nn.Module):
58 | def __init__(self, patch_size=16,
59 | num_queries=5,
60 | nmb_crops=(1, 1),
61 | roi_align_kernel_size=7,
62 | ce_temperature=1,
63 | negative_pressure=0.1,
64 | last_self_attention=True):
65 | super(AlignCriterion, self).__init__()
66 | self.patch_size = patch_size
67 | self.num_queries = num_queries
68 | self.nmb_crops = nmb_crops
69 | self.roi_align_kernel_size = roi_align_kernel_size
70 | self.ce_temperature = ce_temperature
71 | self.negative_pressure = negative_pressure
72 | self.last_self_attention = last_self_attention
73 |
74 | def forward(self, results, bboxes):
75 | all_queries, gc_output, lc_output, attn_hard, gc_spatial_res, lc_spatial_res = results
76 | B = gc_output.size(0)
77 |
78 | # prepare foreground mask
79 | mask = attn_hard.reshape(B*sum(self.nmb_crops), -1)
80 | mask = mask.int()
81 | mask_gc, masks_lc = mask[:B * self.nmb_crops[0]], mask[B * self.nmb_crops[0]:]
82 |
83 | loss = 0
84 | '''
85 | 1. Compute patch correlation to assignment similarity alignment loss
86 | -- Compute similarity between Queries and spatial_tokens, and align to patch correlation
87 | -- use attention map as foreground hint to mask correlation matrix
88 | -- assuming there is ONLY 1 global crop
89 | '''
90 | # compute patch correlation between gc and lc, use as assignment target later
91 | with torch.no_grad():
92 | gclc_correlations = []
93 | masks_gc_lc = []
94 | mask_gc = mask_gc.repeat(1, lc_spatial_res**2).reshape(B, lc_spatial_res**2, -1)
95 | mask_gc = mask_gc.transpose(1, 2) # (B,n,m)
96 | for i in range(self.nmb_crops[-1]):
97 | # compute cosine similarity
98 | correlation = torch.einsum("bnc,bmc->bnm", norm(gc_output), norm(lc_output[:, i]))
99 | # spatial centering for better recognizing small objects
100 | old_mean = correlation.mean()
101 | correlation -= correlation.mean(dim=-1, keepdim=True)
102 | correlation = correlation - correlation.mean() + old_mean
103 | gclc_correlations.append(correlation)
104 |
105 | # compute gc-lc foreground intersection mask
106 | mask_lc_ = masks_lc[i*B:(i+1)*B] # (B, m)
107 | mask_lc_ = mask_lc_.repeat(1, gc_spatial_res**2).reshape(B, gc_spatial_res**2, -1) # (B,n,m)
108 | mask_gc_lc_ = mask_gc * mask_lc_
109 | masks_gc_lc.append(mask_gc_lc_.bool())
110 |
111 | # compute gc and lc token assignment
112 | gc_token_assign = torch.einsum("bnc,bqc->bnq", norm(gc_output), norm(all_queries[0]))
113 |
114 | gclc_cor_loss = 0
115 | lc_assigns_detached = []
116 | for i in range(self.nmb_crops[-1]):
117 | lc_token_assign = torch.einsum("bmc,bqc->bmq", norm(lc_output[:, i]), norm(all_queries[i+1]))
118 | # store lc intersection assignment
119 | lc_tmp = torch.clone(lc_token_assign.detach())
120 | lc_tmp = lc_tmp.reshape(B, lc_spatial_res, lc_spatial_res, -1).permute(0, 3, 1, 2) # (B, num_queries, 6, 6)
121 | lc_assigns_detached.append(lc_tmp)
122 |
123 | # note here correlation value is not cosine similarity
124 | gc_token_assign_ = gc_token_assign.clamp(min=0.0)
125 | lc_token_assign_ = lc_token_assign.clamp(min=0.0)
126 | gclc_assign_cor = torch.einsum("bnq,bmq->bnm", gc_token_assign_.softmax(dim=-1), lc_token_assign_.softmax(dim=-1))
127 | # align patch assignment similarity to feature correlation
128 | cor_align_loss = (- gclc_assign_cor * (gclc_correlations[i] - self.negative_pressure))[masks_gc_lc[i]]
129 | gclc_cor_loss += 0.15*cor_align_loss.sum()
130 |
131 | loss += gclc_cor_loss / self.nmb_crops[-1]
132 |
133 | '''
134 | 2. Compute Global-Local Query Alignment loss
135 | -- use cross-entropy loss to align queries, and make each query different
136 | '''
137 | query_align_loss = 0
138 | for i in range(self.nmb_crops[-1]):
139 | tmp_loss, top1 = query_ce_loss(norm(all_queries[0]), norm(all_queries[i + 1]), self.num_queries,
140 | self.ce_temperature)
141 | query_align_loss += tmp_loss
142 |
143 | loss += query_align_loss / self.nmb_crops[-1]
144 |
145 | return loss
146 |
--------------------------------------------------------------------------------
/evaluate/sup_overcluster.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import os
4 | import torch
5 | import torch.nn.functional as F
6 | import torchvision.transforms as T
7 |
8 | import yaml
9 | from torchvision.transforms.functional import InterpolationMode
10 |
11 | from data.coco_data import CocoDataModule
12 | from data.voc_data import VOCDataModule
13 | from model.align_model import AlignSegmentor
14 | from utils import PredsmIoU
15 |
16 |
17 | def norm(t):
18 | return F.normalize(t, dim=-1, eps=1e-10)
19 |
20 |
21 | def eval_overcluster():
22 | with open(args.config_path) as file:
23 | config = yaml.safe_load(file.read())
24 | # print('Config: ', config)
25 |
26 | data_config = config['data']
27 | val_config = config['val']
28 | input_size = data_config["size_crops"]
29 | torch.manual_seed(val_config['seed'])
30 | torch.cuda.manual_seed_all(val_config['seed'])
31 |
32 | # Init data and transforms
33 | val_image_transforms = T.Compose([T.Resize((input_size, input_size)),
34 | T.ToTensor(),
35 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
36 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST),
37 | T.ToTensor()])
38 |
39 | data_dir = data_config["data_dir"]
40 | dataset_name = data_config["dataset_name"]
41 | if dataset_name == "voc":
42 | ignore_index = 255
43 | num_classes = 21
44 | data_module = VOCDataModule(batch_size=val_config["batch_size"],
45 | return_masks=True,
46 | num_workers=config["num_workers"],
47 | train_split="trainaug",
48 | val_split="val",
49 | data_dir=data_dir,
50 | train_image_transform=None,
51 | drop_last=True,
52 | val_image_transform=val_image_transforms,
53 | val_target_transform=val_target_transforms)
54 | elif "coco" in dataset_name:
55 | assert len(dataset_name.split("-")) == 2
56 | mask_type = dataset_name.split("-")[-1]
57 | assert mask_type in ["all", "stuff", "thing"]
58 | if mask_type == "all":
59 | num_classes = 27
60 | elif mask_type == "stuff":
61 | num_classes = 15
62 | elif mask_type == "thing":
63 | num_classes = 12
64 | ignore_index = 255
65 | file_list = os.listdir(os.path.join(data_dir, "images", "train2017"))
66 | file_list_val = os.listdir(os.path.join(data_dir, "images", "val2017"))
67 | # random.shuffle(file_list_val)
68 | data_module = CocoDataModule(batch_size=val_config["batch_size"],
69 | num_workers=config["num_workers"],
70 | file_list=file_list,
71 | data_dir=data_dir,
72 | file_list_val=file_list_val,
73 | mask_type=mask_type,
74 | train_transforms=None,
75 | val_transforms=val_image_transforms,
76 | val_target_transforms=val_target_transforms)
77 | elif dataset_name == "ade20k":
78 | num_classes = 111
79 | ignore_index = 255
80 | val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST)])
81 | data_module = None
82 | else:
83 | raise ValueError(f"{dataset_name} not supported")
84 |
85 | # Init method
86 | patch_size = val_config["patch_size"]
87 | spatial_res = input_size / patch_size
88 | assert spatial_res.is_integer()
89 | model = AlignSegmentor(arch=val_config['arch'],
90 | patch_size=val_config['patch_size'],
91 | embed_dim=val_config['embed_dim'],
92 | hidden_dim=val_config['hidden_dim'],
93 | num_heads=val_config['decoder_num_heads'],
94 | num_queries=val_config['num_queries'],
95 | num_decode_layers=val_config['num_decode_layers'],
96 | last_self_attention=val_config['last_self_attention'])
97 |
98 | # set model to eval mode
99 | for p in model.parameters():
100 | p.requires_grad = False
101 | model.eval()
102 | model.to(cuda)
103 |
104 | # load pretrained weights
105 | if val_config["checkpoint"] is not None:
106 | checkpoint = torch.load(val_config["checkpoint"])
107 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True)
108 | print(msg)
109 | else:
110 | print('no pretrained pth found!')
111 |
112 | dataloader = data_module.val_dataloader()
113 | metric = PredsmIoU(val_config['num_queries'], num_classes)
114 |
115 | # Calculate IoU for each image individually
116 | for idx, batch in enumerate(dataloader):
117 | imgs, masks = batch
118 | B = imgs.size(0)
119 | assert B == 1 # image has to be evaluated individually
120 | all_queries, tokens, _, _, res, _ = model([imgs.to(cuda)]) # tokens=(1,N,dim)
121 |
122 | # calculate token assignment
123 | token_cls = torch.einsum("bnc,bqc->bnq", norm(tokens), norm(all_queries[0]))
124 | token_cls = torch.softmax(token_cls, dim=-1)
125 | token_cls = token_cls.reshape(B, res, res, -1).permute(0, 3, 1, 2) # (1,num_query,res,res)
126 | token_cls = token_cls.max(dim=1, keepdim=True)[1].float() # (1,1,res,res)
127 |
128 | # downsample masks/upsample preds to masks_eval_size
129 | preds = F.interpolate(token_cls, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), mode='nearest')
130 | masks *= 255
131 | if masks.size(3) != val_config['mask_eval_size']:
132 | masks = F.interpolate(masks, size=(val_config['mask_eval_size'], val_config['mask_eval_size']), mode='nearest')
133 |
134 | metric.update(masks[masks != ignore_index], preds[masks != ignore_index])
135 | # sys.exit(1)
136 |
137 | metric.compute()
138 |
139 |
140 | if __name__ == "__main__":
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('--config_path', default='../configs/eval_voc_config.yml', type=str)
143 | parser.add_argument('--gpu', default='0', type=str)
144 |
145 | args = parser.parse_args()
146 |
147 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
148 | cuda = torch.device('cuda')
149 | eval_overcluster()
150 |
--------------------------------------------------------------------------------
/model/align_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from model.vit import vit_small, vit_base, vit_large, trunc_normal_
6 | from utils import process_attentions
7 |
8 |
9 | class CrossAttentionLayer(nn.Module):
10 |
11 | def __init__(self, d_model, nhead, dropout=0.0):
12 | super().__init__()
13 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
14 |
15 | self.norm = nn.LayerNorm(d_model)
16 | self.dropout = nn.Dropout(dropout)
17 |
18 | # self._reset_parameters()
19 |
20 | def _reset_parameters(self):
21 | for p in self.parameters():
22 | if p.dim() > 1:
23 | nn.init.xavier_uniform_(p)
24 |
25 | def with_pos_embed(self, tensor, pos):
26 | return tensor if pos is None else tensor + pos
27 |
28 | def forward(self, tgt, memory,
29 | pos=None,
30 | query_pos=None):
31 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
32 | key=self.with_pos_embed(memory, pos), value=memory)[0]
33 | tgt = tgt + self.dropout(tgt2)
34 | tgt = self.norm(tgt)
35 | return tgt
36 |
37 |
38 | class SelfAttentionLayer(nn.Module):
39 |
40 | def __init__(self, d_model, nhead, dropout=0.0):
41 | super().__init__()
42 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
43 |
44 | self.norm = nn.LayerNorm(d_model)
45 | self.dropout = nn.Dropout(dropout)
46 |
47 | # self._reset_parameters()
48 |
49 | def _reset_parameters(self):
50 | for p in self.parameters():
51 | if p.dim() > 1:
52 | nn.init.xavier_uniform_(p)
53 |
54 | def with_pos_embed(self, tensor, pos):
55 | return tensor if pos is None else tensor + pos
56 |
57 | def forward(self, tgt, query_pos=None):
58 | q = k = self.with_pos_embed(tgt, query_pos)
59 | tgt2 = self.self_attn(q, k, value=tgt)[0]
60 | tgt = tgt + self.dropout(tgt2)
61 | tgt = self.norm(tgt)
62 |
63 | return tgt
64 |
65 |
66 | class FFNLayer(nn.Module):
67 |
68 | def __init__(self, d_model, dim_feedforward=1024, dropout=0.0):
69 | super().__init__()
70 | # Implementation of Feedforward model
71 | self.linear1 = nn.Linear(d_model, dim_feedforward)
72 | self.dropout = nn.Dropout(dropout)
73 | self.linear2 = nn.Linear(dim_feedforward, d_model)
74 |
75 | self.norm = nn.LayerNorm(d_model)
76 | self.activation = F.relu
77 |
78 | self.apply(self._init_weights)
79 |
80 | def _init_weights(self, m):
81 | if isinstance(m, nn.Linear):
82 | nn.init.trunc_normal_(m.weight, std=.02)
83 | if isinstance(m, nn.Linear) and m.bias is not None:
84 | nn.init.constant_(m.bias, 0)
85 |
86 | def forward(self, tgt):
87 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
88 | tgt = tgt + self.dropout(tgt2)
89 | tgt = self.norm(tgt)
90 | return tgt
91 |
92 |
93 | class MLP(nn.Module):
94 |
95 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):
96 | super().__init__()
97 | self.num_layers = num_layers
98 | h = [hidden_dim] * (num_layers - 1)
99 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
100 |
101 | def forward(self, x):
102 | for i, layer in enumerate(self.layers):
103 | x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x)
104 | return x
105 |
106 |
107 | class AlignSegmentor(nn.Module):
108 |
109 | def __init__(self, arch='vit_small',
110 | patch_size=16,
111 | embed_dim=384,
112 | hidden_dim=384,
113 | num_heads=4,
114 | num_queries=21,
115 | nmb_crops=(1, 0),
116 | num_decode_layers=1,
117 | last_self_attention=True):
118 | super(AlignSegmentor, self).__init__()
119 | self.patch_size = patch_size
120 | self.embed_dim = embed_dim
121 | self.hidden_dim = hidden_dim
122 | self.nmb_crops = nmb_crops
123 | self.num_decode_layers = num_decode_layers
124 | self.last_self_attention = last_self_attention
125 |
126 | # Initialize model
127 | if arch == 'vit_small':
128 | self.backbone = vit_small(patch_size=patch_size)
129 | elif arch == 'vit_base':
130 | self.backbone = vit_base(patch_size=patch_size)
131 | elif arch == 'vit_large':
132 | self.backbone = vit_large(patch_size=patch_size)
133 | else:
134 | raise ValueError(f"{self.arch} is not supported")
135 |
136 | # learnable CLS queries and/or positional queries
137 | self.clsQueries = nn.Embedding(num_queries, embed_dim)
138 |
139 | # simple Transformer Decoder with num_decoder_layers
140 | self.decoder_cross_attention_layers = nn.ModuleList()
141 | self.decoder_self_attention_layers = nn.ModuleList()
142 | self.decoder_ffn_layers = nn.ModuleList()
143 | for _ in range(self.num_decode_layers):
144 | self.decoder_cross_attention_layers.append(
145 | CrossAttentionLayer(d_model=embed_dim, nhead=num_heads)
146 | )
147 | self.decoder_self_attention_layers.append(
148 | SelfAttentionLayer(d_model=embed_dim, nhead=num_heads)
149 | )
150 | self.decoder_ffn_layers.append(
151 | FFNLayer(d_model=embed_dim, dim_feedforward=hidden_dim)
152 | )
153 |
154 | self.apply(self._init_weights)
155 |
156 | def _init_weights(self, m):
157 | if isinstance(m, nn.Linear):
158 | trunc_normal_(m.weight, std=.02)
159 | if m.bias is not None:
160 | nn.init.constant_(m.bias, 0)
161 | elif isinstance(m, nn.LayerNorm):
162 | nn.init.constant_(m.bias, 0)
163 | nn.init.constant_(m.weight, 1.0)
164 |
165 | def set_clsQuery(self, prototypes):
166 | # initialize clsQueries with generated prototypes of [num_queries, embed_dim]
167 | self.clsQueries.weight = nn.Parameter(prototypes)
168 |
169 | def forward(self, inputs, threshold=0.6):
170 | # inputs is a list of crop images
171 | B = inputs[0].size(0)
172 |
173 | # repeat query for batch use, (B, num_queries, embed_dim)
174 | outQueries = self.clsQueries.weight.unsqueeze(0).repeat(B, 1, 1)
175 | posQueries = pos = None
176 |
177 | # Extract feature
178 | outputs = self.backbone(inputs, self.nmb_crops, self.last_self_attention)
179 | if self.last_self_attention:
180 | outputs, attentions = outputs # outputs=[B*N(196+36), embed_dim], attentions(only global)=[B, heads, 196]
181 |
182 | # calculate gc and lc resolutions. Split output in gc and lc embeddings
183 | gc_res_w = inputs[0].size(2) / self.patch_size
184 | gc_res_h = inputs[0].size(3) / self.patch_size
185 | assert gc_res_w.is_integer() and gc_res_w.is_integer(), "Image dims need to be divisible by patch size"
186 | assert gc_res_w == gc_res_h, f"Only supporting square images not {inputs[0].size(2)}x{inputs[0].size(3)}"
187 | gc_spatial_res = int(gc_res_w)
188 | lc_res_w = inputs[-1].size(2) / self.patch_size
189 | assert lc_res_w.is_integer(), "Image dims need to be divisible by patch size"
190 | lc_spatial_res = int(lc_res_w)
191 | gc_spatial_output, lc_spatial_output = outputs[:B * self.nmb_crops[0] * gc_spatial_res ** 2], \
192 | outputs[B * self.nmb_crops[0] * gc_spatial_res ** 2:]
193 | # (B*N, C) -> (B, N, C)
194 | gc_spatial_output = gc_spatial_output.reshape(B, -1, self.embed_dim)
195 | if self.nmb_crops[-1] != 0:
196 | lc_spatial_output = lc_spatial_output.reshape(B, self.nmb_crops[-1], lc_spatial_res**2, self.embed_dim)
197 |
198 | # merge attention heads and threshold attentions
199 | attn_hard = None
200 | if self.last_self_attention:
201 | attn_smooth = sum(attentions[:, i] * 1 / attentions.size(1) for i in range(attentions.size(1)))
202 | attn_smooth = attn_smooth.reshape(B * sum(self.nmb_crops), 1, gc_spatial_res, gc_spatial_res)
203 | # attn_hard is later served as 'foreground' hint, use attn_hard.bool()
204 | attn_hard = process_attentions(attn_smooth, gc_spatial_res, threshold=threshold, blur_sigma=0.6)
205 | attn_hard = attn_hard.squeeze(1)
206 |
207 | # Align Queries to each image crop's features with decoder, assuming only 1 global crop
208 | all_queries = []
209 | for i in range(sum(self.nmb_crops)):
210 | if i == 0:
211 | features = gc_spatial_output
212 | else:
213 | features = lc_spatial_output[:, i-1]
214 | for j in range(self.num_decode_layers):
215 | # attention: cross-attention first
216 | queries = self.decoder_cross_attention_layers[j](
217 | outQueries, features, pos=pos, query_pos=posQueries)
218 | # self-attention
219 | queries = self.decoder_self_attention_layers[j](
220 | queries, query_pos=posQueries)
221 | # FFN
222 | queries = self.decoder_ffn_layers[j](queries)
223 |
224 | all_queries.append(queries)
225 |
226 | return all_queries, gc_spatial_output, lc_spatial_output, attn_hard, gc_spatial_res, lc_spatial_res
227 |
228 |
229 | if __name__ == '__main__':
230 | model = AlignSegmentor()
--------------------------------------------------------------------------------
/model/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import torchvision
5 |
6 | from PIL import ImageFilter, Image
7 | from typing import List, Tuple, Dict
8 | from torch import Tensor
9 | from torchvision.transforms import functional as F
10 |
11 |
12 | class GaussianBlur:
13 | """
14 | Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709 following
15 | https://github.com/facebookresearch/swav/blob/5e073db0cc69dea22aa75e92bfdd75011e888f28/src/multicropdataset.py#L64
16 | """
17 |
18 | def __init__(self, sigma=[.1, 2.]):
19 | self.sigma = sigma
20 |
21 | def __call__(self, x: Image):
22 | sigma = random.uniform(self.sigma[0], self.sigma[1])
23 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
24 | return x
25 |
26 |
27 | class TrainTransforms:
28 |
29 | def __init__(self,
30 | size_crops: List[int], # [224, 224]
31 | nmb_crops: List[int], # [2, 4]
32 | min_scale_crops: List[float], # [0.5, 0.05]
33 | max_scale_crops: List[float], # [1., 0.25]
34 | augment_image: bool = False,
35 | jitter_strength: float = 1.0,
36 | min_intersection: float = 0.05,
37 | blur_strength: float = 1.0):
38 | """
39 | Main transform used for aligning target img. Implements multi-crop and calculates the corresponding
40 | crop bounding boxes for each crop-pair.
41 | :param size_crops: size of global and local crop
42 | :param nmb_crops: number of global and local crop
43 | :param min_scale_crops: the lower bound for the random area of the global and local crops before resizing
44 | :param max_scale_crops: the upper bound for the random area of the global and local crops before resizing
45 | :param augment_image: whether to perform image augmentation
46 | :param jitter_strength: the strength of jittering for brightness, contrast, saturation and hue
47 | :param min_intersection: minimum percentage of intersection of image ares for two sampled crops from the
48 | same picture should have. This makes sure that we can always calculate a loss for each pair of
49 | global and local crops.
50 | :param blur_strength: the maximum standard deviation of the Gaussian kernel
51 | """
52 | assert len(size_crops) == len(nmb_crops)
53 | assert len(min_scale_crops) == len(nmb_crops)
54 | assert len(max_scale_crops) == len(nmb_crops)
55 | assert 0 < min_intersection < 1
56 | self.size_crops = size_crops
57 | self.nmb_crops = nmb_crops
58 | self.min_scale_crops = min_scale_crops
59 | self.max_scale_crops = max_scale_crops
60 | self.min_intersection = min_intersection
61 | self.augment_image = augment_image
62 |
63 | if self.augment_image:
64 | # Construct color transforms
65 | self.color_jitter = torchvision.transforms.ColorJitter(
66 | 0.8 * jitter_strength, 0.8 * jitter_strength, 0.8 * jitter_strength,
67 | 0.2 * jitter_strength
68 | )
69 | color_transform = [torchvision.transforms.RandomApply([self.color_jitter], p=0.8),
70 | torchvision.transforms.RandomGrayscale(p=0.2)]
71 | blur = GaussianBlur(sigma=[blur_strength * .1, blur_strength * 2.])
72 | color_transform.append(torchvision.transforms.RandomApply([blur], p=0.5))
73 | self.color_transform = torchvision.transforms.Compose(color_transform)
74 |
75 | # Construct final transforms
76 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77 | self.final_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize])
78 |
79 | # Construct randomly resized crops transforms
80 | self.rrc_transforms = []
81 | for i in range(len(self.size_crops)): # [224, 96]
82 | random_resized_crop = torchvision.transforms.RandomResizedCrop(
83 | self.size_crops[i],
84 | scale=(self.min_scale_crops[i], self.max_scale_crops[i]),
85 | )
86 | self.rrc_transforms.extend([random_resized_crop] * self.nmb_crops[i])
87 |
88 | def __call__(self, sample: torch.Tensor) -> Tuple[List[Tensor], Dict[str, Tensor]]:
89 | multi_crops = []
90 | crop_bboxes = torch.zeros(len(self.rrc_transforms), 4)
91 |
92 | for i, rrc_transform in enumerate(self.rrc_transforms):
93 | # Get random crop params
94 | y1, x1, h, w = rrc_transform.get_params(sample, rrc_transform.scale, rrc_transform.ratio)
95 | if i > 0:
96 | # Check whether crop has min overlap with existing global crops. If not resample.
97 | while True:
98 | # Calculate intersection between sampled crop and all sampled global crops
99 | bbox = torch.Tensor([x1, y1, x1 + w, y1 + h])
100 | left_top = torch.max(bbox.unsqueeze(0)[:, None, :2],
101 | crop_bboxes[:min(i, self.nmb_crops[0]), :2])
102 | right_bottom = torch.min(bbox.unsqueeze(0)[:, None, 2:],
103 | crop_bboxes[:min(i, self.nmb_crops[0]), 2:])
104 | wh = _upcast(right_bottom - left_top).clamp(min=0)
105 | inter = wh[:, :, 0] * wh[:, :, 1]
106 |
107 | # set min intersection to at least 1% of image area
108 | min_intersection = int((sample.size[0] * sample.size[1]) * self.min_intersection)
109 | # Global crops should have twice the min_intersection with each other
110 | if i in list(range(self.nmb_crops[0])):
111 | min_intersection *= 2
112 | if not torch.all(inter > min_intersection):
113 | y1, x1, h, w = rrc_transform.get_params(sample, rrc_transform.scale, rrc_transform.ratio)
114 | else:
115 | break
116 |
117 | # Apply rrc params and store absolute crop bounding box
118 | img = F.resized_crop(sample, y1, x1, h, w, rrc_transform.size, rrc_transform.interpolation)
119 | crop_bboxes[i] = torch.Tensor([x1, y1, x1 + w, y1 + h])
120 |
121 | if self.augment_image:
122 | # Apply color transforms
123 | img = self.color_transform(img)
124 |
125 | # Apply final transform
126 | img = self.final_transform(img)
127 | multi_crops.append(img)
128 |
129 | # Calculate relative bboxes for each crop pair from absolute bboxes
130 | gc_bboxes, otc_bboxes = self.calculate_bboxes(crop_bboxes)
131 |
132 | return multi_crops, {"gc": gc_bboxes, "all": otc_bboxes}
133 |
134 | def calculate_bboxes(self, crop_bboxes: Tensor):
135 | # 1. Calculate two intersection bboxes for each global crop - other crop pair
136 | gc_bboxes = crop_bboxes[:self.nmb_crops[0]]
137 | left_top = torch.max(gc_bboxes[:, None, :2], crop_bboxes[:, :2]) # [nmb_crops[0], sum(nmb_crops), 2]
138 | right_bottom = torch.min(gc_bboxes[:, None, 2:], crop_bboxes[:, 2:]) # [nmb_crops[0], sum(nmb_crops), 2]
139 | # Testing for non-intersecting crops. This should always be true, just as safeguard.
140 | assert torch.all((right_bottom - left_top) > 0)
141 |
142 | # 2. Scale intersection bbox with crop size
143 | # Extract height and width of all crop bounding boxes. Each row contains (w,h) of a crop. [sum(nmb_crops),1,2]
144 | ws_hs = torch.stack((crop_bboxes[:, 2] - crop_bboxes[:, 0], crop_bboxes[:, 3] - crop_bboxes[:, 1])).T[:, None]
145 |
146 | # Stack global crop sizes for each bbox dimension
147 | crops_sizes = torch.repeat_interleave(torch.Tensor([self.size_crops[0]]), self.nmb_crops[0] * 2) \
148 | .reshape(self.nmb_crops[0], 2)
149 | if len(self.size_crops) == 2:
150 | lc_crops_sizes = torch.repeat_interleave(torch.Tensor([self.size_crops[1]]), self.nmb_crops[1] * 2) \
151 | .reshape(self.nmb_crops[1], 2)
152 | crops_sizes = torch.cat((crops_sizes, lc_crops_sizes))[:, None] # [sum(nmb_crops), 1, 2]
153 |
154 | # Calculate x1s and y1s of each crop bbox
155 | x1s_y1s = crop_bboxes[:, None, :2]
156 |
157 | # Scale top left and right bottom points by percentage of width and height covered
158 | left_top_scaled_gc = crops_sizes[:self.nmb_crops[0]] \
159 | * ((left_top - x1s_y1s[:self.nmb_crops[0]]) / ws_hs[:self.nmb_crops[0]])
160 | right_bottom_scaled_gc = crops_sizes[:self.nmb_crops[0]] \
161 | * ((right_bottom - x1s_y1s[:self.nmb_crops[0]]) / ws_hs[:self.nmb_crops[0]])
162 | left_top_otc_points_per_gc = torch.stack([left_top[i] for i in range(self.nmb_crops[0])], dim=1)
163 | right_bottom_otc_points_per_gc = torch.stack([right_bottom[i] for i in range(self.nmb_crops[0])], dim=1)
164 | left_top_scaled_otc = crops_sizes * ((left_top_otc_points_per_gc - x1s_y1s) / ws_hs)
165 | right_bottom_scaled_otc = crops_sizes * ((right_bottom_otc_points_per_gc - x1s_y1s) / ws_hs)
166 |
167 | # 3. Construct bboxes in x1, y1, x2, y2 format from left top and right bottom points
168 | # gc_bboxes = relative bboxes of gc and its intersection with lc, [num_crops[0], sum(nmb_crops), 4]
169 | # otc_bboxes = relative bboxes of lc and its intersection with gc, [sum(nmb_crops), 1, 4]
170 | gc_bboxes = torch.cat((left_top_scaled_gc, right_bottom_scaled_gc), dim=2)
171 | otc_bboxes = torch.cat((left_top_scaled_otc, right_bottom_scaled_otc), dim=2)
172 |
173 | return gc_bboxes, otc_bboxes
174 |
175 |
176 | def _upcast(t: Tensor) -> Tensor:
177 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
178 | if t.is_floating_point():
179 | return t if t.dtype in (torch.float32, torch.float64) else t.float()
180 | else:
181 | return t if t.dtype in (torch.int32, torch.int64) else t.int()
182 |
--------------------------------------------------------------------------------
/data/coco_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 |
5 | from PIL import Image
6 | from torch.utils.data import DataLoader, Dataset
7 | from torchvision.datasets import VisionDataset
8 | from typing import List, Optional, Callable, Tuple, Any
9 |
10 |
11 | class CocoDataModule:
12 |
13 | def __init__(self,
14 | num_workers: int,
15 | batch_size: int,
16 | data_dir: str,
17 | train_transforms,
18 | val_transforms,
19 | file_list: List[str],
20 | mask_type: str = None,
21 | file_list_val: List[str] = None,
22 | val_target_transforms=None,
23 | shuffle: bool = True,
24 | size_val_set: int = 10):
25 | super().__init__()
26 | self.num_workers = num_workers
27 | self.batch_size = batch_size
28 | self.shuffle = shuffle
29 | self.size_val_set = size_val_set
30 | self.file_list = file_list
31 | self.file_list_val = file_list_val
32 | self.data_dir = data_dir
33 | self.train_transforms = train_transforms
34 | self.val_transforms = val_transforms
35 | self.file_list_val = file_list_val
36 | self.val_target_transforms = val_target_transforms
37 | self.mask_type = mask_type
38 | self.coco_train = None
39 | self.coco_val = None
40 |
41 | if self.mask_type is None:
42 | self.coco_train = UnlabelledCoco(self.file_list,
43 | self.train_transforms,
44 | os.path.join(self.data_dir, "images/train2017"))
45 | self.coco_val = UnlabelledCoco(self.file_list[:self.size_val_set * self.batch_size],
46 | self.val_transforms,
47 | os.path.join(self.data_dir, "images/val2017"))
48 | else:
49 | self.coco_train = COCOSegmentation(self.data_dir,
50 | self.file_list,
51 | self.mask_type,
52 | image_set="train",
53 | transforms=self.train_transforms)
54 | self.coco_val = COCOSegmentation(self.data_dir,
55 | self.file_list_val,
56 | self.mask_type,
57 | image_set="val",
58 | transform=self.val_transforms,
59 | target_transform=self.val_target_transforms)
60 |
61 | print(f"Train size {len(self.coco_train)}")
62 | print(f"Val size {len(self.coco_val)}")
63 |
64 | def __len__(self):
65 | return len(self.file_list)
66 |
67 | def train_dataloader(self):
68 | return DataLoader(self.coco_train, batch_size=self.batch_size,
69 | shuffle=self.shuffle, num_workers=self.num_workers,
70 | drop_last=True, pin_memory=True)
71 |
72 | def val_dataloader(self):
73 | return DataLoader(self.coco_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
74 | drop_last=False, pin_memory=True)
75 |
76 |
77 | class COCOSegmentation(VisionDataset):
78 |
79 | def __init__(
80 | self,
81 | root: str,
82 | file_names: List[str],
83 | mask_type: str,
84 | image_set: str = "train",
85 | transform: Optional[Callable] = None,
86 | target_transform: Optional[Callable] = None,
87 | transforms: Optional[Callable] = None,
88 | ):
89 | super(COCOSegmentation, self).__init__(root, transforms, transform, target_transform)
90 | self.image_set = image_set
91 | self.file_names = file_names
92 | self.mask_type = mask_type
93 | assert self.image_set in ["train", "val"]
94 | assert mask_type in ["all", "stuff", "thing"]
95 |
96 | # Set mask folder depending on mask_type
97 | if mask_type == "all":
98 | seg_folder = "annotations/stuffthingmaps_trainval2017/{}2017/"
99 | json_file = "annotations/stuffthingmaps_trainval2017/stuffthing_2017.json"
100 | elif mask_type == "thing":
101 | seg_folder = "annotations/panoptic_annotations/semantic_segmentation_{}2017/"
102 | json_file = "annotations/panoptic_annotations/panoptic_val2017.json"
103 | elif mask_type == "stuff":
104 | seg_folder = "annotations/stuff_annotations/stuff_{}2017_pixelmaps/"
105 | json_file = "annotations/stuff_annotations/stuff_val2017.json"
106 | else:
107 | raise ValueError(f"No support for image set {self.image_set}")
108 | seg_folder = seg_folder.format(image_set)
109 |
110 | # Load categories to category to id map for merging to coarse categories
111 | with open(os.path.join(root, json_file)) as f:
112 | an_json = json.load(f)
113 | all_cat = an_json['categories']
114 | if mask_type == "all":
115 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat])
116 | super_cats.remove("other") # remove others from prediction targets as this is not semantic
117 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))}
118 | super_cat_to_id["other"] = 255 # ignore_index for CE
119 | # Align 'id' labels: PNG_label = GT_label - 1
120 | self.cat_id_map = {(cat_dict['id']-1): super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat}
121 | elif mask_type == "thing":
122 | all_thing_cat_sup = set(cat_dict["supercategory"] for cat_dict in all_cat if cat_dict["isthing"] == 1)
123 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(all_thing_cat_sup))}
124 | self.cat_id_map = {}
125 | for cat_dict in all_cat:
126 | if cat_dict["isthing"] == 1:
127 | self.cat_id_map[cat_dict["id"]] = super_cat_to_id[cat_dict["supercategory"]]
128 | elif cat_dict["isthing"] == 0:
129 | self.cat_id_map[cat_dict["id"]] = 255
130 | elif mask_type == "stuff":
131 | super_cats = set([cat_dict['supercategory'] for cat_dict in all_cat])
132 | super_cats.remove("other") # remove others from prediction targets as this is not semantic
133 | super_cat_to_id = {super_cat: i for i, super_cat in enumerate(sorted(super_cats))}
134 | super_cat_to_id["other"] = 255 # ignore_index for CE
135 | self.cat_id_map = {cat_dict['id']: super_cat_to_id[cat_dict['supercategory']] for cat_dict in all_cat}
136 |
137 | # Get images and masks fnames
138 | seg_dir = os.path.join(root, seg_folder)
139 | image_dir = os.path.join(root, "images", f"{image_set}2017")
140 | if not os.path.isdir(seg_dir) or not os.path.isdir(image_dir):
141 | print(seg_dir)
142 | print(image_dir)
143 | raise RuntimeError('Dataset not found or corrupted.')
144 | self.images = [os.path.join(image_dir, x) for x in self.file_names]
145 | self.masks = [os.path.join(seg_dir, x.replace("jpg", "png")) for x in self.file_names]
146 |
147 | def __len__(self):
148 | return len(self.file_names)
149 |
150 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
151 | img = Image.open(self.images[index]).convert('RGB')
152 | mask = Image.open(self.masks[index])
153 |
154 | if self.transforms:
155 | img, mask = self.transforms(img, mask)
156 |
157 | if self.mask_type == "all":
158 | # move 'id' labels from [0, 182] to [0,26], and 255=={182,255}
159 | # (183 is 'other' and 0 is things)
160 | mask *= 255
161 | assert torch.min(mask).item() >= 0
162 | mask[mask == 255] = 182
163 | assert torch.max(mask).item() <= 182
164 | for cat_id in torch.unique(mask):
165 | mask[mask == cat_id] = self.cat_id_map[cat_id.item()]
166 |
167 | assert torch.max(mask).item() <= 255
168 | assert torch.min(mask).item() >= 0
169 | mask /= 255
170 | return img, mask
171 | elif self.mask_type == "stuff":
172 | # move stuff labels from {0} U [92, 183] to [0,15] and [255] with 255 == {0, 183}
173 | # (183 is 'other' and 0 is things)
174 | mask *= 255
175 | assert torch.max(mask).item() <= 183
176 | mask[mask == 0] = 183 # [92, 183]
177 | assert torch.min(mask).item() >= 92
178 | for cat_id in torch.unique(mask):
179 | mask[mask == cat_id] = self.cat_id_map[cat_id.item()]
180 |
181 | assert torch.max(mask).item() <= 255
182 | assert torch.min(mask).item() >= 0
183 | mask /= 255
184 | return img, mask
185 | elif self.mask_type == "thing":
186 | mask *= 255
187 | assert torch.max(mask).item() <= 200
188 | mask[mask == 0] = 200 # map unlabelled to stuff
189 | merged_mask = mask.clone()
190 | for cat_id in torch.unique(mask):
191 | merged_mask[mask == cat_id] = self.cat_id_map[int(cat_id.item())] # [0, 11] + {255}
192 |
193 | assert torch.max(merged_mask).item() <= 255
194 | assert torch.min(merged_mask).item() >= 0
195 | merged_mask /= 255
196 | return img, merged_mask
197 | return img, mask
198 |
199 |
200 | class UnlabelledCoco(Dataset):
201 |
202 | def __init__(self, file_list, transforms, data_dir):
203 | self.file_names = file_list
204 | self.transform = transforms
205 | self.data_dir = data_dir
206 |
207 | def __len__(self):
208 | return len(self.file_names)
209 |
210 | def __getitem__(self, idx):
211 | img_path = self.file_names[idx]
212 | image = Image.open(os.path.join(self.data_dir, img_path)).convert('RGB')
213 | if self.transform:
214 | image = self.transform(image)
215 | return image
216 |
--------------------------------------------------------------------------------
/evaluate/visualize_segment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import cv2
5 | import random
6 | import colorsys
7 |
8 | import skimage.io
9 | from skimage.measure import find_contours
10 | from matplotlib.patches import Polygon
11 | import torch
12 | import torch.nn as nn
13 | import torchvision
14 | from torchvision import transforms as pth_transforms
15 | from torchvision.transforms import GaussianBlur
16 | import torch.nn.functional as F
17 | import numpy as np
18 | from PIL import Image
19 | from skimage.measure import label
20 | from matplotlib import pyplot as plt
21 |
22 | from model.align_model import AlignSegmentor
23 | from utils import neq_load_external
24 |
25 |
26 | def norm(t):
27 | return F.normalize(t, dim=-1, eps=1e-10)
28 |
29 |
30 | def apply_mask(image, mask, color, alpha=0.5):
31 | for c in range(3):
32 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
33 | return image
34 |
35 |
36 | def random_colors(N, bright=True):
37 | """
38 | Generate random colors.
39 | """
40 | brightness = 1.0 if bright else 0.7
41 | hsv = [(i / N, 1, brightness) for i in range(N)]
42 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
43 | random.shuffle(colors)
44 | return colors
45 |
46 |
47 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
48 | fig = plt.figure(figsize=figsize, frameon=False)
49 | ax = plt.Axes(fig, [0., 0., 1., 1.])
50 | ax.set_axis_off()
51 | fig.add_axes(ax)
52 | ax = plt.gca()
53 |
54 | N = 1
55 | mask = mask[None, :, :]
56 | # Generate random colors
57 | colors = random_colors(N)
58 |
59 | # Show area outside image boundaries.
60 | height, width = image.shape[:2]
61 | margin = 0
62 | ax.set_ylim(height + margin, -margin)
63 | ax.set_xlim(-margin, width + margin)
64 | ax.axis('off')
65 | masked_image = image.astype(np.uint32).copy()
66 |
67 | for i in range(N):
68 | color = colors[i]
69 | _mask = mask[i]
70 | if blur:
71 | _mask = cv2.blur(_mask, (10, 10))
72 | # Mask
73 | masked_image = apply_mask(masked_image, _mask, color, alpha)
74 | # Mask Polygon
75 | # Pad to ensure proper polygons for masks that touch image edges.
76 | if contour:
77 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
78 | padded_mask[1:-1, 1:-1] = _mask
79 | contours = find_contours(padded_mask, 0.5)
80 | for verts in contours:
81 | # Subtract the padding and flip (y, x) to (x, y)
82 | verts = np.fliplr(verts) - 1
83 | p = Polygon(verts, facecolor="none", edgecolor=color)
84 | ax.add_patch(p)
85 | ax.imshow(masked_image.astype(np.uint8), aspect='auto')
86 | fig.savefig(fname)
87 | print(f"{fname} saved.")
88 |
89 |
90 | def display_segments(image, masks, fname="test", figsize=(5, 5), alpha=0.7):
91 | N = 5
92 | # Generate random colors
93 | # colors = random_colors(N)
94 | colors = [(128, 0, 0), (30, 144, 255), (75, 0, 130), (184, 134, 11), (0, 128, 0)]
95 | colors = [(x / 255, y / 255, z / 255) for (x, y, z) in colors]
96 |
97 | fig = plt.figure(figsize=figsize, frameon=False)
98 | ax = plt.Axes(fig, [0., 0., 1., 1.])
99 | ax.set_axis_off()
100 | fig.add_axes(ax)
101 | ax = plt.gca()
102 |
103 | # Show area outside image boundaries.
104 | height, width = image.shape[:2]
105 | margin = 0
106 | ax.set_ylim(height + margin, -margin)
107 | ax.set_xlim(-margin, width + margin)
108 | ax.axis('off')
109 |
110 | for i in range(N):
111 | color = colors[i]
112 | _mask = masks[i]
113 | # Mask
114 | masked_image = image.astype(np.uint32).copy()
115 | masked_image = apply_mask(masked_image, _mask, color, alpha)
116 |
117 | ax.imshow(masked_image.astype(np.uint8), aspect='auto')
118 | file = os.path.join(fname, "cls" + str(i) + ".png")
119 | fig.savefig(file)
120 | print(f"{file} saved.")
121 |
122 |
123 | def display_allsegments(image, masks, n=5, fname="test", figsize=(5, 5), alpha=0.6):
124 | N = n # num of colors
125 | # colors = [(128,0,0), (184,134,11), (0,128,0), (62,78,94), (0,0,0)] # last two backgrounds
126 | colors = [(128,0,0), (30,144,255), (75,0,130), (184,134,11), (0,128,0)] # for coco
127 | colors = [(x/255, y/255, z/255) for (x, y, z) in colors]
128 | print(colors)
129 |
130 | # Generate random colors
131 | # colors = random_colors(N)
132 |
133 | fig = plt.figure(figsize=figsize, frameon=False)
134 | ax = plt.Axes(fig, [0., 0., 1., 1.])
135 | ax.set_axis_off()
136 | fig.add_axes(ax)
137 | ax = plt.gca()
138 |
139 | # Show area outside image boundaries.
140 | height, width = image.shape[:2]
141 | margin = 0
142 | ax.set_ylim(height + margin, -margin)
143 | ax.set_xlim(-margin, width + margin)
144 | ax.axis('off')
145 |
146 | masked_image = image.astype(np.uint32).copy()
147 | for i in range(N):
148 | color = colors[i]
149 | _mask = masks[i]
150 | # Mask
151 | masked_image = apply_mask(masked_image, _mask, color, alpha)
152 |
153 | ax.imshow(masked_image.astype(np.uint8), aspect='auto')
154 | file = os.path.join(fname, "cls" + str(i) + ".png")
155 | fig.savefig(file)
156 | print(f"{file} saved.")
157 |
158 |
159 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \
160 | -> torch.Tensor:
161 | """
162 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes
163 | components smaller than 3 pixels.
164 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the
165 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions
166 | so that we can use the foreground mask during training time.
167 | :param attentions: torch 4D-Tensor containing the averaged attentions
168 | :param spatial_res: spatial resolution of the attention map
169 | :param threshold: the percentage of mass to keep as foreground.
170 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring.
171 | :return: the foreground mask obtained from the ViT's attention.
172 | """
173 | # Blur attentions
174 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions)
175 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2)
176 | # Keep threshold% of mass
177 | val, idx = torch.sort(attentions)
178 | val /= torch.sum(val, dim=-1, keepdim=True)
179 | cumval = torch.cumsum(val, dim=-1)
180 | th_attn = cumval > (1 - threshold)
181 | idx2 = torch.argsort(idx)
182 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0])
183 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float()
184 | # Remove components with less than 3 pixels
185 | for j, th_att in enumerate(th_attn):
186 | labelled = label(th_att.cpu().numpy())
187 | for k in range(1, np.max(labelled) + 1):
188 | mask = labelled == k
189 | if np.sum(mask) <= 2:
190 | th_attn[j, 0][mask] = 0
191 | return th_attn
192 |
193 |
194 | if __name__ == '__main__':
195 | parser = argparse.ArgumentParser('Evaluate segmentation on pretrained model')
196 | parser.add_argument('--pretrained_weights', default='./epoch10.pth',
197 | type=str, help="Path to pretrained weights to load.")
198 | parser.add_argument("--image_path", default='', type=str, help="Path of the image to load.")
199 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
200 | parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
201 | parser.add_argument('--output_dir', default='./outputs/', help='Path where to save visualizations.')
202 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
203 | obtained by thresholding the self-attention maps to keep xx% of the mass.""")
204 | args = parser.parse_args()
205 |
206 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
207 | # build model
208 | # '''
209 | model = AlignSegmentor(arch='vit_small',
210 | patch_size=16,
211 | embed_dim=384,
212 | hidden_dim=768,
213 | num_heads=3,
214 | num_queries=5,
215 | nmb_crops=[1, 0],
216 | num_decode_layers=1,
217 | last_self_attention=True)
218 |
219 | # set model to eval mode
220 | for p in model.parameters():
221 | p.requires_grad = False
222 | model.eval()
223 | model.to(device)
224 |
225 | # load pretrained weights
226 | if os.path.isfile(args.pretrained_weights):
227 | pratrained_model = torch.load(args.pretrained_weights, map_location="cpu")
228 | msg = model.load_state_dict(pratrained_model['state_dict'], strict=False)
229 | print(msg)
230 | else:
231 | print('no pretrained pth found!')
232 |
233 | if os.path.isfile(args.image_path):
234 | with open(args.image_path, 'rb') as f:
235 | img = Image.open(f)
236 | img = img.convert('RGB')
237 | else:
238 | print(f"Provided image path {args.image_path} is non valid.")
239 | sys.exit(1)
240 | transform = pth_transforms.Compose([
241 | pth_transforms.Resize(args.image_size),
242 | pth_transforms.ToTensor(),
243 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
244 | ])
245 | img = transform(img)
246 |
247 | # make the image divisible by the patch size
248 | # img = c, w, h (3, 480, 480); unsqueeze -> (1, 3, 480, 480)
249 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
250 | img = img[:, :w, :h].unsqueeze(0)
251 |
252 | w_spatial_res = img.shape[-2] // args.patch_size
253 |
254 | # get aligned_queries, spatial_token_output and attention_map
255 | all_queries, gc_output, _, attn_hard, _, _ = model([img.to(device)], threshold=args.threshold)
256 |
257 | os.makedirs(args.output_dir, exist_ok=True)
258 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True),
259 | os.path.join(args.output_dir, "img.png"))
260 |
261 | # interpolate binary mask
262 | attn_hard = nn.functional.interpolate(attn_hard.unsqueeze(1), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
263 | image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
264 | display_instances(image, attn_hard[0], fname=os.path.join(args.output_dir, "mask_" + str(args.threshold) + ".png"), blur=False)
265 |
266 | # calculate query assignment score
267 | gc_token_sim = torch.einsum("bnc,bqc->bnq", norm(gc_output), norm(all_queries[0]))
268 | gc_token_cls = torch.softmax(gc_token_sim, dim=-1)
269 | gc_token_cls = gc_token_cls.reshape(1, w_spatial_res, w_spatial_res, -1).permute(0, 3, 1, 2)
270 |
271 | # Smooth interpolation
272 | masks_prob = F.interpolate(gc_token_cls, size=w, mode='bilinear')
273 | masks_oh = masks_prob.argmax(dim=1)
274 | masks_oh = torch.nn.functional.one_hot(masks_oh, masks_prob.shape[1])
275 | masks_oh = masks_oh.squeeze(dim=0).permute(2, 0, 1)
276 |
277 | masks = []
278 | for i in range(masks_prob.shape[1]):
279 | mask = masks_oh[i].cpu().numpy()
280 | # print('mask = ', mask.shape, mask)
281 | masks.append(mask)
282 | display_allsegments(image, masks, n=5, fname=args.output_dir)
283 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import torch
6 | import yaml
7 |
8 | from torchvision.transforms import ToTensor, Compose, Resize, Normalize
9 | from torchvision.transforms.functional import InterpolationMode
10 |
11 | from data.coco_data import CocoDataModule
12 | from data.voc_data import VOCDataModule
13 | from model.align_model import AlignSegmentor
14 | from model.criterion import AlignCriterion
15 | from model.transforms import TrainTransforms
16 | from utils import AverageMeter, save_checkpoint, neq_load_external
17 |
18 |
19 | def set_path(config):
20 | if config['train']['checkpoint']:
21 | model_path = os.path.dirname(config['train']['checkpoint'])
22 | else:
23 | model_path = './log_tmp/{0}-{1}-bs{2}/model'.format(config["data"]["dataset_name"],
24 | config["train"]["arch"],
25 | config["train"]["batch_size"])
26 |
27 | if not os.path.exists(model_path): os.makedirs(model_path)
28 | return model_path
29 |
30 |
31 | def exclude_from_wt_decay(named_params, weight_decay: float, lr: float):
32 | params = []
33 | excluded_params = []
34 | query_param = []
35 |
36 | for name, param in named_params:
37 | if not param.requires_grad:
38 | continue
39 | # do not regularize biases nor Norm parameters
40 | if name.endswith(".bias") or len(param.shape) == 1:
41 | excluded_params.append(param)
42 | elif 'clsQueries' in name:
43 | query_param.append(param)
44 | else:
45 | params.append(param)
46 | return [{'params': params, 'weight_decay': weight_decay, 'lr': lr},
47 | {'params': excluded_params, 'weight_decay': 0., 'lr': lr},
48 | {'params': query_param, 'weight_decay': 0., 'lr': lr * 1}]
49 |
50 |
51 | def configure_optimizers(model, train_config):
52 | # Separate Decoder params from ViT params
53 | # only train Decoder
54 | decoder_params_named = []
55 | for name, param in model.named_parameters():
56 | if name.startswith("backbone"):
57 | param.requires_grad = False
58 | elif train_config['fix_prototypes'] and 'clsQueries' in name:
59 | param.requires_grad = False
60 | else:
61 | decoder_params_named.append((name, param))
62 |
63 | # Prepare param groups. Exclude norm and bias from weight decay if flag set.
64 | if train_config['exclude_norm_bias']:
65 | params = exclude_from_wt_decay(decoder_params_named,
66 | weight_decay=train_config["weight_decay"],
67 | lr=train_config['lr_decoder'])
68 | else:
69 | decoder_params = [param for _, param in decoder_params_named]
70 | params = [{'params': decoder_params, 'lr': train_config['lr_decoder']}]
71 |
72 | # Init optimizer and lr schedule
73 | optimizer = torch.optim.AdamW(params, weight_decay=train_config["weight_decay"])
74 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
75 |
76 | return optimizer, scheduler
77 |
78 |
79 | def start_train():
80 | with open(args.config_path) as file:
81 | config = yaml.safe_load(file.read())
82 | # print('Config: ', config)
83 |
84 | data_config = config['data']
85 | train_config = config['train']
86 | torch.manual_seed(train_config['seed'])
87 | torch.cuda.manual_seed_all(train_config['seed'])
88 |
89 | # Init data modules and tranforms
90 | dataset_name = data_config["dataset_name"]
91 | train_transforms = TrainTransforms(size_crops=data_config["size_crops"],
92 | nmb_crops=data_config["nmb_crops"],
93 | min_intersection=data_config["min_intersection_crops"],
94 | min_scale_crops=data_config["min_scale_crops"],
95 | max_scale_crops=data_config["max_scale_crops"],
96 | augment_image=data_config["augment_image"])
97 |
98 | # Setup voc dataset used for evaluation
99 | val_size = data_config["size_crops_val"]
100 | val_image_transforms = Compose([Resize((val_size, val_size)),
101 | ToTensor(),
102 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
103 | val_target_transforms = Compose([Resize((val_size, val_size), interpolation=InterpolationMode.NEAREST),
104 | ToTensor()])
105 |
106 | # Setup train data
107 | if dataset_name == "voc":
108 | train_data_module = VOCDataModule(batch_size=train_config["batch_size"],
109 | num_workers=config["num_workers"],
110 | train_split="trainaug",
111 | val_split="val",
112 | data_dir=data_config["voc_data_path"],
113 | train_image_transform=train_transforms,
114 | val_image_transform=val_image_transforms,
115 | val_target_transform=val_target_transforms)
116 | elif dataset_name == 'coco':
117 | file_list = os.listdir(os.path.join(data_config["data_dir"], "images/train2017"))
118 | train_data_module = CocoDataModule(batch_size=train_config["batch_size"],
119 | num_workers=config["num_workers"],
120 | file_list=file_list,
121 | data_dir=data_config["data_dir"],
122 | train_transforms=train_transforms,
123 | val_transforms=None)
124 | elif 'movi' in dataset_name:
125 | train_data_module = MOViDataModule(data_dir=data_config["data_dir"],
126 | dataset_name=data_config['dataset_name'],
127 | batch_size=train_config["batch_size"],
128 | num_workers=config["num_workers"],
129 | train_split="frames",
130 | val_split="images",
131 | train_image_transform=train_transforms,
132 | val_image_transform=val_image_transforms,
133 | val_target_transform=val_target_transforms)
134 | else:
135 | raise ValueError(f"Data set {dataset_name} not supported")
136 |
137 | model_path = set_path(config)
138 |
139 | model = AlignSegmentor(arch=train_config['arch'],
140 | patch_size=train_config['patch_size'],
141 | embed_dim=train_config['embed_dim'],
142 | hidden_dim=train_config['hidden_dim'],
143 | num_heads=train_config['decoder_num_heads'],
144 | num_queries=train_config['num_queries'],
145 | nmb_crops=data_config["nmb_crops"],
146 | num_decode_layers=train_config['num_decode_layers'],
147 | last_self_attention=train_config['last_self_attention'])
148 | model = model.to(cuda)
149 |
150 | criterion = AlignCriterion(patch_size=train_config['patch_size'],
151 | num_queries=train_config['num_queries'],
152 | nmb_crops=data_config["nmb_crops"],
153 | roi_align_kernel_size=train_config['roi_align_kernel_size'],
154 | ce_temperature=train_config['ce_temperature'],
155 | negative_pressure=train_config['negative_pressure'],
156 | last_self_attention=train_config['last_self_attention'])
157 | criterion = criterion.to(cuda)
158 |
159 | # Initialize model
160 | start_epoch = 0
161 | if train_config["checkpoint"] is not None:
162 | checkpoint = torch.load(train_config["checkpoint"])
163 | start_epoch = checkpoint['epoch']
164 | msg = model.load_state_dict(checkpoint["state_dict"], strict=True)
165 | print(msg)
166 | elif train_config["checkpoint"] is None \
167 | and train_config["pretrained_model"] is not None \
168 | and train_config["prototype_queries"] is not None:
169 | # initialize model with pre-trained ViT and prepared Prototypes
170 | pretrained_model = torch.load(train_config["pretrained_model"], map_location=torch.device('cpu'))
171 | neq_load_external(model, pretrained_model)
172 | protos = torch.load(train_config["prototype_queries"]).to(cuda)
173 | model.set_clsQuery(protos)
174 | elif train_config["checkpoint"] is None \
175 | and train_config["pretrained_model"] is not None \
176 | and train_config["prototype_queries"] is None:
177 | # only load pre-trained ViT
178 | pretrained_model = torch.load(train_config["pretrained_model"], map_location=torch.device('cpu'))
179 | neq_load_external(model, pretrained_model)
180 |
181 | # Optionally fix ViT, Queries
182 | optimizer, scheduler = configure_optimizers(model, train_config)
183 | dataloader = train_data_module.train_dataloader()
184 |
185 | for epoch in range(start_epoch, train_config['max_epochs']):
186 |
187 | train(dataloader, model, optimizer, criterion, epoch)
188 |
189 | scheduler.step()
190 | print('\t Epoch: ', epoch, 'with lr: ', scheduler.get_last_lr())
191 |
192 | if epoch % train_config['save_checkpoint_every_n_epochs'] == 0:
193 | # save check_point
194 | save_checkpoint({'epoch': epoch + 1,
195 | 'net': train_config['arch'],
196 | 'state_dict': model.state_dict(),
197 | 'optimizer': optimizer.state_dict(),
198 | }, gap=train_config['save_checkpoint_every_n_epochs'],
199 | filename=os.path.join(model_path, 'epoch%s.pth' % str(epoch + 1)), keep_all=False)
200 |
201 | print('Training %d epochs finished' % (train_config['max_epochs']))
202 |
203 |
204 | def train(data_loader, model, optimizer, criterion, epoch):
205 | losses = AverageMeter()
206 | model.train()
207 |
208 | for idx, batch in enumerate(data_loader):
209 | inputs, bboxes = batch # inputs = [sum(num_crops), (B, 3, w, h)]
210 | B = inputs[0].size(0)
211 | tic = time.time()
212 | for i in range(len(inputs)):
213 | inputs[i] = inputs[i].to(cuda, non_blocking=True)
214 | bboxes['gc'] = bboxes['gc'].to(cuda, non_blocking=True)
215 | bboxes['all'] = bboxes['all'].to(cuda, non_blocking=True)
216 |
217 | results = model(inputs)
218 |
219 | # Calculate loss
220 | loss = criterion(results, bboxes)
221 | losses.update(loss.item(), B, step=len(data_loader))
222 |
223 | optimizer.zero_grad()
224 | loss.backward()
225 | optimizer.step()
226 |
227 | if idx % 1 == 0:
228 | print('Epoch: [{0}][{1}/{2}]\t'
229 | 'Loss {loss.val:.4f} ({loss.local_avg:.4f}) Time:{3:.2f}\t'.
230 | format(epoch, idx, len(data_loader), time.time() - tic, loss=losses))
231 |
232 | return losses.local_avg
233 |
234 |
235 | if __name__ == '__main__':
236 | parser = argparse.ArgumentParser()
237 | parser.add_argument('--config_path', default='./configs/train_voc_config.yml', type=str)
238 | parser.add_argument('--gpu', default='0', type=str)
239 |
240 | args = parser.parse_args()
241 |
242 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
243 | cuda = torch.device('cuda')
244 | start_train()
245 |
--------------------------------------------------------------------------------
/data/stuffthing_2017.json:
--------------------------------------------------------------------------------
1 | {
2 | "info": {"description": "COCO 2017 Stuff Dataset", "url": "http://cocodataset.org", "version": "1.0", "year": 2017, "contributor": "H. Caesar, J. Uijlings, M. Maire, T.-Y. Lin, P. Dollar and V. Ferrari", "date_created": "2017-08-31 00:00:00.0"},
3 | "categories": [{"supercategory": "person", "isthing": 1, "id": 1, "name": "person"}, {"supercategory": "vehicle", "isthing": 1, "id": 2, "name": "bicycle"}, {"supercategory": "vehicle", "isthing": 1, "id": 3, "name": "car"}, {"supercategory": "vehicle", "isthing": 1, "id": 4, "name": "motorcycle"}, {"supercategory": "vehicle", "isthing": 1, "id": 5, "name": "airplane"}, {"supercategory": "vehicle", "isthing": 1, "id": 6, "name": "bus"}, {"supercategory": "vehicle", "isthing": 1, "id": 7, "name": "train"}, {"supercategory": "vehicle", "isthing": 1, "id": 8, "name": "truck"}, {"supercategory": "vehicle", "isthing": 1, "id": 9, "name": "boat"}, {"supercategory": "outdoor", "isthing": 1, "id": 10, "name": "traffic light"}, {"supercategory": "outdoor", "isthing": 1, "id": 11, "name": "fire hydrant"}, {"supercategory": "outdoor", "isthing": 1, "id": 13, "name": "stop sign"}, {"supercategory": "outdoor", "isthing": 1, "id": 14, "name": "parking meter"}, {"supercategory": "outdoor", "isthing": 1, "id": 15, "name": "bench"}, {"supercategory": "animal", "isthing": 1, "id": 16, "name": "bird"}, {"supercategory": "animal", "isthing": 1, "id": 17, "name": "cat"}, {"supercategory": "animal", "isthing": 1, "id": 18, "name": "dog"}, {"supercategory": "animal", "isthing": 1, "id": 19, "name": "horse"}, {"supercategory": "animal", "isthing": 1, "id": 20, "name": "sheep"}, {"supercategory": "animal", "isthing": 1, "id": 21, "name": "cow"}, {"supercategory": "animal", "isthing": 1, "id": 22, "name": "elephant"}, {"supercategory": "animal", "isthing": 1, "id": 23, "name": "bear"}, {"supercategory": "animal", "isthing": 1, "id": 24, "name": "zebra"}, {"supercategory": "animal", "isthing": 1, "id": 25, "name": "giraffe"}, {"supercategory": "accessory", "isthing": 1, "id": 27, "name": "backpack"}, {"supercategory": "accessory", "isthing": 1, "id": 28, "name": "umbrella"}, {"supercategory": "accessory", "isthing": 1, "id": 31, "name": "handbag"}, {"supercategory": "accessory", "isthing": 1, "id": 32, "name": "tie"}, {"supercategory": "accessory", "isthing": 1, "id": 33, "name": "suitcase"}, {"supercategory": "sports", "isthing": 1, "id": 34, "name": "frisbee"}, {"supercategory": "sports", "isthing": 1, "id": 35, "name": "skis"}, {"supercategory": "sports", "isthing": 1, "id": 36, "name": "snowboard"}, {"supercategory": "sports", "isthing": 1, "id": 37, "name": "sports ball"}, {"supercategory": "sports", "isthing": 1, "id": 38, "name": "kite"}, {"supercategory": "sports", "isthing": 1, "id": 39, "name": "baseball bat"}, {"supercategory": "sports", "isthing": 1, "id": 40, "name": "baseball glove"}, {"supercategory": "sports", "isthing": 1, "id": 41, "name": "skateboard"}, {"supercategory": "sports", "isthing": 1, "id": 42, "name": "surfboard"}, {"supercategory": "sports", "isthing": 1, "id": 43, "name": "tennis racket"}, {"supercategory": "kitchen", "isthing": 1, "id": 44, "name": "bottle"}, {"supercategory": "kitchen", "isthing": 1, "id": 46, "name": "wine glass"}, {"supercategory": "kitchen", "isthing": 1, "id": 47, "name": "cup"}, {"supercategory": "kitchen", "isthing": 1, "id": 48, "name": "fork"}, {"supercategory": "kitchen", "isthing": 1, "id": 49, "name": "knife"}, {"supercategory": "kitchen", "isthing": 1, "id": 50, "name": "spoon"}, {"supercategory": "kitchen", "isthing": 1, "id": 51, "name": "bowl"}, {"supercategory": "food", "isthing": 1, "id": 52, "name": "banana"}, {"supercategory": "food", "isthing": 1, "id": 53, "name": "apple"}, {"supercategory": "food", "isthing": 1, "id": 54, "name": "sandwich"}, {"supercategory": "food", "isthing": 1, "id": 55, "name": "orange"}, {"supercategory": "food", "isthing": 1, "id": 56, "name": "broccoli"}, {"supercategory": "food", "isthing": 1, "id": 57, "name": "carrot"}, {"supercategory": "food", "isthing": 1, "id": 58, "name": "hot dog"}, {"supercategory": "food", "isthing": 1, "id": 59, "name": "pizza"}, {"supercategory": "food", "isthing": 1, "id": 60, "name": "donut"}, {"supercategory": "food", "isthing": 1, "id": 61, "name": "cake"}, {"supercategory": "furniture", "isthing": 1, "id": 62, "name": "chair"}, {"supercategory": "furniture", "isthing": 1, "id": 63, "name": "couch"}, {"supercategory": "furniture", "isthing": 1, "id": 64, "name": "potted plant"}, {"supercategory": "furniture", "isthing": 1, "id": 65, "name": "bed"}, {"supercategory": "furniture", "isthing": 1, "id": 67, "name": "dining table"}, {"supercategory": "furniture", "isthing": 1, "id": 70, "name": "toilet"}, {"supercategory": "electronic", "isthing": 1, "id": 72, "name": "tv"}, {"supercategory": "electronic", "isthing": 1, "id": 73, "name": "laptop"}, {"supercategory": "electronic", "isthing": 1, "id": 74, "name": "mouse"}, {"supercategory": "electronic", "isthing": 1, "id": 75, "name": "remote"}, {"supercategory": "electronic", "isthing": 1, "id": 76, "name": "keyboard"}, {"supercategory": "electronic", "isthing": 1, "id": 77, "name": "cell phone"}, {"supercategory": "appliance", "isthing": 1, "id": 78, "name": "microwave"}, {"supercategory": "appliance", "isthing": 1, "id": 79, "name": "oven"}, {"supercategory": "appliance", "isthing": 1, "id": 80, "name": "toaster"}, {"supercategory": "appliance", "isthing": 1, "id": 81, "name": "sink"}, {"supercategory": "appliance", "isthing": 1, "id": 82, "name": "refrigerator"}, {"supercategory": "indoor", "isthing": 1, "id": 84, "name": "book"}, {"supercategory": "indoor", "isthing": 1, "id": 85, "name": "clock"}, {"supercategory": "indoor", "isthing": 1, "id": 86, "name": "vase"}, {"supercategory": "indoor", "isthing": 1, "id": 87, "name": "scissors"}, {"supercategory": "indoor", "isthing": 1, "id": 88, "name": "teddy bear"}, {"supercategory": "indoor", "isthing": 1, "id": 89, "name": "hair drier"}, {"supercategory": "indoor", "isthing": 1, "id": 90, "name": "toothbrush"}, {"supercategory": "textile", "isthing": 0, "id": 92, "name": "banner"}, {"supercategory": "textile", "isthing": 0, "id": 93, "name": "blanket"}, {"supercategory": "plant", "isthing": 0, "id": 94, "name": "branch"}, {"supercategory": "building", "isthing": 0, "id": 95, "name": "bridge"}, {"supercategory": "building", "isthing": 0, "id": 96, "name": "building-other"}, {"supercategory": "plant", "isthing": 0, "id": 97, "name": "bush"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 98, "name": "cabinet"}, {"supercategory": "structural", "isthing": 0, "id": 99, "name": "cage"}, {"supercategory": "raw-material", "isthing": 0, "id": 100, "name": "cardboard"}, {"supercategory": "floor", "isthing": 0, "id": 101, "name": "carpet"}, {"supercategory": "ceiling", "isthing": 0, "id": 102, "name": "ceiling-other"}, {"supercategory": "ceiling", "isthing": 0, "id": 103, "name": "ceiling-tile"}, {"supercategory": "textile", "isthing": 0, "id": 104, "name": "cloth"}, {"supercategory": "textile", "isthing": 0, "id": 105, "name": "clothes"}, {"supercategory": "sky", "isthing": 0, "id": 106, "name": "clouds"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 107, "name": "counter"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 108, "name": "cupboard"}, {"supercategory": "textile", "isthing": 0, "id": 109, "name": "curtain"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 110, "name": "desk-stuff"}, {"supercategory": "ground", "isthing": 0, "id": 111, "name": "dirt"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 112, "name": "door-stuff"}, {"supercategory": "structural", "isthing": 0, "id": 113, "name": "fence"}, {"supercategory": "floor", "isthing": 0, "id": 114, "name": "floor-marble"}, {"supercategory": "floor", "isthing": 0, "id": 115, "name": "floor-other"}, {"supercategory": "floor", "isthing": 0, "id": 116, "name": "floor-stone"}, {"supercategory": "floor", "isthing": 0, "id": 117, "name": "floor-tile"}, {"supercategory": "floor", "isthing": 0, "id": 118, "name": "floor-wood"}, {"supercategory": "plant", "isthing": 0, "id": 119, "name": "flower"}, {"supercategory": "water", "isthing": 0, "id": 120, "name": "fog"}, {"supercategory": "food-stuff", "isthing": 0, "id": 121, "name": "food-other"}, {"supercategory": "food-stuff", "isthing": 0, "id": 122, "name": "fruit"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 123, "name": "furniture-other"}, {"supercategory": "plant", "isthing": 0, "id": 124, "name": "grass"}, {"supercategory": "ground", "isthing": 0, "id": 125, "name": "gravel"}, {"supercategory": "ground", "isthing": 0, "id": 126, "name": "ground-other"}, {"supercategory": "solid", "isthing": 0, "id": 127, "name": "hill"}, {"supercategory": "building", "isthing": 0, "id": 128, "name": "house"}, {"supercategory": "plant", "isthing": 0, "id": 129, "name": "leaves"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 130, "name": "light"}, {"supercategory": "textile", "isthing": 0, "id": 131, "name": "mat"}, {"supercategory": "raw-material", "isthing": 0, "id": 132, "name": "metal"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 133, "name": "mirror-stuff"}, {"supercategory": "plant", "isthing": 0, "id": 134, "name": "moss"}, {"supercategory": "solid", "isthing": 0, "id": 135, "name": "mountain"}, {"supercategory": "ground", "isthing": 0, "id": 136, "name": "mud"}, {"supercategory": "textile", "isthing": 0, "id": 137, "name": "napkin"}, {"supercategory": "structural", "isthing": 0, "id": 138, "name": "net"}, {"supercategory": "raw-material", "isthing": 0, "id": 139, "name": "paper"}, {"supercategory": "ground", "isthing": 0, "id": 140, "name": "pavement"}, {"supercategory": "textile", "isthing": 0, "id": 141, "name": "pillow"}, {"supercategory": "plant", "isthing": 0, "id": 142, "name": "plant-other"}, {"supercategory": "raw-material", "isthing": 0, "id": 143, "name": "plastic"}, {"supercategory": "ground", "isthing": 0, "id": 144, "name": "platform"}, {"supercategory": "ground", "isthing": 0, "id": 145, "name": "playingfield"}, {"supercategory": "structural", "isthing": 0, "id": 146, "name": "railing"}, {"supercategory": "ground", "isthing": 0, "id": 147, "name": "railroad"}, {"supercategory": "water", "isthing": 0, "id": 148, "name": "river"}, {"supercategory": "ground", "isthing": 0, "id": 149, "name": "road"}, {"supercategory": "solid", "isthing": 0, "id": 150, "name": "rock"}, {"supercategory": "building", "isthing": 0, "id": 151, "name": "roof"}, {"supercategory": "textile", "isthing": 0, "id": 152, "name": "rug"}, {"supercategory": "food-stuff", "isthing": 0, "id": 153, "name": "salad"}, {"supercategory": "ground", "isthing": 0, "id": 154, "name": "sand"}, {"supercategory": "water", "isthing": 0, "id": 155, "name": "sea"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 156, "name": "shelf"}, {"supercategory": "sky", "isthing": 0, "id": 157, "name": "sky-other"}, {"supercategory": "building", "isthing": 0, "id": 158, "name": "skyscraper"}, {"supercategory": "ground", "isthing": 0, "id": 159, "name": "snow"}, {"supercategory": "solid", "isthing": 0, "id": 160, "name": "solid-other"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 161, "name": "stairs"}, {"supercategory": "solid", "isthing": 0, "id": 162, "name": "stone"}, {"supercategory": "plant", "isthing": 0, "id": 163, "name": "straw"}, {"supercategory": "structural", "isthing": 0, "id": 164, "name": "structural-other"}, {"supercategory": "furniture-stuff", "isthing": 0, "id": 165, "name": "table"}, {"supercategory": "building", "isthing": 0, "id": 166, "name": "tent"}, {"supercategory": "textile", "isthing": 0, "id": 167, "name": "textile-other"}, {"supercategory": "textile", "isthing": 0, "id": 168, "name": "towel"}, {"supercategory": "plant", "isthing": 0, "id": 169, "name": "tree"}, {"supercategory": "food-stuff", "isthing": 0, "id": 170, "name": "vegetable"}, {"supercategory": "wall", "isthing": 0, "id": 171, "name": "wall-brick"}, {"supercategory": "wall", "isthing": 0, "id": 172, "name": "wall-concrete"}, {"supercategory": "wall", "isthing": 0, "id": 173, "name": "wall-other"}, {"supercategory": "wall", "isthing": 0, "id": 174, "name": "wall-panel"}, {"supercategory": "wall", "isthing": 0, "id": 175, "name": "wall-stone"}, {"supercategory": "wall", "isthing": 0, "id": 176, "name": "wall-tile"}, {"supercategory": "wall", "isthing": 0, "id": 177, "name": "wall-wood"}, {"supercategory": "water", "isthing": 0, "id": 178, "name": "water-other"}, {"supercategory": "water", "isthing": 0, "id": 179, "name": "waterdrops"}, {"supercategory": "window", "isthing": 0, "id": 180, "name": "window-blind"}, {"supercategory": "window", "isthing": 0, "id": 181, "name": "window-other"}, {"supercategory": "solid", "isthing": 0, "id": 182, "name": "wood"}, {"supercategory": "other", "isthing": 0, "id": 183, "name": "other"}]
4 | }
--------------------------------------------------------------------------------
/model/vit.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | import torch
4 | import torch.nn as nn
5 |
6 | from functools import partial
7 |
8 |
9 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
10 | def norm_cdf(x):
11 | # Computes standard normal cumulative distribution function
12 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
13 |
14 | if (mean < a - 2 * std) or (mean > b + 2 * std):
15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16 | "The distribution of values may be incorrect.",
17 | stacklevel=2)
18 |
19 | with torch.no_grad():
20 | # Values are generated by using a truncated uniform distribution and
21 | # then using the inverse CDF for the normal distribution.
22 | # Get upper and lower cdf values
23 | l = norm_cdf((a - mean) / std)
24 | u = norm_cdf((b - mean) / std)
25 |
26 | # Uniformly fill tensor with values from [l, u], then translate to
27 | # [2l-1, 2u-1].
28 | tensor.uniform_(2 * l - 1, 2 * u - 1)
29 |
30 | # Use inverse cdf transform for normal distribution to get truncated
31 | # standard normal
32 | tensor.erfinv_()
33 |
34 | # Transform to proper mean, std
35 | tensor.mul_(std * math.sqrt(2.))
36 | tensor.add_(mean)
37 |
38 | # Clamp to ensure it's in the proper range
39 | tensor.clamp_(min=a, max=b)
40 | return tensor
41 |
42 |
43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44 | # type: (Tensor, float, float, float, float) -> Tensor
45 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
46 |
47 |
48 | def drop_path(x, drop_prob: float = 0., training: bool = False):
49 | if drop_prob == 0. or not training:
50 | return x
51 | keep_prob = 1 - drop_prob
52 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
53 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
54 | random_tensor.floor_() # binarize
55 | output = x.div(keep_prob) * random_tensor
56 | return output
57 |
58 |
59 | class DropPath(nn.Module):
60 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61 | """
62 |
63 | def __init__(self, drop_prob=None):
64 | super(DropPath, self).__init__()
65 | self.drop_prob = drop_prob
66 |
67 | def forward(self, x):
68 | return drop_path(x, self.drop_prob, self.training)
69 |
70 |
71 | class Mlp(nn.Module):
72 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
73 | super().__init__()
74 | out_features = out_features or in_features
75 | hidden_features = hidden_features or in_features
76 | self.fc1 = nn.Linear(in_features, hidden_features)
77 | self.act = act_layer()
78 | self.fc2 = nn.Linear(hidden_features, out_features)
79 | self.drop = nn.Dropout(drop)
80 |
81 | def forward(self, x):
82 | x = self.fc1(x)
83 | x = self.act(x)
84 | x = self.drop(x)
85 | x = self.fc2(x)
86 | x = self.drop(x)
87 | return x
88 |
89 |
90 | class Attention(nn.Module):
91 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
92 | super().__init__()
93 | self.num_heads = num_heads
94 | head_dim = dim // num_heads
95 | self.scale = qk_scale or head_dim ** -0.5
96 |
97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98 | self.attn_drop = nn.Dropout(attn_drop)
99 | self.proj = nn.Linear(dim, dim)
100 | self.proj_drop = nn.Dropout(proj_drop)
101 |
102 | def forward(self, x):
103 | B, N, C = x.shape
104 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
105 | q, k, v = qkv[0], qkv[1], qkv[2]
106 |
107 | attn = (q @ k.transpose(-2, -1)) * self.scale
108 | attn = attn.softmax(dim=-1)
109 | attn = self.attn_drop(attn)
110 |
111 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
112 | x = self.proj(x)
113 | x = self.proj_drop(x)
114 | return x, attn
115 |
116 |
117 | class Block(nn.Module):
118 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
119 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
120 | super().__init__()
121 | self.norm1 = norm_layer(dim)
122 | self.attn = Attention(
123 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
124 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
125 | self.norm2 = norm_layer(dim)
126 | mlp_hidden_dim = int(dim * mlp_ratio)
127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
128 |
129 | def forward(self, x, return_attention=False):
130 | y, attn = self.attn(self.norm1(x))
131 | x = x + self.drop_path(y)
132 | x = x + self.drop_path(self.mlp(self.norm2(x)))
133 | if return_attention:
134 | return x, attn
135 | return x
136 |
137 |
138 | class PatchEmbed(nn.Module):
139 | """ Image to Patch Embedding
140 | """
141 |
142 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
143 | super().__init__()
144 | num_patches = (img_size // patch_size) * (img_size // patch_size)
145 | self.img_size = img_size
146 | self.patch_size = patch_size
147 | self.num_patches = num_patches
148 |
149 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
150 |
151 | def forward(self, x):
152 | B, C, H, W = x.shape
153 | x = self.proj(x).flatten(2).transpose(1, 2)
154 | return x
155 |
156 |
157 | class VisionTransformer(nn.Module):
158 | """ Vision Transformer """
159 |
160 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
161 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
162 | norm_layer=nn.LayerNorm):
163 | super().__init__()
164 | self.num_features = self.embed_dim = embed_dim
165 |
166 | self.patch_embed = PatchEmbed(
167 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
168 | num_patches = self.patch_embed.num_patches
169 |
170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
171 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
172 | self.pos_drop = nn.Dropout(p=drop_rate)
173 |
174 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
175 | self.blocks = nn.ModuleList([
176 | Block(
177 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
178 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
179 | for i in range(depth)])
180 | self.norm = norm_layer(embed_dim)
181 |
182 | trunc_normal_(self.pos_embed, std=.02)
183 | trunc_normal_(self.cls_token, std=.02)
184 | self.apply(self._init_weights)
185 |
186 | def _init_weights(self, m):
187 | if isinstance(m, nn.Linear):
188 | trunc_normal_(m.weight, std=.02)
189 | if m.bias is not None:
190 | nn.init.constant_(m.bias, 0)
191 | elif isinstance(m, nn.LayerNorm):
192 | nn.init.constant_(m.bias, 0)
193 | nn.init.constant_(m.weight, 1.0)
194 | elif isinstance(m, nn.Conv2d):
195 | trunc_normal_(m.weight, std=.02)
196 | if m.bias is not None:
197 | nn.init.constant_(m.bias, 0)
198 |
199 | def interpolate_pos_encoding(self, x, w, h):
200 | npatch = x.shape[1] - 1
201 | N = self.pos_embed.shape[1] - 1
202 | if npatch == N and w == h:
203 | return self.pos_embed
204 | class_pos_embed = self.pos_embed[:, 0]
205 | patch_pos_embed = self.pos_embed[:, 1:]
206 | dim = x.shape[-1]
207 | w0 = w // self.patch_embed.patch_size
208 | h0 = h // self.patch_embed.patch_size
209 | # we add a small number to avoid floating point error in the interpolation
210 | w0, h0 = w0 + 0.1, h0 + 0.1
211 | patch_pos_embed = nn.functional.interpolate(
212 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
213 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
214 | mode='bicubic',
215 | )
216 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
217 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
218 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
219 |
220 | def get_intermediate_layers(self, x, n=1):
221 | x = self.prepare_tokens(x)
222 | # we return the output tokens from the `n` last blocks
223 | output = []
224 | for i, blk in enumerate(self.blocks):
225 | x = blk(x)
226 | if len(self.blocks) - i <= n:
227 | output.append(self.norm(x))
228 | return output
229 |
230 | def prepare_tokens(self, x):
231 | B, nc, w, h = x.shape
232 | x = self.patch_embed(x) # patch linear embedding
233 |
234 | # add the [CLS] token to the embed patch tokens
235 | cls_tokens = self.cls_token.expand(B, -1, -1)
236 | x = torch.cat((cls_tokens, x), dim=1)
237 |
238 | # add positional encoding to each token
239 | x = x + self.interpolate_pos_encoding(x, w, h)
240 |
241 | return self.pos_drop(x)
242 |
243 | @torch.no_grad()
244 | def forward(self, inputs, nmb_crops=(1,0), last_self_attention=False):
245 | if not isinstance(inputs, list):
246 | inputs = [inputs]
247 | idx_crops = [1, ] # for inference
248 | if sum(nmb_crops) > 1:
249 | # for training
250 | idx_crops.append(sum(nmb_crops))
251 |
252 | assert len(idx_crops) <= 2, "Only supporting at most two different type of crops (global and local crops)"
253 | start_idx = 0
254 | for end_idx in idx_crops:
255 | _out = torch.cat(inputs[start_idx:end_idx])
256 | _out = self.forward_backbone(_out, last_self_attention=last_self_attention)
257 | if last_self_attention:
258 | _out, _attn = _out
259 | spatial_tokens = _out[:, 1:] # remove CLS token
260 | spatial_tokens = spatial_tokens.reshape(-1, self.embed_dim) # [B*196/36, embed_dim]
261 |
262 | if start_idx == 0:
263 | output_spatial = spatial_tokens
264 | if last_self_attention:
265 | # only keep 1st global crop attention
266 | attentions = _attn
267 | else:
268 | output_spatial = torch.cat((output_spatial, spatial_tokens))
269 | if last_self_attention:
270 | attentions = torch.cat((attentions, _attn))
271 | start_idx = end_idx
272 |
273 | result = output_spatial
274 | if last_self_attention:
275 | result = (result, attentions)
276 | return result
277 |
278 | def forward_backbone(self, x, last_self_attention=False):
279 | x = self.prepare_tokens(x)
280 | for i, blk in enumerate(self.blocks):
281 | if i < len(self.blocks) - 1:
282 | x = blk(x)
283 | else:
284 | x = blk(x, return_attention=last_self_attention)
285 | if last_self_attention:
286 | x, attn = x
287 | x = self.norm(x)
288 | if last_self_attention:
289 | return x, attn[:, :, 0, 1:] # [B, heads, cls, cls-patch]
290 | return x
291 |
292 | def get_last_selfattention(self, x):
293 | x = self.prepare_tokens(x)
294 | for i, blk in enumerate(self.blocks):
295 | if i < len(self.blocks) - 1:
296 | x = blk(x)
297 | else:
298 | # return attention of the last block
299 | return blk(x, return_attention=True)[1]
300 |
301 | def get_cls_tokens(self, x):
302 | x = self.prepare_tokens(x)
303 | for blk in self.blocks:
304 | x = blk(x)
305 | x = self.norm(x)
306 | return x[:, 0]
307 |
308 |
309 | def vit_small(patch_size=16, **kwargs):
310 | model = VisionTransformer(
311 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
312 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
313 | return model
314 |
315 |
316 | def vit_base(patch_size=16, **kwargs):
317 | model = VisionTransformer(
318 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
319 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
320 | return model
321 |
322 |
323 | def vit_large(patch_size=16, **kwargs):
324 | model = VisionTransformer(
325 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
326 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
327 | return model
328 |
329 |
330 | if __name__ == '__main__':
331 | clsQueries = nn.Embedding(2, 6)
332 | print(clsQueries.weight)
333 | trunc_normal_(clsQueries.weight, std=.02)
334 | print(clsQueries.weight, clsQueries.weight.mean(), clsQueries.weight.sum())
335 |
336 | cls = nn.Parameter(torch.zeros(1, 1, 5))
337 | trunc_normal_(cls, std=.02)
338 | print(cls, cls.mean(), cls.sum())
339 |
--------------------------------------------------------------------------------
/evaluate/eval_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions used in metrics computation."""
2 | from typing import Optional
3 |
4 | import scipy.optimize
5 | import torch
6 | import torchmetrics
7 |
8 |
9 | class ARIMetric(torchmetrics.Metric):
10 | """Computes ARI metric."""
11 |
12 | def __init__(
13 | self,
14 | foreground: bool = True,
15 | convert_target_one_hot: bool = False,
16 | ignore_overlaps: bool = False,
17 | ):
18 | super().__init__()
19 | self.foreground = foreground
20 | self.convert_target_one_hot = convert_target_one_hot
21 | self.ignore_overlaps = ignore_overlaps
22 | self.add_state(
23 | "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
24 | )
25 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
26 |
27 | def update(
28 | self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
29 | ):
30 | """Update this metric.
31 |
32 | Args:
33 | prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
34 | number of classes.
35 | target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
36 | number of classes.
37 | ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
38 | """
39 | if prediction.ndim == 5:
40 | # Merge frames, height and width to single dimension.
41 | prediction = prediction.transpose(1, 2).flatten(-3, -1)
42 | target = target.transpose(1, 2).flatten(-3, -1)
43 | if ignore is not None:
44 | ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
45 | elif prediction.ndim == 4:
46 | # Merge height and width to single dimension.
47 | prediction = prediction.flatten(-2, -1)
48 | target = target.flatten(-2, -1)
49 | if ignore is not None:
50 | ignore = ignore.to(torch.bool).flatten(-2, -1)
51 | else:
52 | raise ValueError(f"Incorrect input shape: f{prediction.shape}")
53 |
54 | if self.ignore_overlaps:
55 | overlaps = (target > 0).sum(1, keepdim=True) > 1
56 | if ignore is None:
57 | ignore = overlaps
58 | else:
59 | ignore = ignore | overlaps
60 |
61 | if ignore is not None:
62 | assert ignore.ndim == 3 and ignore.shape[1] == 1
63 | prediction = prediction.clone()
64 | prediction[ignore.expand_as(prediction)] = 0
65 | target = target.clone()
66 | target[ignore.expand_as(target)] = 0
67 |
68 | # Make channels / gt labels the last dimension.
69 | prediction = prediction.transpose(-2, -1)
70 | target = target.transpose(-2, -1)
71 |
72 | if self.convert_target_one_hot:
73 | target_oh = tensor_to_one_hot(target, dim=2)
74 | # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for
75 | # this (then it is technically not one-hot anymore).
76 | target_oh[:, :, 0][target.sum(dim=2) == 0] = 0
77 | target = target_oh
78 |
79 | # Should be either 0 (empty, padding) or 1 (single object).
80 | assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive"
81 |
82 | if self.foreground:
83 | ari = fg_adjusted_rand_index(prediction, target)
84 | else:
85 | ari = adjusted_rand_index(prediction, target)
86 |
87 | print("\tupdating ari... ", ari.item())
88 |
89 | self.values += ari.sum()
90 | self.total += len(ari)
91 |
92 | def compute(self) -> torch.Tensor:
93 | return self.values / self.total
94 |
95 |
96 | class UnsupervisedMaskIoUMetric(torchmetrics.Metric):
97 | """Computes IoU metric for segmentation masks when correspondences to ground truth are not known.
98 |
99 | Uses Hungarian matching to compute the assignment between predicted classes and ground truth
100 | classes.
101 |
102 | Args:
103 | use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
104 | If `False`, class probabilities are turned into mask using a softmax instead.
105 | threshold: Value to use for thresholding masks.
106 | matching: Approach to match predicted to ground truth classes. For "hungarian", computes
107 | assignment that maximizes total IoU between all classes. For "best_overlap", uses the
108 | predicted class with maximum overlap for each ground truth class. Using "best_overlap"
109 | leads to the "average best overlap" metric.
110 | compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes
111 | that were "discovered", meaning that they have an IoU greater than some threshold.
112 | correct_localization: Instead of the IoU, compute the fraction of images on which at least
113 | one ground truth class was correctly localised, meaning that they have an IoU
114 | greater than some threshold.
115 | discovery_threshold: Minimum IoU to count a class as discovered/correctly localized.
116 | ignore_background: If true, assume class at index 0 of ground truth masks is background class
117 | that is removed before computing IoU.
118 | ignore_overlaps: If true, remove points where ground truth masks has overlappign classes from
119 | predictions and ground truth masks.
120 | """
121 |
122 | def __init__(
123 | self,
124 | use_threshold: bool = False,
125 | threshold: float = 0.5,
126 | matching: str = "hungarian",
127 | compute_discovery_fraction: bool = False,
128 | correct_localization: bool = False,
129 | discovery_threshold: float = 0.5,
130 | ignore_background: bool = False,
131 | ignore_overlaps: bool = False,
132 | ):
133 | super().__init__()
134 | self.use_threshold = use_threshold
135 | self.threshold = threshold
136 | self.discovery_threshold = discovery_threshold
137 | self.compute_discovery_fraction = compute_discovery_fraction
138 | self.correct_localization = correct_localization
139 | if compute_discovery_fraction and correct_localization:
140 | raise ValueError(
141 | "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled."
142 | )
143 |
144 | matchings = ("hungarian", "best_overlap")
145 | if matching not in matchings:
146 | raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.")
147 | self.matching = matching
148 | self.ignore_background = ignore_background
149 | self.ignore_overlaps = ignore_overlaps
150 |
151 | self.add_state(
152 | "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
153 | )
154 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
155 |
156 | def update(
157 | self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
158 | ):
159 | """Update this metric.
160 |
161 | Args:
162 | prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
163 | number of classes. Assumes class probabilities as inputs.
164 | target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
165 | number of classes.
166 | ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
167 | """
168 | if prediction.ndim == 5:
169 | # Merge frames, height and width to single dimension.
170 | predictions = prediction.transpose(1, 2).flatten(-3, -1)
171 | targets = target.transpose(1, 2).flatten(-3, -1)
172 | if ignore is not None:
173 | ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
174 | elif prediction.ndim == 4:
175 | # Merge height and width to single dimension.
176 | predictions = prediction.flatten(-2, -1)
177 | targets = target.flatten(-2, -1)
178 | if ignore is not None:
179 | ignore = ignore.to(torch.bool).flatten(-2, -1)
180 | else:
181 | raise ValueError(f"Incorrect input shape: f{prediction.shape}")
182 |
183 | if self.use_threshold:
184 | predictions = predictions > self.threshold
185 | else:
186 | indices = torch.argmax(predictions, dim=1)
187 | predictions = torch.nn.functional.one_hot(indices, num_classes=predictions.shape[1])
188 | predictions = predictions.transpose(1, 2)
189 |
190 | if self.ignore_background:
191 | targets = targets[:, 1:]
192 |
193 | targets = targets > 0 # Ensure masks are binary
194 |
195 | if self.ignore_overlaps:
196 | overlaps = targets.sum(1, keepdim=True) > 1
197 | if ignore is None:
198 | ignore = overlaps
199 | else:
200 | ignore = ignore | overlaps
201 |
202 | if ignore is not None:
203 | assert ignore.ndim == 3 and ignore.shape[1] == 1
204 | predictions[ignore.expand_as(predictions)] = 0
205 | targets[ignore.expand_as(targets)] = 0
206 |
207 | # Should be either 0 (empty, padding) or 1 (single object).
208 | assert torch.all(targets.sum(dim=1) < 2), "Issues with target format, mask non-exclusive"
209 |
210 | for pred, target in zip(predictions, targets):
211 | nonzero_classes = torch.sum(target, dim=-1) > 0
212 | target = target[nonzero_classes] # Remove empty (e.g. padded) classes
213 | if len(target) == 0:
214 | continue # Skip elements without any target mask
215 |
216 | iou_per_class = unsupervised_mask_iou(
217 | pred, target, matching=self.matching, reduction="none"
218 | )
219 |
220 | if self.compute_discovery_fraction:
221 | discovered = iou_per_class > self.discovery_threshold
222 | self.values += discovered.sum() / len(discovered)
223 | elif self.correct_localization:
224 | correctly_localized = torch.any(iou_per_class > self.discovery_threshold)
225 | self.values += correctly_localized.sum()
226 | else:
227 | self.values += iou_per_class.mean()
228 | self.total += 1
229 |
230 | def compute(self) -> torch.Tensor:
231 | if self.total == 0:
232 | return torch.zeros_like(self.values)
233 | else:
234 | return self.values / self.total
235 |
236 |
237 | class MaskCorLocMetric(UnsupervisedMaskIoUMetric):
238 | def __init__(self, **kwargs):
239 | super().__init__(matching="best_overlap", correct_localization=True, **kwargs)
240 |
241 |
242 | class AverageBestOverlapMetric(UnsupervisedMaskIoUMetric):
243 | def __init__(self, **kwargs):
244 | super().__init__(matching="best_overlap", **kwargs)
245 |
246 |
247 | class BestOverlapObjectRecoveryMetric(UnsupervisedMaskIoUMetric):
248 | def __init__(self, **kwargs):
249 | super().__init__(matching="best_overlap", compute_discovery_fraction=True, **kwargs)
250 |
251 |
252 | def unsupervised_mask_iou(
253 | pred_mask: torch.Tensor,
254 | true_mask: torch.Tensor,
255 | matching: str = "hungarian",
256 | reduction: str = "mean",
257 | iou_empty: float = 0.0,
258 | ) -> torch.Tensor:
259 | """Compute intersection-over-union (IoU) between masks with unknown class correspondences.
260 |
261 | This metric is also known as Jaccard index. Note that this is a non-batched implementation.
262 |
263 | Args:
264 | pred_mask: Predicted mask of shape (C, N), where C is the number of predicted classes and
265 | N is the number of points. Masks are assumed to be binary.
266 | true_mask: Ground truth mask of shape (K, N), where K is the number of ground truth
267 | classes and N is the number of points. Masks are assumed to be binary.
268 | matching: How to match predicted classes to ground truth classes. For "hungarian", computes
269 | assignment that maximizes total IoU between all classes. For "best_overlap", uses the
270 | predicted class with maximum overlap for each ground truth class (each predicted class
271 | can be assigned to multiple ground truth classes). Empty ground truth classes are
272 | assigned IoU of zero.
273 | reduction: If "mean", return IoU averaged over classes. If "none", return per-class IoU.
274 | iou_empty: IoU for the case when a class does not occur, but was also not predicted.
275 |
276 | Returns:
277 | Mean IoU over classes if reduction is `mean`, tensor of shape (K,) containing per-class IoU
278 | otherwise.
279 | """
280 | assert pred_mask.ndim == 2
281 | assert true_mask.ndim == 2
282 | n_gt_classes = len(true_mask)
283 | pred_mask = pred_mask.unsqueeze(1).to(torch.bool)
284 | true_mask = true_mask.unsqueeze(0).to(torch.bool)
285 |
286 | intersection = torch.sum(pred_mask & true_mask, dim=-1).to(torch.float64)
287 | union = torch.sum(pred_mask | true_mask, dim=-1).to(torch.float64)
288 | pairwise_iou = intersection / union
289 |
290 | # Remove NaN from divide-by-zero: class does not occur, and class was not predicted.
291 | pairwise_iou[union == 0] = iou_empty
292 |
293 | if matching == "hungarian":
294 | pred_idxs, true_idxs = scipy.optimize.linear_sum_assignment(
295 | pairwise_iou.cpu(), maximize=True
296 | )
297 | pred_idxs = torch.as_tensor(pred_idxs, dtype=torch.int64, device=pairwise_iou.device)
298 | true_idxs = torch.as_tensor(true_idxs, dtype=torch.int64, device=pairwise_iou.device)
299 | elif matching == "best_overlap":
300 | non_empty_gt = torch.sum(true_mask.squeeze(0), dim=1) > 0
301 | pred_idxs = torch.argmax(pairwise_iou, dim=0)[non_empty_gt]
302 | true_idxs = torch.arange(pairwise_iou.shape[1])[non_empty_gt]
303 | else:
304 | raise ValueError(f"Unknown matching {matching}")
305 |
306 | matched_iou = pairwise_iou[pred_idxs, true_idxs]
307 | iou = torch.zeros(n_gt_classes, dtype=torch.float64, device=pairwise_iou.device)
308 | iou[true_idxs] = matched_iou
309 |
310 | if reduction == "mean":
311 | return iou.mean()
312 | else:
313 | return iou
314 |
315 |
316 | def tensor_to_one_hot(tensor: torch.Tensor, dim: int) -> torch.Tensor:
317 | """Convert tensor to one-hot encoding by using maximum across dimension as one-hot element."""
318 | assert 0 <= dim
319 | max_idxs = torch.argmax(tensor, dim=dim, keepdim=True)
320 | shape = [1] * dim + [-1] + [1] * (tensor.ndim - dim - 1)
321 | one_hot = max_idxs == torch.arange(tensor.shape[dim], device=tensor.device).view(*shape)
322 | return one_hot.to(torch.long)
323 |
324 |
325 | def adjusted_rand_index(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> torch.Tensor:
326 | """Computes adjusted Rand index (ARI), a clustering similarity score.
327 |
328 | This implementation ignores points with no cluster label in `true_mask` (i.e. those points for
329 | which `true_mask` is a zero vector). In the context of segmentation, that means this function
330 | can ignore points in an image corresponding to the background (i.e. not to an object).
331 |
332 | Implementation adapted from https://github.com/deepmind/multi_object_datasets and
333 | https://github.com/google-research/slot-attention-video/blob/main/savi/lib/metrics.py
334 |
335 | Args:
336 | pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape
337 | (batch_size, n_points, n_pred_clusters).
338 | true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points,
339 | n_true_clusters).
340 |
341 | Returns:
342 | ARI scores of shape (batch_size,).
343 | """
344 | n_pred_clusters = pred_mask.shape[-1]
345 | pred_cluster_ids = torch.argmax(pred_mask, axis=-1)
346 |
347 | # Convert true and predicted clusters to one-hot ('oh') representations. We use float64 here on
348 | # purpose, otherwise mixed precision training automatically casts to FP16 in some of the
349 | # operations below, which can create overflows.
350 | true_mask_oh = true_mask.to(torch.float64) # already one-hot
351 | pred_mask_oh = torch.nn.functional.one_hot(pred_cluster_ids, n_pred_clusters).to(torch.float64)
352 |
353 | n_ij = torch.einsum("bnc,bnk->bck", true_mask_oh, pred_mask_oh)
354 | a = torch.sum(n_ij, axis=-1)
355 | b = torch.sum(n_ij, axis=-2)
356 | n_fg_points = torch.sum(a, axis=1)
357 |
358 | rindex = torch.sum(n_ij * (n_ij - 1), axis=(1, 2))
359 | aindex = torch.sum(a * (a - 1), axis=1)
360 | bindex = torch.sum(b * (b - 1), axis=1)
361 | expected_rindex = aindex * bindex / torch.clamp(n_fg_points * (n_fg_points - 1), min=1)
362 | max_rindex = (aindex + bindex) / 2
363 | denominator = max_rindex - expected_rindex
364 | ari = (rindex - expected_rindex) / denominator
365 |
366 | # There are two cases for which the denominator can be zero:
367 | # 1. If both true_mask and pred_mask assign all pixels to a single cluster.
368 | # (max_rindex == expected_rindex == rindex == n_fg_points * (n_fg_points-1))
369 | # 2. If both true_mask and pred_mask assign max 1 point to each cluster.
370 | # (max_rindex == expected_rindex == rindex == 0)
371 | # In both cases, we want the ARI score to be 1.0:
372 | return torch.where(denominator > 0, ari, torch.ones_like(ari))
373 |
374 |
375 | def fg_adjusted_rand_index(
376 | pred_mask: torch.Tensor, true_mask: torch.Tensor, bg_dim: int = 0
377 | ) -> torch.Tensor:
378 | """Compute adjusted random index using only foreground groups (FG-ARI).
379 |
380 | Args:
381 | pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape
382 | (batch_size, n_points, n_pred_clusters).
383 | true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points,
384 | n_true_clusters).
385 | bg_dim: Index of background class in true mask.
386 |
387 | Returns:
388 | ARI scores of shape (batch_size,).
389 | """
390 | n_true_clusters = true_mask.shape[-1]
391 | assert 0 <= bg_dim < n_true_clusters
392 | if bg_dim == 0:
393 | true_mask_only_fg = true_mask[..., 1:]
394 | elif bg_dim == n_true_clusters - 1:
395 | true_mask_only_fg = true_mask[..., :-1]
396 | else:
397 | true_mask_only_fg = torch.cat(
398 | (true_mask[..., :bg_dim], true_mask[..., bg_dim + 1 :]), dim=-1
399 | )
400 |
401 | return adjusted_rand_index(pred_mask, true_mask_only_fg)
402 |
403 |
404 | def _all_equal_masked(values: torch.Tensor, mask: torch.Tensor, dim=-1) -> torch.Tensor:
405 | """Check if all masked values along a dimension of a tensor are the same.
406 |
407 | All non-masked values are considered as true, i.e. if no value is masked, true is returned
408 | for this dimension.
409 | """
410 | assert mask.dtype == torch.bool
411 | _, first_non_masked_idx = torch.max(mask, dim=dim)
412 |
413 | comparison_value = values.gather(index=first_non_masked_idx.unsqueeze(dim), dim=dim)
414 |
415 | return torch.logical_or(~mask, values == comparison_value).all(dim=dim)
416 |
417 |
418 | def masks_to_bboxes(masks: torch.Tensor, empty_value: float = -1.0) -> torch.Tensor:
419 | """Compute bounding boxes around the provided masks.
420 |
421 | Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py
422 |
423 | Args:
424 | masks: Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial
425 | dimensions.
426 | empty_value: Value bounding boxes should contain for empty masks.
427 |
428 | Returns:
429 | Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1)
430 | is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right
431 | corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain
432 | `empty_value` instead.
433 | """
434 | masks = masks.bool()
435 | if masks.numel() == 0:
436 | return torch.zeros((0, 4), device=masks.device)
437 |
438 | large_value = 1e8
439 | inv_mask = ~masks
440 |
441 | h, w = masks.shape[-2:]
442 |
443 | y = torch.arange(0, h, dtype=torch.float, device=masks.device)
444 | x = torch.arange(0, w, dtype=torch.float, device=masks.device)
445 | y, x = torch.meshgrid(y, x, indexing="ij")
446 |
447 | x_mask = masks * x.unsqueeze(0)
448 | x_max = x_mask.flatten(1).max(-1)[0]
449 | x_min = x_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
450 |
451 | y_mask = masks * y.unsqueeze(0)
452 | y_max = y_mask.flatten(1).max(-1)[0]
453 | y_min = y_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
454 |
455 | bboxes = torch.stack((x_min, y_min, x_max, y_max), dim=1)
456 | bboxes[x_min == large_value] = empty_value
457 |
458 | return bboxes
459 |
460 |
461 | def _remap_one_hot_mask(
462 | mask: torch.Tensor, new_classes: torch.Tensor, n_new_classes: int, strip_empty: bool = False
463 | ):
464 | """Remap classes from binary mask to new classes.
465 |
466 | In the case of an overlap of classes for a point, the new class with the highest ID is
467 | assigned to that point. If no class is assigned to a point, the point will have no class
468 | assigned after remapping as well.
469 |
470 | Args:
471 | mask: Binary mask of shape (B, P, K) where K is the number of old classes and P is the
472 | number of points.
473 | new_classes: Tensor of shape (B, K) containing ids of new classes for each old class.
474 | n_new_classes: Number of classes after remapping, i.e. highest class id that can occur.
475 | strip_empty: Whether to remove the empty pixels mask
476 |
477 | Returns:
478 | Tensor of shape (B, P, J), where J is the new number of classes.
479 | """
480 | assert new_classes.shape[1] == mask.shape[2]
481 | mask_dense = (mask * new_classes.unsqueeze(1)).max(dim=-1).values
482 | mask = torch.nn.functional.one_hot(mask_dense.to(torch.long), num_classes=n_new_classes + 1)
483 |
484 | if strip_empty:
485 | mask = mask[..., 1:]
486 |
487 | return mask
488 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Tuple, List, Dict
3 |
4 | import faiss
5 | import torch
6 | import numpy as np
7 | import os
8 | from datetime import datetime
9 | import glob
10 | import matplotlib.pyplot as plt
11 | from joblib import Parallel, delayed
12 | from scipy.optimize import linear_sum_assignment
13 | from skimage.measure import label
14 | from collections import deque, defaultdict
15 |
16 | from torch import nn
17 | from torchvision import transforms
18 | from torchvision.transforms import GaussianBlur
19 | from torchmetrics import Metric
20 |
21 |
22 | def save_checkpoint(state, is_best=0, gap=1, filename='models/checkpoint.pth', keep_all=False):
23 | torch.save(state, filename)
24 | last_epoch_path = os.path.join(os.path.dirname(filename),
25 | 'epoch%s.pth' % str(state['epoch'] - gap))
26 | if not keep_all:
27 | try:
28 | os.remove(last_epoch_path)
29 | except:
30 | pass
31 | if is_best:
32 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'model_best_*.pth'))
33 | for i in past_best:
34 | try:
35 | os.remove(i)
36 | except:
37 | pass
38 | torch.save(state, os.path.join(os.path.dirname(filename), 'model_best_epoch%s.pth' % str(state['epoch'])))
39 |
40 |
41 | class PredsmIoU(Metric):
42 | """
43 | Subclasses Metric. Computes mean Intersection over Union (mIoU) given ground-truth and predictions.
44 | .update() can be called repeatedly to add data from multiple validation loops.
45 | """
46 |
47 | def __init__(self,
48 | num_pred_classes: int,
49 | num_gt_classes: int):
50 | """
51 | :param num_pred_classes: The number of predicted classes.
52 | :param num_gt_classes: The number of gt classes.
53 | """
54 | super().__init__(dist_sync_on_step=False, compute_on_step=False)
55 | self.num_pred_classes = num_pred_classes
56 | self.num_gt_classes = num_gt_classes
57 | self.add_state("iou", [])
58 | self.add_state("iou_excludeFirst", [])
59 | self.n_jobs = -1
60 |
61 | def update(self, gt: torch.Tensor, pred: torch.Tensor, many_to_one=True, precision_based=True, linear_probe=False):
62 | pred = pred.cpu().numpy().astype(int)
63 | gt = gt.cpu().numpy().astype(int)
64 | assert len(np.unique(pred)) <= self.num_pred_classes
65 | assert np.max(pred) <= self.num_pred_classes
66 | iou_all, iou_excludeFirst = self.compute_miou(gt, pred, self.num_pred_classes, len(np.unique(gt)),
67 | many_to_one=many_to_one, precision_based=precision_based, linear_probe=linear_probe)
68 | self.iou.append(iou_all)
69 | self.iou_excludeFirst.append(iou_excludeFirst)
70 |
71 | def compute(self):
72 | """
73 | Compute mIoU
74 | """
75 | mIoU = np.mean(self.iou)
76 | mIoU_excludeFirst = np.mean(self.iou_excludeFirst)
77 | print('---mIoU computed---', mIoU)
78 | print('---mIoU exclude first---', mIoU_excludeFirst)
79 | return mIoU
80 |
81 | def compute_miou(self, gt: np.ndarray, pred: np.ndarray, num_pred: int, num_gt: int,
82 | many_to_one=False, precision_based=False, linear_probe=False):
83 | """
84 | Compute mIoU with optional hungarian matching or many-to-one matching (extracts information from labels).
85 | :param gt: numpy array with all flattened ground-truth class assignments per pixel
86 | :param pred: numpy array with all flattened class assignment predictions per pixel
87 | :param num_pred: number of predicted classes
88 | :param num_gt: number of ground truth classes
89 | :param many_to_one: Compute a many-to-one mapping of predicted classes to ground truth instead of hungarian
90 | matching.
91 | :param precision_based: Use precision as matching criteria instead of IoU for assigning predicted class to
92 | ground truth class.
93 | :param linear_probe: Skip hungarian / many-to-one matching. Used for evaluating predictions of fine-tuned heads.
94 | :return: mIoU over all classes, true positives per class, false negatives per class, false positives per class,
95 | reordered predictions matching gt
96 | """
97 | assert pred.shape == gt.shape
98 | print(f"unique semantic class = {np.unique(gt)}")
99 | gt_class = np.unique(gt).tolist()
100 | tp = [0] * num_gt
101 | fp = [0] * num_gt
102 | fn = [0] * num_gt
103 | iou = [0] * num_gt
104 |
105 | if linear_probe:
106 | reordered_preds = pred
107 | else:
108 | if many_to_one:
109 | match = self._original_match(num_pred, num_gt, pred, gt, precision_based=precision_based)
110 | # remap predictions
111 | reordered_preds = np.zeros(len(pred))
112 | for target_i, matched_preds in match.items():
113 | for pred_i in matched_preds:
114 | reordered_preds[pred == int(pred_i)] = int(target_i)
115 | else:
116 | match = self._hungarian_match(num_pred, num_gt, pred, gt)
117 | # remap predictions
118 | reordered_preds = np.zeros(len(pred))
119 | for target_i, pred_i in zip(*match):
120 | reordered_preds[pred == int(pred_i)] = int(target_i)
121 | # merge all unmatched predictions to background
122 | for unmatched_pred in np.delete(np.arange(num_pred), np.array(match[1])):
123 | reordered_preds[pred == int(unmatched_pred)] = 0
124 |
125 | # tp, fp, and fn evaluation
126 | for i_part in range(0, num_gt):
127 | tmp_all_gt = (gt == gt_class[i_part])
128 | tmp_pred = (reordered_preds == gt_class[i_part])
129 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred)
130 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred)
131 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred)
132 |
133 | # Calculate IoU per class
134 | for i_part in range(0, num_gt):
135 | iou[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8)
136 |
137 | print('\tiou = ', iou, np.mean(iou[1:]))
138 | if len(iou) > 1:
139 | return np.mean(iou), np.mean(iou[1:])
140 | else:
141 | # return np.mean(iou), tp, fp, fn, reordered_preds.astype(int).tolist()
142 | return np.mean(iou), np.mean(iou)
143 |
144 | @staticmethod
145 | def get_score(flat_preds: np.ndarray, flat_targets: np.ndarray, c1: int, c2: int, precision_based: bool = False) \
146 | -> float:
147 | """
148 | Calculates IoU given gt class c1 and prediction class c2.
149 | :param flat_preds: flattened predictions
150 | :param flat_targets: flattened gt
151 | :param c1: ground truth class to match
152 | :param c2: predicted class to match
153 | :param precision_based: flag to calculate precision instead of IoU.
154 | :return: The score if gt-c1 was matched to predicted c2.
155 | """
156 | tmp_all_gt = (flat_targets == c1)
157 | tmp_pred = (flat_preds == c2)
158 | tp = np.sum(tmp_all_gt & tmp_pred)
159 | fp = np.sum(~tmp_all_gt & tmp_pred)
160 | if not precision_based:
161 | fn = np.sum(tmp_all_gt & ~tmp_pred)
162 | jac = float(tp) / max(float(tp + fp + fn), 1e-8)
163 | return jac
164 | else:
165 | prec = float(tp) / max(float(tp + fp), 1e-8)
166 | # print('\tgt, pred = ', c1, c2, ' | precision=', prec)
167 | return prec
168 |
169 | def compute_score_matrix(self, num_pred: int, num_gt: int, pred: np.ndarray, gt: np.ndarray,
170 | precision_based: bool = False) -> np.ndarray:
171 | """
172 | Compute score matrix. Each element i, j of matrix is the score if i was matched j. Computation is parallelized
173 | over self.n_jobs.
174 | :param num_pred: number of predicted classes
175 | :param num_gt: number of ground-truth classes
176 | :param pred: flattened predictions
177 | :param gt: flattened gt
178 | :param precision_based: flag to calculate precision instead of IoU.
179 | :return: num_pred x num_gt matrix with A[i, j] being the score if ground-truth class i was matched to
180 | predicted class j.
181 | """
182 | # print("Parallelizing iou computation")
183 | # start = time.time()
184 | score_mat = Parallel(n_jobs=self.n_jobs)(delayed(self.get_score)(pred, gt, c1, c2, precision_based=precision_based)
185 | for c2 in range(num_pred) for c1 in np.unique(gt))
186 | # print(f"took {time.time() - start} seconds")
187 | score_mat = np.array(score_mat)
188 | return score_mat.reshape((num_pred, num_gt)).T
189 |
190 | def _hungarian_match(self, num_pred: int, num_gt: int, pred: np.ndarray, gt: np.ndarray):
191 | # do hungarian matching. If num_pred > num_gt match will be partial only.
192 | iou_mat = self.compute_score_matrix(num_pred, num_gt, pred, gt)
193 | match = linear_sum_assignment(1 - iou_mat)
194 | print("Matched clusters to gt classes:")
195 | print(match)
196 | return match
197 |
198 | def _original_match(self, num_pred, num_gt, pred, gt, precision_based=False) -> Dict[int, list]:
199 | score_mat = self.compute_score_matrix(num_pred, num_gt, pred, gt, precision_based=precision_based)
200 | gt_class = np.unique(gt).tolist()
201 | preds_to_gts = {}
202 | preds_to_gt_scores = {}
203 | # Greedily match predicted class to ground-truth class by best score.
204 | for pred_c in range(num_pred):
205 | for gt_i in range(num_gt):
206 | score = score_mat[gt_i, pred_c]
207 | if (pred_c not in preds_to_gts) or (score > preds_to_gt_scores[pred_c]):
208 | preds_to_gts[pred_c] = gt_class[gt_i]
209 | preds_to_gt_scores[pred_c] = score
210 | gt_to_matches = defaultdict(list)
211 | for k, v in preds_to_gts.items():
212 | gt_to_matches[v].append(k)
213 | # print('original match:', gt_to_matches)
214 | return gt_to_matches
215 |
216 |
217 | class PredsmIoUKmeans(PredsmIoU):
218 | """
219 | Used to track k-means cluster correspondence to ground-truth categories during fine-tuning.
220 | """
221 |
222 | def __init__(self,
223 | clustering_granularities: List[int],
224 | num_gt_classes: int,
225 | pca_dim: int = 50):
226 | """
227 | :param clustering_granularities: list of clustering granularities for embeddings
228 | :param num_gt_classes: number of ground-truth classes
229 | :param pca_dim: target dimensionality of PCA
230 | """
231 | super(PredsmIoU, self).__init__(compute_on_step=False, dist_sync_on_step=False) # Init Metric super class
232 | self.pca_dim = pca_dim
233 | self.num_pred_classes = clustering_granularities
234 | self.num_gt_classes = num_gt_classes
235 | self.add_state("masks", [])
236 | self.add_state("embeddings", [])
237 | self.add_state("gt", [])
238 | self.n_jobs = -1 # num_jobs = num_cores
239 | self.num_train_pca = 4000000 # take num_train_pca many vectors at max for training pca
240 |
241 | def update(self, masks: torch.Tensor, embeddings: torch.Tensor, gt: torch.Tensor) -> None:
242 | self.masks.append(masks)
243 | self.embeddings.append(embeddings)
244 | self.gt.append(gt)
245 |
246 | def compute(self, is_global_zero: bool) -> List[any]:
247 | if is_global_zero:
248 | # interpolate embeddings to match ground-truth masks spatially
249 | embeddings = torch.cat([e.cpu() for e in self.embeddings], dim=0) # move everything to cpu before catting
250 | valid_masks = torch.cat(self.masks, dim=0).cpu().numpy()
251 | res_w = valid_masks.shape[2]
252 | embeddings = nn.functional.interpolate(embeddings, size=(res_w, res_w), mode='bilinear')
253 | embeddings = embeddings.permute(0, 2, 3, 1).reshape(valid_masks.shape[0] * res_w ** 2, -1).numpy()
254 |
255 | # Normalize embeddings and reduce dims of embeddings by PCA
256 | normalized_embeddings = (embeddings - np.mean(embeddings, axis=0)) / (
257 | np.std(embeddings, axis=0, ddof=0) + 1e-5)
258 | d_orig = embeddings.shape[1]
259 | pca = faiss.PCAMatrix(d_orig, self.pca_dim)
260 | pca.train(normalized_embeddings[:self.num_train_pca])
261 | assert pca.is_trained
262 | transformed_feats = pca.apply_py(normalized_embeddings)
263 |
264 | # Cluster transformed feats with kmeans
265 | results = []
266 | gt = torch.cat(self.gt, dim=0).cpu().numpy()[valid_masks]
267 | for k in self.num_pred_classes: # [500, 300, 21]
268 | kmeans = faiss.Kmeans(self.pca_dim, k, niter=50, nredo=5, seed=1, verbose=True, gpu=False,
269 | spherical=False)
270 | kmeans.train(transformed_feats)
271 | _, pred_labels = kmeans.index.search(transformed_feats, 1)
272 | clusters = pred_labels.squeeze()
273 |
274 | # Filter predictions by valid masks (removes voc boundary gt class)
275 | pred_flattened = clusters.reshape(valid_masks.shape[0], 1, res_w, res_w)[valid_masks]
276 | assert len(np.unique(pred_flattened)) == k
277 | assert np.max(pred_flattened) == k - 1
278 |
279 | # Calculate mIoU. Do many-to-one matching if k > self.num_gt_classes.
280 | if k == self.num_gt_classes:
281 | results.append((k, k, self.compute_miou(gt, pred_flattened, k, self.num_gt_classes,
282 | many_to_one=False)))
283 | else:
284 | results.append((k, k, self.compute_miou(gt, pred_flattened, k, self.num_gt_classes,
285 | many_to_one=True)))
286 | results.append((k, f"{k}_prec", self.compute_miou(gt, pred_flattened, k, self.num_gt_classes,
287 | many_to_one=True, precision_based=True)))
288 | return results
289 |
290 |
291 | def eval_jac(gt: torch.Tensor, pred_mask: torch.Tensor, with_boundary: bool = True) -> float:
292 | """
293 | Calculate Intersection over Union averaged over all pictures. with_boundary flag, if set, doesn't filter out the
294 | boundary class as background.
295 | """
296 | jacs = 0
297 | for k, mask in enumerate(gt):
298 | if with_boundary:
299 | gt_fg_mask = (mask != 0).float()
300 | else:
301 | gt_fg_mask = ((mask != 0) & (mask != 255)).float()
302 | intersection = gt_fg_mask * pred_mask[k]
303 | intersection = torch.sum(torch.sum(intersection, dim=-1), dim=-1)
304 | union = (gt_fg_mask + pred_mask[k]) > 0
305 | union = torch.sum(torch.sum(union, dim=-1), dim=-1)
306 | jacs += intersection / union
307 | res = jacs / gt.size(0)
308 | print(res)
309 | return res.item()
310 |
311 |
312 | def process_attentions(attentions: torch.Tensor, spatial_res: int, threshold: float = 0.6, blur_sigma: float = 0.6) \
313 | -> torch.Tensor:
314 | """
315 | Process [0,1] attentions to binary 0-1 mask. Applies a Guassian filter, keeps threshold % of mass and removes
316 | components smaller than 3 pixels.
317 | The code is adapted from https://github.com/facebookresearch/dino/blob/main/visualize_attention.py but removes the
318 | need for using ground-truth data to find the best performing head. Instead we simply average all head's attentions
319 | so that we can use the foreground mask during training time.
320 | :param attentions: torch 4D-Tensor containing the averaged attentions
321 | :param spatial_res: spatial resolution of the attention map
322 | :param threshold: the percentage of mass to keep as foreground.
323 | :param blur_sigma: standard deviation to be used for creating kernel to perform blurring.
324 | :return: the foreground mask obtained from the ViT's attention.
325 | """
326 | # Blur attentions
327 | attentions = GaussianBlur(7, sigma=(blur_sigma))(attentions)
328 | attentions = attentions.reshape(attentions.size(0), 1, spatial_res ** 2)
329 | # Keep threshold% of mass
330 | val, idx = torch.sort(attentions)
331 | val /= torch.sum(val, dim=-1, keepdim=True)
332 | cumval = torch.cumsum(val, dim=-1)
333 | th_attn = cumval > (1 - threshold)
334 | idx2 = torch.argsort(idx)
335 | th_attn[:, 0] = torch.gather(th_attn[:, 0], dim=1, index=idx2[:, 0])
336 | th_attn = th_attn.reshape(attentions.size(0), 1, spatial_res, spatial_res).float()
337 | # Remove components with less than 3 pixels
338 | for j, th_att in enumerate(th_attn):
339 | labelled = label(th_att.cpu().numpy())
340 | for k in range(1, np.max(labelled) + 1):
341 | mask = labelled == k
342 | if np.sum(mask) <= 2:
343 | th_attn[j, 0][mask] = 0
344 | return th_attn.detach()
345 |
346 |
347 | def neq_load_customized(model, pretrained_dict):
348 | """
349 | load pre-trained model in a non-equal way,
350 | when new model has been partially modified
351 | """
352 | model_dict = model.state_dict()
353 | tmp = {}
354 | print('\n=======Check Weights Loading======')
355 | print('Weights not used from pretrained file:')
356 | for k, v in pretrained_dict.items():
357 | if k in model_dict:
358 | tmp[k] = v
359 | else:
360 | print(k)
361 |
362 | print('\n-----------------------------------')
363 | print('Weights not loaded into new model:')
364 | for k, v in model_dict.items():
365 | if k not in pretrained_dict:
366 | print(k)
367 | print('===================================\n')
368 |
369 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
370 | del pretrained_dict
371 | model_dict.update(tmp)
372 | del tmp
373 | model.load_state_dict(model_dict)
374 | return model
375 |
376 |
377 | def neq_load_external(model, pretrained_dict):
378 | """
379 | load pre-trained model from external source
380 | """
381 | model_dict = model.state_dict()
382 | tmp = {}
383 | print('\n=======Check Weights Loading======')
384 | print('Weights not used from pretrained file:')
385 | for k, v in pretrained_dict.items():
386 | if k.startswith('model'):
387 | k = k.removeprefix('model.') # for Leopart
388 | if 'backbone.' + k in model_dict:
389 | tmp['backbone.' + k] = v
390 | else:
391 | print(k)
392 |
393 | print('\n-----------------------------------')
394 | print('Weights not loaded into new model:')
395 | for k, v in model_dict.items():
396 | if k not in tmp:
397 | print(k)
398 | print('===================================\n')
399 |
400 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
401 | del pretrained_dict
402 | model_dict.update(tmp)
403 | del tmp
404 | model.load_state_dict(model_dict)
405 | return model
406 |
407 |
408 | def write_log(content, epoch, filename):
409 | if not os.path.exists(filename):
410 | log_file = open(filename, 'w')
411 | else:
412 | log_file = open(filename, 'a')
413 | log_file.write('## Epoch %d:\n' % epoch)
414 | log_file.write('time: %s\n' % str(datetime.now()))
415 | log_file.write(content + '\n\n')
416 | log_file.close()
417 |
418 |
419 | def calc_topk_accuracy(output, target, topk=(1,)):
420 | '''
421 | Given predicted and ground truth labels,
422 | calculate top-k accuracies.
423 | '''
424 | maxk = max(topk)
425 | batch_size = target.size(0)
426 |
427 | _, pred = output.topk(maxk, 1, True, True)
428 | pred = pred.t()
429 | correct = pred.eq(target.view(1, -1).expand_as(pred))
430 |
431 | res = []
432 | for k in topk:
433 | correct_k = correct[:k].contiguous().view(-1).float().sum(0)
434 | res.append(correct_k.mul_(1 / batch_size))
435 | return res
436 |
437 |
438 | def calc_accuracy(output, target):
439 | '''output: (B, N); target: (B)'''
440 | target = target.squeeze()
441 | _, pred = torch.max(output, 1)
442 | return torch.mean((pred == target).float())
443 |
444 |
445 | def calc_accuracy_binary(output, target):
446 | '''output, target: (B, N), output is logits, before sigmoid '''
447 | pred = output > 0
448 | acc = torch.mean((pred == target.byte()).float())
449 | del pred, output, target
450 | return acc
451 |
452 |
453 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
454 | assert len(mean) == len(std) == 3
455 | inv_mean = [-mean[i] / std[i] for i in range(3)]
456 | inv_std = [1 / i for i in std]
457 | return transforms.Normalize(mean=inv_mean, std=inv_std)
458 |
459 |
460 | class AverageMeter(object):
461 | """Computes and stores the average and current value"""
462 |
463 | def __init__(self):
464 | self.reset()
465 |
466 | def reset(self):
467 | self.val = 0
468 | self.avg = 0
469 | self.sum = 0
470 | self.count = 0
471 | self.local_history = deque([])
472 | self.local_avg = 0
473 | self.history = []
474 | self.dict = {} # save all data values here
475 | self.save_dict = {} # save mean and std here, for summary table
476 |
477 | def update(self, val, n=1, history=0, step=5):
478 | self.val = val
479 | self.sum += val * n
480 | self.count += n
481 | self.avg = self.sum / self.count
482 | if history:
483 | self.history.append(val)
484 | if step > 0:
485 | self.local_history.append(val)
486 | if len(self.local_history) > step:
487 | self.local_history.popleft()
488 | self.local_avg = np.average(self.local_history)
489 |
490 | def dict_update(self, val, key):
491 | if key in self.dict.keys():
492 | self.dict[key].append(val)
493 | else:
494 | self.dict[key] = [val]
495 |
496 | def __len__(self):
497 | return self.count
498 |
499 |
500 | class AccuracyTable(object):
501 | '''compute accuracy for each class'''
502 |
503 | def __init__(self):
504 | self.dict = {}
505 |
506 | def update(self, pred, tar):
507 | pred = torch.squeeze(pred)
508 | tar = torch.squeeze(tar)
509 | for i, j in zip(pred, tar):
510 | i = int(i)
511 | j = int(j)
512 | if j not in self.dict.keys():
513 | self.dict[j] = {'count': 0, 'correct': 0}
514 | self.dict[j]['count'] += 1
515 | if i == j:
516 | self.dict[j]['correct'] += 1
517 |
518 | def print_table(self, label):
519 | for key in self.dict.keys():
520 | acc = self.dict[key]['correct'] / self.dict[key]['count']
521 | print('%s: %2d, accuracy: %3d/%3d = %0.6f' \
522 | % (label, key, self.dict[key]['correct'], self.dict[key]['count'], acc))
523 |
524 |
525 | class ConfusionMeter(object):
526 | '''compute and show confusion matrix'''
527 |
528 | def __init__(self, num_class):
529 | self.num_class = num_class
530 | self.mat = np.zeros((num_class, num_class))
531 | self.precision = []
532 | self.recall = []
533 |
534 | def update(self, pred, tar):
535 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy()
536 | pred = np.squeeze(pred)
537 | tar = np.squeeze(tar)
538 | for p, t in zip(pred.flat, tar.flat):
539 | self.mat[p][t] += 1
540 |
541 | def print_mat(self):
542 | print('Confusion Matrix: (target in columns)')
543 | print(self.mat)
544 |
545 | def plot_mat(self, path, dictionary=None, annotate=False):
546 | plt.figure(dpi=600)
547 | plt.imshow(self.mat,
548 | cmap=plt.cm.jet,
549 | interpolation=None,
550 | extent=(0.5, np.shape(self.mat)[0] + 0.5, np.shape(self.mat)[1] + 0.5, 0.5))
551 | width, height = self.mat.shape
552 | if annotate:
553 | for x in range(width):
554 | for y in range(height):
555 | plt.annotate(str(int(self.mat[x][y])), xy=(y + 1, x + 1),
556 | horizontalalignment='center',
557 | verticalalignment='center',
558 | fontsize=8)
559 |
560 | if dictionary is not None:
561 | plt.xticks([i + 1 for i in range(width)],
562 | [dictionary[i] for i in range(width)],
563 | rotation='vertical')
564 | plt.yticks([i + 1 for i in range(height)],
565 | [dictionary[i] for i in range(height)])
566 | plt.xlabel('Ground Truth')
567 | plt.ylabel('Prediction')
568 | plt.colorbar()
569 | plt.tight_layout()
570 | plt.savefig(path, format='svg')
571 | plt.clf()
572 |
573 | # for i in range(width):
574 | # if np.sum(self.mat[i,:]) != 0:
575 | # self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:]))
576 | # if np.sum(self.mat[:,i]) != 0:
577 | # self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i]))
578 | # print('Average Precision: %0.4f' % np.mean(self.precision))
579 | # print('Average Recall: %0.4f' % np.mean(self.recall))
580 |
581 |
582 | if __name__ == '__main__':
583 | pass
584 |
--------------------------------------------------------------------------------