├── .gitignore
├── README.md
├── datasets
├── base_dataset.py
├── s3dis.py
├── s3dis_config.py
├── s3dis_regions.json
├── semantic_kitti.py
├── semantic_kitti_areas.json
├── semantic_kitti_config.py
└── semantic_kitti_regions.json
├── dino_model
├── LICENSE
├── README.md
├── eval_knn.py
├── eval_linear.py
├── fe_dino.py
├── hubconf.py
├── main_dino.py
├── run_with_submitit.py
├── utils.py
├── video_generation.py
├── vision_transformer.py
└── visualize_attention.py
├── images
├── kitti.png
├── s3dis.png
└── teaser.png
├── main.py
├── optimization.py
├── requirements.txt
├── s3dis_seed
├── info.txt
├── init_label_region.json
├── init_label_scan.json
├── init_ulabel_region.json
└── init_ulabel_scan.json
├── similarity_model.py
└── sk_seed
├── init_label_large_region.json
├── init_label_scan.json
├── init_ulabel_large_region.json
└── init_ulabel_scan.json
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | eggs/
4 | .eggs/
5 | *.pkl
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation
2 |
3 | Official PyTorch implementation of SeedAL.
4 |
5 | > [**You Never Get a Second Chance To Make a Good First Impression:
6 | Seeding Active Learning for 3D Semantic Segmentation**](http://arxiv.org/abs/2304.11762),
7 | > [Nermin Samet](https://nerminsamet.github.io/), [Oriane Siméoni](https://osimeoni.github.io/), [Gilles Puy](https://sites.google.com/site/puygilles/), [Georgy Ponimatkin](https://ponimatkin.github.io/), [Renaud Marlet](http://imagine.enpc.fr/~marletr/), [Vincent Lepetit](https://vincentlepetit.github.io/), \
8 | > *ICCV 2023.*
9 |
10 | ## Summary
11 | Active learning seeds (initial sets) have
12 | significant effect on the performances of
13 | active learning methods. Below, we show the variability of results obtained with 20 different random seeds (blue dashed lines), within an initial annotation budget of 3% of the dataset, when using various active learning methods for 3D semantic segmentation of S3DIS.
14 | We compare it to the result obtained with our seed selection strategy (solid red line), SeedAL, which performs
15 | better or on par with the best (lucky) random seeds among 20, and “protects” from very bad (unlucky) random seeds.
16 |
17 |
18 |
19 | We propose SeedAL for automatically constructing a seed that will ensure good performance for AL.
20 | Assuming that images of the point clouds are available, which is common, our method relies on powerful unsupervised image features
21 | to measure the diversity of the point clouds.
22 | It selects the point clouds for the seed by optimizing the diversity under an annotation budget.
23 |
24 | ## Outdoor Evaluation Results on Semantic KITTI
25 |
26 |
27 |
28 | ## Indoor Evaluation Results on S3DIS
29 |
30 |
31 |
32 | ## Installation
33 |
34 | Given images of the point cloud scenes,
35 | SeedAL outputs an initial set (or seed) under a certain budget to start the training
36 | of active learning methods for 3D point cloud semantic segmentation.
37 | In our experiments we have used [ReDAL](https://github.com/tsunghan-wu/ReDAL) framework which implements several active learning methods.
38 | For S3DIS and Semantic KITTI, SeedAL also follows the [data preparation](https://github.com/tsunghan-wu/ReDAL/tree/main/data_preparation) of ReDAL to be consistent with AL trainings.
39 |
40 | ### Environmental Setup
41 |
42 | The codebase was tested on Ubuntu 20.04 with Cuda 11.1.
43 | NVIDIA GPUs are needed for extracting the DINO features.
44 |
45 | ~~~
46 | conda create --name seedal python=3.7
47 | conda activate seedal
48 | conda install pytorch==1.10.2 torchvision==0.11.3 cudatoolkit=11.1 -c pytorch
49 | pip install -r requirements.txt
50 | ~~~
51 |
52 | After setting up the data and environment, you can run SeedAL as follows:
53 |
54 | ~~~
55 | python main.py -d
56 | ~~~
57 |
58 | By default, we save the results from intermediate steps such dataset statistics, DINO features and calculated scene diversities.
59 |
60 | ## Acknowledgement
61 |
62 | This work was granted access to the HPC resources of IDRIS under the allocation 2023-AD011013267R1 made by GENCI.
63 |
64 |
65 | ## License
66 |
67 | SeedAL is released under the [Apache 2.0 License](LICENSE).
68 |
69 | ## Citation
70 |
71 | If you find SeedAL useful for your research, please cite our paper as follows.
72 |
73 | > Nermin Samet, Oriane Siméoni, Gilles Puy, Georgy Ponimatkin, Renaud Marlet, Vincent Lepetit, "You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation",
74 | > In IEEE International Conference on Computer Vision (ICCV), 2023.
75 |
76 |
77 | BibTeX entry:
78 | ```
79 | @inproceedings{seedal,
80 | author = {Nermin Samet, Oriane Siméoni, Gilles Puy, Georgy Ponimatkin, Renaud Marlet and Vincent Lepetit},
81 | title = {You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation},
82 | booktitle = {IEEE International Conference on Computer Vision (ICCV)},
83 | year = {2023},
84 | }
85 | ```
86 |
87 |
--------------------------------------------------------------------------------
/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 | from os import listdir
4 | from os.path import isfile, join
5 | import os
6 | from similarity_model import SimilarityModel
7 | import pickle
8 | import pickle5 as pickle
9 | from sklearn.cluster import KMeans
10 | from scipy import stats
11 | import numpy as np
12 |
13 | class PCData:
14 |
15 | def __init__(self, base_config):
16 |
17 | # path related attributes
18 | self.dataset_name = None
19 | self.main_2d_path = None
20 | self.main_3d_path = None
21 | self.scene_relative_path_to_rgb = None
22 | self.save_path = None
23 |
24 | # inheretied class related attributes
25 | self.train_scenes = None
26 | self.class_num = None
27 | self.cluster_num = None
28 | self.target_point_num = None
29 | self.reduction_size = None
30 | self.seed_name = None
31 |
32 | for k, v in base_config.items():
33 | setattr(self, k, v)
34 |
35 | isExist = os.path.exists(self.save_path)
36 | if not isExist:
37 | os.makedirs(self.save_path)
38 |
39 | self.sim_object = SimilarityModel()
40 |
41 | self.cluster_centers_file = self.save_path + f'{self.dataset_name}_scene_clusters_{self.cluster_num}.pkl'
42 | self.scene_attributes_file = self.save_path + f'{self.dataset_name}_attribute_dict.pkl'
43 | self.pairwise_scene_attributes_file = self.save_path + f'{self.dataset_name}_pairwise_attribute_dict.pkl'
44 | self.data_stats_file = self.save_path + f'{self.dataset_name}_stats.pkl'
45 | self.feature_vec_dict_file = self.save_path + f'{self.dataset_name}_{self.sim_object.sim_model_name}_features.pkl'
46 |
47 | # attributes to be calculated
48 | self.cluster_centers = None
49 | self.scene_attributes = None
50 | self.pairwise_scene_attributes = None
51 | self.data_stats = None
52 | self.feat_vec_image_dict = None
53 |
54 | self.all_images = self.get_all_data_rgb_names()
55 | self.all_scenes = None
56 |
57 | def load_feature_vec_dict(self):
58 | if os.path.exists(self.feature_vec_dict_file):
59 | with open(self.feature_vec_dict_file, 'rb') as handle:
60 | self.feat_vec_image_dict = pickle.load(handle)
61 | print(f'{self.feature_vec_dict_file} features are loaded')
62 | return True
63 | else:
64 | return False
65 |
66 | def load_data_stats(self):
67 | if os.path.exists(self.data_stats_file):
68 | with open(self.data_stats_file, 'rb') as handle:
69 | self.data_stats = pickle.load(handle)
70 | print(f'{self.data_stats_file} features are loaded')
71 | return True
72 | else:
73 | return False
74 |
75 | def load_cluster_centers(self):
76 | if os.path.exists(self.cluster_centers_file):
77 | with open(self.cluster_centers_file, 'rb') as handle:
78 | self.cluster_centers = pickle.load(handle)
79 | print(f'{self.cluster_centers_file} features are loaded')
80 | return True
81 | else:
82 | return False
83 |
84 | def load_scene_attributes(self):
85 | if os.path.exists(self.scene_attributes_file):
86 | with open(self.scene_attributes_file, 'rb') as handle:
87 | self.scene_attributes = pickle.load(handle)
88 | print(f'{self.scene_attributes_file} features are loaded')
89 | return True
90 | else:
91 | return False
92 |
93 | def load_pairwise_scene_attributes(self):
94 | if os.path.exists(self.pairwise_scene_attributes_file):
95 | with open(self.pairwise_scene_attributes_file, 'rb') as handle:
96 | self.pairwise_scene_attributes = pickle.load(handle)
97 | print(f'{self.pairwise_scene_attributes_file} features are loaded')
98 | return True
99 | else:
100 | return False
101 |
102 | def get_all_data_rgb_names(self):
103 | train_scenes = [x.lower() for x in self.train_scenes]
104 |
105 | all_files = []
106 | for folder in train_scenes:
107 | curr_dir = self.main_2d_path + folder + self.scene_relative_path_to_rgb
108 | onlyfiles = [f for f in listdir(curr_dir) if isfile(join(curr_dir, f))]
109 | onlyfiles = [(folder + self.scene_relative_path_to_rgb + '/' + word) for word in onlyfiles]
110 | all_files = all_files + onlyfiles
111 |
112 | return all_files
113 |
114 | def cluster_scene(self, cluster_num,im_feats):
115 |
116 | kmeans = KMeans(n_clusters=cluster_num, random_state=123).fit(im_feats)
117 | return kmeans.cluster_centers_
118 |
119 | def extract_scene_attributes(self):
120 |
121 | ret = self.load_scene_attributes()
122 | if ret: # if we managed to load then no need to run again!
123 | return
124 |
125 | all_scenes, all_scene_keys = self.get_scenes()
126 |
127 | self.load_cluster_centers()
128 |
129 | scene_attributes = {}
130 |
131 | for ind, kk in enumerate(all_scene_keys):
132 | xbb = self.cluster_centers.get(kk, None)
133 |
134 | if xbb is not None:
135 | xbb = self.cluster_centers[kk]
136 | curr_dim = len(xbb)
137 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(xbb)
138 | final_distance_list = list(ordered_D[np.triu_indices(curr_dim, 1)])
139 |
140 | d = {}
141 | d[f"{kk}"] = {}
142 | d[f"{kk}"]['mean'] = stats.describe(final_distance_list).mean
143 | d[f"{kk}"]['variance'] = stats.describe(final_distance_list).variance
144 | d[f"{kk}"]['minmax'] = stats.describe(final_distance_list).minmax
145 | scene_attributes.update(d)
146 |
147 | self.scene_attributes = scene_attributes
148 | f = open(self.scene_attributes_file, "wb")
149 | pickle.dump(self.scene_attributes, f)
150 | f.close()
151 | print(f'Len of final rooms {len(self.scene_attributes)}')
152 |
153 | def extract_pairwise_scene_attributes(self):
154 |
155 | ret = self.load_pairwise_scene_attributes()
156 | if ret: # if we managed to load then no need to run again!
157 | return
158 |
159 | all_scenes, all_scene_keys = self.get_scenes()
160 | total_scene_number = len(all_scene_keys)
161 |
162 | self.load_cluster_centers()
163 |
164 | pairwise_scene_attributes = {}
165 |
166 | for ind, s1 in enumerate(all_scene_keys):
167 | for st in range(ind + 1, total_scene_number):
168 | s2 = all_scene_keys[st]
169 | kk = f'{s1}*{s2}'
170 |
171 | cluster_centers_s1 = self.cluster_centers[s1]
172 | cluster_centers_s2 = self.cluster_centers[s2]
173 |
174 | curr_dim = len(cluster_centers_s1) + len(cluster_centers_s2)
175 | xbb = np.zeros((curr_dim, cluster_centers_s1.shape[1]), dtype=np.float32)
176 | for ii, v in enumerate(cluster_centers_s1):
177 | xbb[ii] = v
178 | for ii, v in enumerate(cluster_centers_s2):
179 | xbb[len(cluster_centers_s1) + ii] = v
180 |
181 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(xbb)
182 |
183 | final_distance_list = list(
184 | ordered_D[len(cluster_centers_s1):, 0: len(cluster_centers_s2)].flatten())
185 |
186 | d = {}
187 | d[f"{kk}"] = {}
188 | d[f"{kk}"]['mean'] = stats.describe(final_distance_list).mean
189 | d[f"{kk}"]['variance'] = stats.describe(final_distance_list).variance
190 | d[f"{kk}"]['minmax'] = stats.describe(final_distance_list).minmax
191 | pairwise_scene_attributes.update(d)
192 |
193 | self.pairwise_scene_attributes = pairwise_scene_attributes
194 | # create a binary pickle file
195 | f = open(self.pairwise_scene_attributes_file, "wb")
196 | pickle.dump(self.pairwise_scene_attributes, f)
197 | f.close()
198 | print(f'Len: {len(self.pairwise_scene_attributes)}')
199 |
200 | def prepare_data_for_optimization(self):
201 |
202 | self.load_scene_attributes()
203 | self.load_data_stats()
204 | self.load_pairwise_scene_attributes()
205 |
206 | all_scenes, all_scene_keys = self.get_scenes()
207 | total_scene_number = len(all_scene_keys)
208 |
209 | pair_scores = []
210 | all_pairs = []
211 | for i, scene_i in enumerate(all_scene_keys):
212 | for j in range(i + 1, total_scene_number):
213 | scene_j = all_scene_keys[j]
214 | dsim_i = 1 - self.scene_attributes[scene_i]['mean']
215 | dsim_j = 1 - self.scene_attributes[scene_j]['mean']
216 |
217 | dsim = dsim_i * dsim_j
218 |
219 | kk = f'{scene_i}*{scene_j}'
220 |
221 | pairwise_sim = self.pairwise_scene_attributes[kk]['mean']
222 | pairwise_dsim = 1 - pairwise_sim
223 | final_score = pairwise_dsim * dsim
224 |
225 | pair_scores.append(final_score)
226 | all_pairs.append((scene_i, scene_j))
227 |
228 | return all_pairs, self.data_stats, pair_scores, self.reduction_size, self.target_point_num
229 |
230 |
231 |
232 |
233 |
--------------------------------------------------------------------------------
/datasets/s3dis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from PIL import Image
6 | from os import listdir
7 | from os.path import isfile, join
8 | from scipy import stats
9 | from .base_dataset import PCData
10 |
11 |
12 | class S3DIS(PCData):
13 |
14 | def __init__(self, config):
15 |
16 | base_config = getattr(config, "base_parameters")
17 | super().__init__(base_config)
18 |
19 | s3dis_parameters = getattr(config, "s3dis_parameters")
20 | for k, v in s3dis_parameters.items():
21 | setattr(self, k, v)
22 |
23 | def extract_feature_vecs(self):
24 |
25 | ret = self.load_feature_vec_dict()
26 | if ret:
27 | return
28 | feat_vec_image_dict = {}
29 | print("Extracting Feature Vectors!")
30 | for file_name in self.all_images:
31 | print(self.main_2d_path + file_name)
32 | im = Image.open(self.main_2d_path+file_name)
33 | feature_vec = self.sim_object.get_sim_vec_single(im)
34 | name_split = file_name.split('/')
35 | scene = name_split[0]
36 | further_split = name_split[3].split('_')
37 | camera = further_split[1]
38 | room = further_split[2] + '_' + further_split[3]
39 | frame = further_split[5]
40 | feat_vec_image_dict[f"{scene}_{room}_{camera}_{frame}"] = feature_vec
41 |
42 | self.feat_vec_image_dict = feat_vec_image_dict
43 | # lets dump
44 | with open(self.feature_vec_dict_file, 'wb') as handle:
45 | pickle.dump(self.feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
46 |
47 | def extract_data_stats(self):
48 |
49 | ret = self.load_data_stats()
50 | if ret: # if we managed to load then no need to run again!
51 | return
52 |
53 | f = open('./datasets/s3dis_regions.json')
54 | ulabel_region = json.load(f)
55 |
56 | s3dis_stats = {}
57 | total_point_number = 0
58 | total_region_number = 0
59 | for folder in self.train_scenes:
60 | curr_dir = self.main_3d_path + folder + '/supervoxel'
61 | annot_dir = self.main_3d_path + folder + '/labels'
62 | onlyfiles = [f for f in listdir(curr_dir) if isfile(join(curr_dir, f))]
63 |
64 | for room in onlyfiles:
65 | supvox = np.load(curr_dir + '/' + room)
66 | annots = np.load(annot_dir + '/' + room)
67 | preserving_labels = ulabel_region[f'{folder}#{room[:-4]}']
68 |
69 | (unique, counts) = np.unique(supvox, return_counts=True)
70 | (annot_unique, annot_counts) = np.unique(annots, return_counts=True)
71 | dict_annots = {}
72 | for au, ac in zip(annot_unique, annot_counts):
73 | dict_annots[self.label_2_name[au]] = ac
74 |
75 | indices_preserving_labels = [unique.tolist().index(x) for x in preserving_labels]
76 | unique = unique[indices_preserving_labels]
77 | counts = counts[indices_preserving_labels]
78 |
79 | frequencies = np.asarray((unique, counts)).T
80 | d = {}
81 | key_name = folder + '_' + room[:-4]
82 | d[f"{key_name}"] = {}
83 | d[f"{key_name}"]['area'] = frequencies[:, 1].sum()
84 |
85 | mask = np.isin(supvox, preserving_labels)
86 | if frequencies[:, 1].sum() != mask.sum():
87 | print('Something is wrong about AREA!')
88 |
89 | d[f"{key_name}"]['frequencies'] = frequencies
90 | d[f"{key_name}"]['preserving_labels'] = preserving_labels
91 |
92 | d[f"{key_name}"]['annot_stats'] = dict_annots
93 |
94 | total_point_number += frequencies[:, 1].sum()
95 | total_region_number += len(preserving_labels)
96 |
97 | s3dis_stats.update(d)
98 |
99 | self.data_stats = s3dis_stats
100 | f = open(self.data_stats_file, "wb")
101 | pickle.dump(self.data_stats, f)
102 | f.close()
103 | print(f'Point Num in Total {total_point_number}')
104 | print(f'Region Num in Total {total_region_number}')
105 |
106 | def extract_scene_clusters(self):
107 |
108 | ret = self.load_cluster_centers()
109 | if ret: # if we managed to load then no need to run again!
110 | return
111 |
112 | all_scenes, all_scene_keys = self.get_scenes()
113 | self.cluster_centers = {}
114 |
115 | for ind, s1 in enumerate(all_scene_keys):
116 | im_feats = np.squeeze(np.asarray(all_scenes[s1]), axis=1)
117 |
118 | cluster_num = self.cluster_num
119 | if len(im_feats) < cluster_num:
120 | cluster_num = len(im_feats)
121 |
122 | self.cluster_centers[s1] = self.cluster_scene(cluster_num, im_feats)
123 |
124 | f = open(self.cluster_centers_file, "wb")
125 | pickle.dump(self.cluster_centers, f)
126 | f.close()
127 |
128 | def get_scenes(self):
129 | all_scenes = {}
130 | self.load_feature_vec_dict()
131 | for scene in self.train_scenes:
132 | for room in self.rooms:
133 | for i in range(self.max_room_num):
134 | values = None
135 | kk = scene + '_' + room + str(i) + '_'
136 | values = [value for key, value in self.feat_vec_image_dict.items() if kk.lower() in key.lower()]
137 | if values:
138 | all_scenes[kk[:-1]] = values
139 | return all_scenes, list(all_scenes.keys())
140 |
141 | def create_initial_set(self, scene_list):
142 |
143 | self.load_data_stats()
144 |
145 | scan_num = len(self.data_stats)
146 | all_keys = list(self.data_stats.keys())
147 | all_values = list(self.data_stats.values())
148 |
149 | selected_samples = []
150 | for scn in scene_list:
151 | selected_samples.append(all_keys.index(scn))
152 |
153 | path = os.path.join('.', self.seed_name)
154 | os.mkdir(path)
155 |
156 | f = open(path + "/init_label_scan.json", "w")
157 | fu = open(path + "/init_ulabel_scan.json", "w")
158 | f.write("[\n")
159 | fu.write("[\n")
160 |
161 | for j in range(scan_num):
162 | curr_name = all_keys[j]
163 | splits = curr_name.split('_')
164 | if j in selected_samples:
165 | f.write(f'"{splits[0]}_{splits[1]}/coords/{splits[2]}_{splits[3]}.npy",\n')
166 | else:
167 | fu.write(f'"{splits[0]}_{splits[1]}/coords/{splits[2]}_{splits[3]}.npy",\n')
168 |
169 | f.seek(f.tell() - 2)
170 | fu.seek(fu.tell() - 2)
171 | f.write("\n]\n")
172 | fu.write("\n]\n")
173 | f.close()
174 | fu.close()
175 |
176 | f = open(path+"/init_label_region.json", "w")
177 | fu = open(path+"/init_ulabel_region.json", "w")
178 | f.write("{")
179 | fu.write("{")
180 |
181 | total_point_num = 0
182 | total_region_num = 0
183 | for j in range(scan_num):
184 | curr_name = all_keys[j]
185 | splits = curr_name.split('_')
186 | supervoxel_list = str(all_values[j]['frequencies'][:,0].tolist())
187 | if j in selected_samples:
188 | f.write(f'"{splits[0]}_{splits[1]}#{splits[2]}_{splits[3]}": {supervoxel_list},')
189 | total_point_num+= all_values[j]['area']
190 | total_region_num += len(all_values[j]['frequencies'])
191 | else:
192 | fu.write(f'"{splits[0]}_{splits[1]}#{splits[2]}_{splits[3]}": {supervoxel_list},')
193 |
194 | f.seek(f.tell() - 1)
195 | fu.seek(fu.tell() - 1)
196 | f.write("}")
197 | fu.write("}")
198 | f.close()
199 | fu.close()
200 |
201 | fi = open(path + "/info.txt", "w")
202 | fi.write(f'Region Num: {total_region_num} and Point Num: {total_point_num} in the current set')
203 | fi.close()
204 |
205 |
206 |
207 |
208 |
209 |
210 |
--------------------------------------------------------------------------------
/datasets/s3dis_config.py:
--------------------------------------------------------------------------------
1 |
2 | base_parameters = dict(
3 | dataset_name = 's3dis',
4 | main_2d_path = 'path/to/2D_modalities/',
5 | main_3d_path = 'path/to/S3DIS_processed/',
6 | scene_relative_path_to_rgb = '/data/rgb',
7 | save_path = './s3dis_attribute_outputs/',
8 | train_scenes = ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'],
9 | class_num = 13,
10 | cluster_num = 13,
11 | target_point_num = 6500000,
12 | reduction_size = 80,
13 | target_region_num = 790,
14 | seed_name = 's3dis_seed'
15 | )
16 |
17 | s3dis_parameters = dict(
18 | rooms = ['auditorium_', 'conferenceRoom_', 'copyRoom_', 'hallway_', 'lobby_', 'lounge_', 'office_', 'storage_', 'pantry_', 'WC_', 'openspace_'],
19 | max_room_num = 40,
20 | label_2_name = {0: 'ceiling',
21 | 1: 'floor',
22 | 2: 'wall',
23 | 3: 'beam',
24 | 4: 'column',
25 | 5: 'window',
26 | 6: 'door',
27 | 7: 'chair',
28 | 8: 'table',
29 | 9: 'bookcase',
30 | 10: 'sofa',
31 | 11: 'board',
32 | 12: 'clutter'},
33 | )
34 |
--------------------------------------------------------------------------------
/datasets/semantic_kitti.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from PIL import Image
6 | from .base_dataset import PCData
7 |
8 |
9 | class SK(PCData):
10 |
11 | def __init__(self, config):
12 |
13 | base_config = getattr(config, "base_parameters")
14 | super().__init__(base_config)
15 |
16 | s3dis_parameters = getattr(config, "sk_parameters")
17 | for k, v in s3dis_parameters.items():
18 | setattr(self, k, v)
19 |
20 | self.cls_feature_vec_dict_file = self.save_path + f'{self.dataset_name}_{self.sim_object.sim_model_name}_cls_features.pkl'
21 | self.cls_feat_vec_image_dict = None
22 |
23 | self.sparsified_dataset_file = self.save_path + f'{self.dataset_name}_sparsified.pkl'
24 | self.sparsified_dataset = None
25 |
26 | def extract_feature_vecs(self):
27 | self.extract_cls_features()
28 | self.sparsify_dataset()
29 | self.extract_patch_features()
30 |
31 | def extract_data_stats(self):
32 |
33 | ret = self.load_data_stats()
34 | if ret: # if we managed to load then no need to run again!
35 | return
36 |
37 | with open('./datasets/semantic_kitti_areas.json', 'r') as f:
38 | supvox_pts = json.load(f)
39 | supvox_keys = list(supvox_pts.keys())
40 |
41 | sk_stats = {}
42 | total_point_number = 0
43 | for im in self.sparsified_dataset:
44 | name_parse = im.split('/')
45 | seq = name_parse[0]
46 | im_id = name_parse[2][:-4]
47 | key_name = f'{seq}/velodyne/{im_id}.bin#'
48 | matching = [s for s in supvox_keys if key_name in s]
49 | area = 0
50 | for m in matching:
51 | area+=supvox_pts[m]
52 | d = {}
53 | d[f"{seq}_{self.scene_relative_path_to_rgb[1:]}_{im_id}"] = {}
54 | d[f"{seq}_{self.scene_relative_path_to_rgb[1:]}_{im_id}"]['area'] = area
55 | total_point_number+=area
56 | sk_stats.update(d)
57 |
58 | self.data_stats = sk_stats
59 | f = open(self.data_stats_file, "wb")
60 | pickle.dump(self.data_stats, f)
61 | f.close()
62 | print(f'Point Num in Total {total_point_number}')
63 |
64 | def extract_scene_clusters(self):
65 | ret = self.load_cluster_centers()
66 | if ret: # if we managed to load then no need to run again!
67 | return
68 | self.load_feature_vec_dict()
69 | self.cluster_centers = {}
70 | all_scene_keys = list(self.feat_vec_image_dict.keys())
71 | for ind, s1 in enumerate(all_scene_keys):
72 | im_feats = self.feat_vec_image_dict[s1][0]
73 |
74 | cluster_num = self.cluster_num
75 | if len(im_feats) < cluster_num:
76 | cluster_num = len(im_feats)
77 | self.cluster_centers[s1] = self.cluster_scene(cluster_num, im_feats)
78 |
79 | f = open(self.cluster_centers_file, "wb")
80 | pickle.dump(self.cluster_centers, f)
81 | f.close()
82 |
83 | def extract_cls_features(self):
84 |
85 | ret = self.load_cls_feature_vec_dict()
86 | if ret:
87 | return
88 | cls_feat_vec_image_dict = {}
89 | print("Extracting Feature Vectors!")
90 | for file_name in self.all_images:
91 | print(self.main_2d_path + file_name)
92 | im = Image.open(self.main_2d_path+file_name)
93 | feature_vec = self.sim_object.get_sim_vec_single(im)
94 | name_split = file_name.split('/')
95 | scene = name_split[0]
96 | camera = name_split[1]
97 | frame = name_split[2].split('.')[0]
98 | cls_feat_vec_image_dict[f"{scene}_{camera}_{frame}"] = feature_vec
99 |
100 | self.cls_feat_vec_image_dict = cls_feat_vec_image_dict
101 | # lets dump
102 | with open(self.cls_feature_vec_dict_file, 'wb') as handle:
103 | pickle.dump(self.cls_feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
104 |
105 | def sparsify_dataset(self):
106 |
107 | ret = self.load_sparsified_dataset()
108 | if ret:
109 | return
110 |
111 | selected_frames = []
112 | for folder in self.train_scenes:
113 | matching = sorted([s for s in self.all_images if f'{folder}/' in s])
114 | p = 0
115 | for i, img in enumerate(matching):
116 | i = p
117 | key_i = matching[i].replace('/','_')[:-4]
118 | for j in range(i + 1, len(matching), 1):
119 | key_j = matching[j].replace('/','_')[:-4]
120 |
121 | feat_i = self.cls_feat_vec_image_dict[key_i]
122 | feat_j = self.cls_feat_vec_image_dict[key_j]
123 |
124 | feats = np.concatenate((feat_i, feat_j), axis=0)
125 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(feats)
126 | sim = ordered_D[0][1]
127 | if sim < self.sparsification_similarity_thr:
128 | s_i = int((i + j) / 2)
129 | selected_frames.append(matching[s_i])
130 | p = j
131 | break
132 |
133 | if p > len(matching):
134 | break
135 |
136 | # lets dump
137 | self.sparsified_dataset = selected_frames
138 | with open(self.sparsified_dataset_file, 'wb') as handle:
139 | pickle.dump(self.sparsified_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
140 |
141 | def load_cls_feature_vec_dict(self):
142 | if os.path.exists(self.cls_feature_vec_dict_file):
143 | with open(self.cls_feature_vec_dict_file, 'rb') as handle:
144 | self.cls_feat_vec_image_dict = pickle.load(handle)
145 | print(f'{self.cls_feature_vec_dict_file} features are loaded')
146 | return True
147 | else:
148 | return False
149 |
150 | def load_sparsified_dataset(self):
151 | if os.path.exists(self.sparsified_dataset_file):
152 | with open(self.sparsified_dataset_file, 'rb') as handle:
153 | self.sparsified_dataset = pickle.load(handle)
154 | print(f'{self.sparsified_dataset_file} are loaded')
155 | return True
156 | else:
157 | return False
158 |
159 | def extract_patch_features(self):
160 |
161 | ret = self.load_feature_vec_dict()
162 | if ret:
163 | return
164 |
165 | ret = self.load_sparsified_dataset()
166 | if not ret:
167 | print('No sparsified dataset!')
168 | return
169 |
170 | feat_vec_image_dict = {}
171 |
172 | self.sim_object.set_sim_model_feat_type_dino('patch')
173 |
174 | print("Extracting Feature Vectors!")
175 | for file_name in self.sparsified_dataset:
176 | print(self.main_2d_path + file_name)
177 | im = Image.open(self.main_2d_path+file_name)
178 | feature_vec = self.sim_object.get_sim_vec_single(im)
179 | name_split = file_name.split('/')
180 | scene = name_split[0]
181 | camera = name_split[1]
182 | frame = name_split[2].split('.')[0]
183 | feat_vec_image_dict[f"{scene}_{camera}_{frame}"] = feature_vec
184 |
185 | self.feat_vec_image_dict = feat_vec_image_dict
186 | # lets dump
187 | with open(self.feature_vec_dict_file, 'wb') as handle:
188 | pickle.dump(self.feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
189 |
190 | self.sim_object.set_sim_model_feat_type_dino('cls')
191 |
192 | def create_initial_set(self, selected_samples):
193 |
194 | f = open('./datasets/semantic_kitti_regions.json')
195 | all_regions = json.load(f)
196 |
197 | path = os.path.join('.', self.seed_name)
198 | os.mkdir(path)
199 |
200 | f = open(path+"/init_label_scan.json", "w")
201 | fu = open(path+"/init_ulabel_scan.json", "w")
202 | f.write("[\n")
203 | fu.write("[\n")
204 |
205 | all_keys = list(all_regions.keys())
206 | for scn in all_keys:
207 | search_str = scn.replace('_', '_image_2_')
208 | final_str = scn.replace('_','/velodyne/') + '.bin'
209 | if search_str in selected_samples:
210 | f.write(f' "{final_str}",\n')
211 | else:
212 | fu.write(f' "{final_str}",\n')
213 |
214 | f.seek(f.tell() - 2)
215 | fu.seek(fu.tell() - 2)
216 | f.write("\n]\n")
217 | fu.write("\n]\n")
218 | f.close()
219 | fu.close()
220 |
221 | f = open(path+"/init_label_large_region.json", "w")
222 | fu = open(path+"/init_ulabel_large_region.json", "w")
223 | f.write("{")
224 | fu.write("{")
225 |
226 | for i, scn in enumerate(all_regions):
227 | search_str = scn.replace('_','_image_2_')
228 | supervoxel_list = all_regions[scn]
229 | if search_str in selected_samples:
230 | f.write(f'"{scn}": {supervoxel_list}, ')
231 | else:
232 | fu.write(f'"{scn}": {supervoxel_list}, ')
233 |
234 | f.seek(f.tell() - 2)
235 | fu.seek(fu.tell() - 2)
236 | f.write("}")
237 | fu.write("}")
238 | f.close()
239 | fu.close()
240 |
241 | def get_scenes(self):
242 |
243 | sparsified_dataset = self.sparsified_dataset
244 | sparsified_dataset_keys = []
245 |
246 | for scn in sparsified_dataset:
247 | sparsified_dataset_keys.append(scn.replace('/','_')[:-4])
248 |
249 | return sparsified_dataset, sparsified_dataset_keys
250 |
251 |
252 |
253 |
254 |
255 |
--------------------------------------------------------------------------------
/datasets/semantic_kitti_config.py:
--------------------------------------------------------------------------------
1 | base_parameters = dict(
2 | dataset_name = 'sk',
3 | main_2d_path = 'path/to/SemanticKitti/data_odometry_color/sequences/',
4 | main_3d_path = 'path/to/SemanticKitti/data_odometry_velodyne/dataset/sequences',
5 | scene_relative_path_to_rgb = '/image_2',
6 | save_path = './sk_attribute_outputs/',
7 | train_scenes = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'],
8 | class_num = 19,
9 | cluster_num = 19,
10 | target_point_num = 22591773,
11 | reduction_size = 1200,
12 | seed_name = 'sk_seed'
13 | )
14 |
15 | sk_parameters = dict(
16 | sparsification_similarity_thr=0.75,
17 | )
--------------------------------------------------------------------------------
/dino_model/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/dino_model/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Vision Transformers with DINO
2 |
3 | PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**.
4 | [[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)]
5 |
6 |
7 |

8 |
9 |
10 | ## Pretrained models
11 | You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.
12 |
13 |
77 |
78 | The pretrained models are available on PyTorch Hub.
79 | ```python
80 | import torch
81 | deits16 = torch.hub.load('facebookresearch/dino:main', 'dino_deits16')
82 | deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')
83 | vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
84 | vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
85 | resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
86 | ```
87 |
88 | ## Training
89 |
90 | ### Documentation
91 | Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run:
92 | ```
93 | python main_dino.py --help
94 | ```
95 |
96 | ### Vanilla DINO training :sauropod:
97 | Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
98 | ```
99 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
100 | ```
101 |
102 | ### Multi-node training
103 | We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs):
104 | ```
105 | python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
106 | ```
107 |
108 |
109 |
110 | DINO with ViT-base network.
111 |
112 |
113 | ```
114 | python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
115 | ```
116 |
117 |
118 |
119 | ### Boosting DINO performance :t-rex:
120 | You can improve the performance of the vanilla run by:
121 | - training for more epochs: `--epochs 300`,
122 | - increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`.
123 | - removing last layer normalization (only safe with `--arch deit_small`): `--norm_last_layer false`,
124 |
125 |
126 |
127 | Full command.
128 |
129 |
130 | ```
131 | python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
132 | ```
133 |
134 |
135 |
136 | The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
137 |
138 | ### ResNet-50 and other convnets trainings
139 | This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) logs for this run.
140 | ```
141 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
142 | ```
143 |
144 | ## Self-attention visualization
145 | You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:
146 | ```
147 | python visualize_attention.py
148 | ```
149 |
150 | ## Self-attention video generation
151 | You can generate videos like the one on the blog post with `video_generation.py`.
152 |
153 | https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4
154 |
155 | Extract frames from input video and generate attention video:
156 | ```
157 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
158 | --input_path input/video.mp4 \
159 | --output_path output/ \
160 | --fps 25
161 | ```
162 |
163 | Use folder of frames already extracted and generate attention video:
164 | ```
165 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
166 | --input_path output/frames/ \
167 | --output_path output/ \
168 | --resize 256 \
169 | ```
170 |
171 | Only generate video from folder of attention maps images:
172 | ```
173 | python video_generation.py --input_path output/attention \
174 | --output_path output/ \
175 | --video_only \
176 | --video_format avi
177 | ```
178 |
179 | Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook.
180 |
181 |
182 |

183 |
184 |
185 |
186 | ## Evaluation: k-NN classification on ImageNet
187 | To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
188 | ```
189 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet
190 | ```
191 | If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:
192 | ```
193 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet
194 | ```
195 |
196 | ## Evaluation: Linear classification on ImageNet
197 | To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:
198 | ```
199 | python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet
200 | ```
201 |
202 | ## License
203 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
204 |
205 | ## Citation
206 | If you find this repository useful, please consider giving a star :star: and citation :t-rex::
207 | ```
208 | @article{caron2021emerging,
209 | title={Emerging Properties in Self-Supervised Vision Transformers},
210 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
211 | journal={arXiv preprint arXiv:2104.14294},
212 | year={2021}
213 | }
214 | ```
215 |
--------------------------------------------------------------------------------
/dino_model/eval_knn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import argparse
16 |
17 | import torch
18 | from torch import nn
19 | import torch.distributed as dist
20 | import torch.backends.cudnn as cudnn
21 | from torchvision import datasets
22 | from torchvision import transforms as pth_transforms
23 |
24 | import utils
25 | import vision_transformer as vits
26 |
27 |
28 | def extract_feature_pipeline(args):
29 | # ============ preparing data ... ============
30 | transform = pth_transforms.Compose([
31 | pth_transforms.Resize(256, interpolation=3),
32 | pth_transforms.CenterCrop(224),
33 | pth_transforms.ToTensor(),
34 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
35 | ])
36 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
37 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
38 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
39 | data_loader_train = torch.utils.data.DataLoader(
40 | dataset_train,
41 | sampler=sampler,
42 | batch_size=args.batch_size_per_gpu,
43 | num_workers=args.num_workers,
44 | pin_memory=True,
45 | drop_last=False,
46 | )
47 | data_loader_val = torch.utils.data.DataLoader(
48 | dataset_val,
49 | batch_size=args.batch_size_per_gpu,
50 | num_workers=args.num_workers,
51 | pin_memory=True,
52 | drop_last=False,
53 | )
54 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
55 |
56 | # ============ building network ... ============
57 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
58 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
59 | model.cuda()
60 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
61 | model.eval()
62 |
63 | # ============ extract features ... ============
64 | print("Extracting features for train set...")
65 | train_features = extract_features(model, data_loader_train)
66 | print("Extracting features for val set...")
67 | test_features = extract_features(model, data_loader_val)
68 |
69 | if utils.get_rank() == 0:
70 | train_features = nn.functional.normalize(train_features, dim=1, p=2)
71 | test_features = nn.functional.normalize(test_features, dim=1, p=2)
72 |
73 | train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
74 | test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
75 | # save features and labels
76 | if args.dump_features and dist.get_rank() == 0:
77 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
78 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
79 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
80 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
81 | return train_features, test_features, train_labels, test_labels
82 |
83 |
84 | @torch.no_grad()
85 | def extract_features(model, data_loader):
86 | metric_logger = utils.MetricLogger(delimiter=" ")
87 | features = None
88 | for samples, index in metric_logger.log_every(data_loader, 10):
89 | samples = samples.cuda(non_blocking=True)
90 | index = index.cuda(non_blocking=True)
91 | feats = model(samples).clone()
92 |
93 | # init storage feature matrix
94 | if dist.get_rank() == 0 and features is None:
95 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
96 | if args.use_cuda:
97 | features = features.cuda(non_blocking=True)
98 | print(f"Storing features into tensor of shape {features.shape}")
99 |
100 | # get indexes from all processes
101 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
102 | y_l = list(y_all.unbind(0))
103 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
104 | y_all_reduce.wait()
105 | index_all = torch.cat(y_l)
106 |
107 | # share features between processes
108 | feats_all = torch.empty(
109 | dist.get_world_size(),
110 | feats.size(0),
111 | feats.size(1),
112 | dtype=feats.dtype,
113 | device=feats.device,
114 | )
115 | output_l = list(feats_all.unbind(0))
116 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
117 | output_all_reduce.wait()
118 |
119 | # update storage feature matrix
120 | if dist.get_rank() == 0:
121 | if args.use_cuda:
122 | features.index_copy_(0, index_all, torch.cat(output_l))
123 | else:
124 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
125 | return features
126 |
127 |
128 | @torch.no_grad()
129 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
130 | top1, top5, total = 0.0, 0.0, 0
131 | train_features = train_features.t()
132 | num_test_images, num_chunks = test_labels.shape[0], 100
133 | imgs_per_chunk = num_test_images // num_chunks
134 | retrieval_one_hot = torch.zeros(k, num_classes).cuda()
135 | for idx in range(0, num_test_images, imgs_per_chunk):
136 | # get the features for test images
137 | features = test_features[
138 | idx : min((idx + imgs_per_chunk), num_test_images), :
139 | ]
140 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
141 | batch_size = targets.shape[0]
142 |
143 | # calculate the dot product and compute top-k neighbors
144 | similarity = torch.mm(features, train_features)
145 | distances, indices = similarity.topk(k, largest=True, sorted=True)
146 | candidates = train_labels.view(1, -1).expand(batch_size, -1)
147 | retrieved_neighbors = torch.gather(candidates, 1, indices)
148 |
149 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
150 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
151 | distances_transform = distances.clone().div_(T).exp_()
152 | probs = torch.sum(
153 | torch.mul(
154 | retrieval_one_hot.view(batch_size, -1, num_classes),
155 | distances_transform.view(batch_size, -1, 1),
156 | ),
157 | 1,
158 | )
159 | _, predictions = probs.sort(1, True)
160 |
161 | # find the predictions that match the target
162 | correct = predictions.eq(targets.data.view(-1, 1))
163 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
164 | top5 = top5 + correct.narrow(1, 0, 5).sum().item()
165 | total += targets.size(0)
166 | top1 = top1 * 100.0 / total
167 | top5 = top5 * 100.0 / total
168 | return top1, top5
169 |
170 |
171 | class ReturnIndexDataset(datasets.ImageFolder):
172 | def __getitem__(self, idx):
173 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
174 | return img, idx
175 |
176 |
177 | if __name__ == '__main__':
178 | parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
179 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
180 | parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
181 | help='Number of NN to use. 20 is usually working the best.')
182 | parser.add_argument('--temperature', default=0.07, type=float,
183 | help='Temperature used in the voting coefficient')
184 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
185 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
186 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
187 | parser.add_argument('--arch', default='deit_small', type=str,
188 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
189 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
190 | parser.add_argument("--checkpoint_key", default="teacher", type=str,
191 | help='Key to use in the checkpoint (example: "teacher")')
192 | parser.add_argument('--dump_features', default=None,
193 | help='Path where to save computed features, empty for no saving')
194 | parser.add_argument('--load_features', default=None, help="""If the features have
195 | already been computed, where to find them.""")
196 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
197 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
198 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
199 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
200 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
201 | args = parser.parse_args()
202 |
203 | utils.init_distributed_mode(args)
204 | print("git:\n {}\n".format(utils.get_sha()))
205 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
206 | cudnn.benchmark = True
207 |
208 | if args.load_features:
209 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
210 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
211 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
212 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
213 | else:
214 | # need to extract features !
215 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
216 |
217 | if utils.get_rank() == 0:
218 | if args.use_cuda:
219 | train_features = train_features.cuda()
220 | test_features = test_features.cuda()
221 | train_labels = train_labels.cuda()
222 | test_labels = test_labels.cuda()
223 |
224 | print("Features are ready!\nStart the k-NN classification.")
225 | for k in args.nb_knn:
226 | top1, top5 = knn_classifier(train_features, train_labels,
227 | test_features, test_labels, k, args.temperature)
228 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
229 | dist.barrier()
230 |
--------------------------------------------------------------------------------
/dino_model/eval_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import argparse
16 | import json
17 | from pathlib import Path
18 |
19 | import torch
20 | from torch import nn
21 | import torch.distributed as dist
22 | import torch.backends.cudnn as cudnn
23 | from torchvision import datasets
24 | from torchvision import transforms as pth_transforms
25 |
26 | import utils
27 | import vision_transformer as vits
28 |
29 |
30 | def eval_linear(args):
31 | utils.init_distributed_mode(args)
32 | print("git:\n {}\n".format(utils.get_sha()))
33 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
34 | cudnn.benchmark = True
35 |
36 | # ============ preparing data ... ============
37 | train_transform = pth_transforms.Compose([
38 | pth_transforms.RandomResizedCrop(224),
39 | pth_transforms.RandomHorizontalFlip(),
40 | pth_transforms.ToTensor(),
41 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
42 | ])
43 | val_transform = pth_transforms.Compose([
44 | pth_transforms.Resize(256, interpolation=3),
45 | pth_transforms.CenterCrop(224),
46 | pth_transforms.ToTensor(),
47 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
48 | ])
49 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
50 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
51 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
52 | train_loader = torch.utils.data.DataLoader(
53 | dataset_train,
54 | sampler=sampler,
55 | batch_size=args.batch_size_per_gpu,
56 | num_workers=args.num_workers,
57 | pin_memory=True,
58 | )
59 | val_loader = torch.utils.data.DataLoader(
60 | dataset_val,
61 | batch_size=args.batch_size_per_gpu,
62 | num_workers=args.num_workers,
63 | pin_memory=True,
64 | )
65 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
66 |
67 | # ============ building network ... ============
68 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
69 | model.cuda()
70 | model.eval()
71 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
72 | # load weights to evaluate
73 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
74 |
75 | linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels)
76 | linear_classifier = linear_classifier.cuda()
77 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
78 |
79 | # set optimizer
80 | optimizer = torch.optim.SGD(
81 | linear_classifier.parameters(),
82 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
83 | momentum=0.9,
84 | weight_decay=0, # we do not apply weight decay
85 | )
86 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
87 |
88 | # Optionally resume from a checkpoint
89 | to_restore = {"epoch": 0, "best_acc": 0.}
90 | utils.restart_from_checkpoint(
91 | os.path.join(args.output_dir, "checkpoint.pth.tar"),
92 | run_variables=to_restore,
93 | state_dict=linear_classifier,
94 | optimizer=optimizer,
95 | scheduler=scheduler,
96 | )
97 | start_epoch = to_restore["epoch"]
98 | best_acc = to_restore["best_acc"]
99 |
100 | for epoch in range(start_epoch, args.epochs):
101 | train_loader.sampler.set_epoch(epoch)
102 |
103 | train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
104 | scheduler.step()
105 |
106 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
107 | 'epoch': epoch}
108 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
109 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
110 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
111 | best_acc = max(best_acc, test_stats["acc1"])
112 | print(f'Max accuracy so far: {best_acc:.2f}%')
113 | log_stats = {**{k: v for k, v in log_stats.items()},
114 | **{f'test_{k}': v for k, v in test_stats.items()}}
115 | if utils.is_main_process():
116 | with (Path(args.output_dir) / "log.txt").open("a") as f:
117 | f.write(json.dumps(log_stats) + "\n")
118 | save_dict = {
119 | "epoch": epoch + 1,
120 | "state_dict": linear_classifier.state_dict(),
121 | "optimizer": optimizer.state_dict(),
122 | "scheduler": scheduler.state_dict(),
123 | "best_acc": best_acc,
124 | }
125 | torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
126 | print("Training of the supervised linear classifier on frozen features completed.\n"
127 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
128 |
129 |
130 | def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
131 | linear_classifier.train()
132 | metric_logger = utils.MetricLogger(delimiter=" ")
133 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
134 | header = 'Epoch: [{}]'.format(epoch)
135 | for (inp, target) in metric_logger.log_every(loader, 20, header):
136 | # move to gpu
137 | inp = inp.cuda(non_blocking=True)
138 | target = target.cuda(non_blocking=True)
139 |
140 | # forward
141 | with torch.no_grad():
142 | output = model.forward_return_n_last_blocks(inp, n, avgpool)
143 | output = linear_classifier(output)
144 |
145 | # compute cross entropy loss
146 | loss = nn.CrossEntropyLoss()(output, target)
147 |
148 | # compute the gradients
149 | optimizer.zero_grad()
150 | loss.backward()
151 |
152 | # step
153 | optimizer.step()
154 |
155 | # log
156 | torch.cuda.synchronize()
157 | metric_logger.update(loss=loss.item())
158 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
159 | # gather the stats from all processes
160 | metric_logger.synchronize_between_processes()
161 | print("Averaged stats:", metric_logger)
162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
163 |
164 |
165 | @torch.no_grad()
166 | def validate_network(val_loader, model, linear_classifier, n, avgpool):
167 | linear_classifier.eval()
168 | metric_logger = utils.MetricLogger(delimiter=" ")
169 | header = 'Test:'
170 | for inp, target in metric_logger.log_every(val_loader, 20, header):
171 | # move to gpu
172 | inp = inp.cuda(non_blocking=True)
173 | target = target.cuda(non_blocking=True)
174 |
175 | # compute output
176 | output = model.forward_return_n_last_blocks(inp, n, avgpool)
177 | output = linear_classifier(output)
178 | loss = nn.CrossEntropyLoss()(output, target)
179 |
180 | if linear_classifier.module.num_labels >= 5:
181 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
182 | else:
183 | acc1, = utils.accuracy(output, target, topk=(1,))
184 |
185 | batch_size = inp.shape[0]
186 | metric_logger.update(loss=loss.item())
187 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
188 | if linear_classifier.module.num_labels >= 5:
189 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
190 | if linear_classifier.module.num_labels >= 5:
191 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
192 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
193 | else:
194 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
195 | .format(top1=metric_logger.acc1, losses=metric_logger.loss))
196 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
197 |
198 |
199 | class LinearClassifier(nn.Module):
200 | """Linear layer to train on top of frozen features"""
201 | def __init__(self, dim, num_labels=1000):
202 | super(LinearClassifier, self).__init__()
203 | self.num_labels = num_labels
204 | self.linear = nn.Linear(dim, num_labels)
205 | self.linear.weight.data.normal_(mean=0.0, std=0.01)
206 | self.linear.bias.data.zero_()
207 |
208 | def forward(self, x):
209 | # flatten
210 | x = x.view(x.size(0), -1)
211 |
212 | # linear layer
213 | return self.linear(x)
214 |
215 |
216 | if __name__ == '__main__':
217 | parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
218 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
219 | for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base.""")
220 | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
221 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
222 | We typically set this to False for DeiT-Small and to True with ViT-Base.""")
223 | parser.add_argument('--arch', default='deit_small', type=str,
224 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
225 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
226 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
227 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
228 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
229 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
230 | training (highest LR used during training). The learning rate is linearly scaled
231 | with the batch size, and specified here for a reference batch size of 256.
232 | We recommend tweaking the LR depending on the checkpoint evaluated.""")
233 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
234 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
235 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
236 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
237 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
238 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
239 | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
240 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
241 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
242 | args = parser.parse_args()
243 | eval_linear(args)
244 |
--------------------------------------------------------------------------------
/dino_model/fe_dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | from torch import nn
5 | from torchvision import transforms as pth_transforms
6 |
7 | # from utils import load_image
8 | import dino_model.utils as utils
9 | import dino_model.vision_transformer as vits
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 |
13 | class DinoModel(nn.Module):
14 | def __init__(self, feat_type):
15 | super(DinoModel, self).__init__()
16 |
17 | self.transform = pth_transforms.Compose([
18 | pth_transforms.Resize((224, 224), interpolation=3),
19 | # pth_transforms.CenterCrop(224),
20 | pth_transforms.ToTensor(),
21 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
22 | ])
23 |
24 | self.model = self.load_model()
25 | self.feat_type = feat_type
26 |
27 |
28 | def load_model(self):
29 | # ============ building network ... ============
30 | p = 8
31 | model = vits.__dict__["vit_base"](patch_size=p, num_classes=0)
32 | print(f"Model {'vit_base'} {p}x{p} built.")
33 | model.cuda()
34 | utils.load_pretrained_weights(model, "", "", "vit_base", p)
35 | model.eval()
36 | return model
37 |
38 | def forward(self, images, feat_type='cls'):
39 | """Extract the image feature vectors."""
40 | if self.transform is not None:
41 | images = self.transform(images).unsqueeze(0)
42 | with torch.no_grad():
43 |
44 | if self.feat_type == 'cls':
45 | features = self.model(images.to(device))
46 |
47 | elif self.feat_type == 'patch':
48 |
49 | feat_out = {}
50 |
51 | def hook_fn_forward_qkv(module, input, output):
52 | feat_out["qkv"] = output
53 |
54 | self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
55 |
56 | # Forward pass in the model
57 | attentions = self.model.forward_selfattention(images.to(device))
58 |
59 | # Dimensions
60 | nb_im = attentions.shape[0] # Batch size
61 | nh = attentions.shape[1] # Number of heads
62 | nb_tokens = attentions.shape[2] # Number of tokens
63 |
64 | # Extract the qkv features of the last attention layer
65 | qkv = (
66 | feat_out["qkv"]
67 | .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
68 | .permute(2, 0, 3, 1, 4)
69 | )
70 | q, k, v = qkv[0], qkv[1], qkv[2]
71 | k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :]
72 | q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :]
73 | v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :]
74 |
75 | features = k
76 |
77 | return features
78 |
--------------------------------------------------------------------------------
/dino_model/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | from torchvision.models.resnet import resnet50
16 |
17 | import vision_transformer as vits
18 |
19 | dependencies = ["torch", "torchvision"]
20 |
21 |
22 | def dino_deits16(pretrained=True, **kwargs):
23 | """
24 | DeiT-Small/16x16 pre-trained with DINO.
25 | Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
26 | """
27 | model = vits.__dict__["deit_small"](patch_size=16, num_classes=0, **kwargs)
28 | if pretrained:
29 | state_dict = torch.hub.load_state_dict_from_url(
30 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
31 | map_location="cpu",
32 | )
33 | model.load_state_dict(state_dict, strict=True)
34 | return model
35 |
36 |
37 | def dino_deits8(pretrained=True, **kwargs):
38 | """
39 | DeiT-Small/8x8 pre-trained with DINO.
40 | Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
41 | """
42 | model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs)
43 | if pretrained:
44 | state_dict = torch.hub.load_state_dict_from_url(
45 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
46 | map_location="cpu",
47 | )
48 | model.load_state_dict(state_dict, strict=True)
49 | return model
50 |
51 |
52 | def dino_vitb16(pretrained=True, **kwargs):
53 | """
54 | ViT-Base/16x16 pre-trained with DINO.
55 | Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
56 | """
57 | model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
58 | if pretrained:
59 | state_dict = torch.hub.load_state_dict_from_url(
60 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
61 | map_location="cpu",
62 | )
63 | model.load_state_dict(state_dict, strict=True)
64 | return model
65 |
66 |
67 | def dino_vitb8(pretrained=True, **kwargs):
68 | """
69 | ViT-Base/8x8 pre-trained with DINO.
70 | Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
71 | """
72 | model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
73 | if pretrained:
74 | state_dict = torch.hub.load_state_dict_from_url(
75 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
76 | map_location="cpu",
77 | )
78 | model.load_state_dict(state_dict, strict=True)
79 | return model
80 |
81 |
82 | def dino_resnet50(pretrained=True, **kwargs):
83 | """
84 | ResNet-50 pre-trained with DINO.
85 | Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
86 | Note that `fc.weight` and `fc.bias` are randomly initialized.
87 | """
88 | model = resnet50(pretrained=False, **kwargs)
89 | if pretrained:
90 | state_dict = torch.hub.load_state_dict_from_url(
91 | url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
92 | map_location="cpu",
93 | )
94 | model.load_state_dict(state_dict, strict=False)
95 | return model
96 |
--------------------------------------------------------------------------------
/dino_model/main_dino.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import os
16 | import sys
17 | import datetime
18 | import time
19 | import math
20 | import json
21 | from pathlib import Path
22 |
23 | import numpy as np
24 | from PIL import Image
25 | import torch
26 | import torch.nn as nn
27 | import torch.distributed as dist
28 | import torch.backends.cudnn as cudnn
29 | import torch.nn.functional as F
30 | from torchvision import datasets, transforms
31 | from torchvision import models as torchvision_models
32 |
33 | import utils
34 | import vision_transformer as vits
35 | from vision_transformer import DINOHead
36 |
37 | torchvision_archs = sorted(name for name in torchvision_models.__dict__
38 | if name.islower() and not name.startswith("__")
39 | and callable(torchvision_models.__dict__[name]))
40 |
41 | def get_args_parser():
42 | parser = argparse.ArgumentParser('DINO', add_help=False)
43 |
44 | # Model parameters
45 | parser.add_argument('--arch', default='deit_small', type=str,
46 | choices=['deit_tiny', 'deit_small', 'vit_base'] + torchvision_archs,
47 | help="""Name of architecture to train. For quick experiments with ViTs,
48 | we recommend using deit_tiny or deit_small.""")
49 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
50 | of input square patches - default 16 (for 16x16 patches). Using smaller
51 | values leads to better performance but requires more memory. Applies only
52 | for ViTs (deit_tiny, deit_small and vit_base). If <16, we recommend disabling
53 | mixed precision training (--use_fp16 false) to avoid unstabilities.""")
54 | parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
55 | the DINO head output. For complex and large datasets large values (like 65k) work well.""")
56 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
57 | help="""Whether or not to weight normalize the last layer of the DINO head.
58 | Not normalizing leads to better performance but can make the training unstable.
59 | In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""")
60 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
61 | parameter for teacher update. The value is increased to 1 during training with cosine schedule.
62 | We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
63 | parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
64 | help="Whether to use batch normalizations in projection head (Default: False)")
65 |
66 | # Temperature teacher parameters
67 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
68 | help="""Initial value for the teacher temperature: 0.04 works well in most cases.
69 | Try decreasing it if the training loss does not decrease.""")
70 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
71 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
72 | starting with the default value of 0.04 and increase this slightly if needed.""")
73 | parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
74 | help='Number of warmup epochs for the teacher temperature (Default: 30).')
75 |
76 | # Training/Optimization parameters
77 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
78 | to use half precision for training. Improves training time and memory requirements,
79 | but can provoke instability and slight decay of performance. We recommend disabling
80 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
81 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
82 | weight decay. With ViT, a smaller value at the beginning of training works well.""")
83 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
84 | weight decay. We use a cosine schedule for WD and using a larger decay by
85 | the end of training improves performance for ViTs.""")
86 | parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
87 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
88 | help optimization for larger ViT architectures. 0 for disabling.""")
89 | parser.add_argument('--batch_size_per_gpu', default=64, type=int,
90 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
91 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
92 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
93 | during which we keep the output layer fixed. Typically doing so during
94 | the first epoch helps training. Try increasing this value if the loss does not decrease.""")
95 | parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of
96 | linear warmup (highest LR used during training). The learning rate is linearly scaled
97 | with the batch size, and specified here for a reference batch size of 256.""")
98 | parser.add_argument("--warmup_epochs", default=10, type=int,
99 | help="Number of epochs for the linear learning-rate warm up.")
100 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
101 | end of optimization. We use a cosine LR schedule with linear warmup.""")
102 | parser.add_argument('--optimizer', default='adamw', type=str,
103 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
104 |
105 | # Multi-crop parameters
106 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
107 | help="""Scale range of the cropped image before resizing, relatively to the origin image.
108 | Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
109 | recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
110 | parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
111 | local views to generate. Set this parameter to 0 to disable multi-crop training.
112 | When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
113 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
114 | help="""Scale range of the cropped image before resizing, relatively to the origin image.
115 | Used for small local view cropping of multi-crop.""")
116 |
117 | # Misc
118 | parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str,
119 | help='Please specify path to the ImageNet training data.')
120 | parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
121 | parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.')
122 | parser.add_argument('--seed', default=0, type=int, help='Random seed.')
123 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
124 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
125 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
126 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
127 | return parser
128 |
129 |
130 | def train_dino(args):
131 | utils.init_distributed_mode(args)
132 | utils.fix_random_seeds(args.seed)
133 | print("git:\n {}\n".format(utils.get_sha()))
134 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
135 | cudnn.benchmark = True
136 |
137 | # ============ preparing data ... ============
138 | transform = DataAugmentationDINO(
139 | args.global_crops_scale,
140 | args.local_crops_scale,
141 | args.local_crops_number,
142 | )
143 | dataset = datasets.ImageFolder(args.data_path, transform=transform)
144 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
145 | data_loader = torch.utils.data.DataLoader(
146 | dataset,
147 | sampler=sampler,
148 | batch_size=args.batch_size_per_gpu,
149 | num_workers=args.num_workers,
150 | pin_memory=True,
151 | drop_last=True,
152 | )
153 | print(f"Data loaded: there are {len(dataset)} images.")
154 |
155 | # ============ building student and teacher networks ... ============
156 | # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
157 | if args.arch in vits.__dict__.keys():
158 | student = vits.__dict__[args.arch](
159 | patch_size=args.patch_size,
160 | drop_path_rate=0.1, # stochastic depth
161 | )
162 | teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
163 | student.head = DINOHead(
164 | student.embed_dim,
165 | args.out_dim,
166 | use_bn=args.use_bn_in_head,
167 | norm_last_layer=args.norm_last_layer,
168 | )
169 | teacher.head = DINOHead(teacher.embed_dim, args.out_dim, args.use_bn_in_head)
170 |
171 | # otherwise, we check if the architecture is in torchvision models
172 | elif args.arch in torchvision_models.__dict__.keys():
173 | student = torchvision_models.__dict__[args.arch]()
174 | teacher = torchvision_models.__dict__[args.arch]()
175 | embed_dim = student.fc.weight.shape[1]
176 | student = utils.MultiCropWrapper(student, DINOHead(
177 | embed_dim,
178 | args.out_dim,
179 | use_bn=args.use_bn_in_head,
180 | norm_last_layer=args.norm_last_layer,
181 | ))
182 | teacher = utils.MultiCropWrapper(
183 | teacher,
184 | DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
185 | )
186 | else:
187 | print(f"Unknow architecture: {args.arch}")
188 |
189 | # move networks to gpu
190 | student, teacher = student.cuda(), teacher.cuda()
191 | # synchronize batch norms (if any)
192 | if utils.has_batchnorms(student):
193 | student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
194 | teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
195 |
196 | # we need DDP wrapper to have synchro batch norms working...
197 | teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
198 | teacher_without_ddp = teacher.module
199 | else:
200 | # teacher_without_ddp and teacher are the same thing
201 | teacher_without_ddp = teacher
202 | student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
203 | # teacher and student start with the same weights
204 | teacher_without_ddp.load_state_dict(student.module.state_dict())
205 | # there is no backpropagation through the teacher, so no need for gradients
206 | for p in teacher.parameters():
207 | p.requires_grad = False
208 | print(f"Student and Teacher are built: they are both {args.arch} network.")
209 |
210 | # ============ preparing loss ... ============
211 | dino_loss = DINOLoss(
212 | args.out_dim,
213 | args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number
214 | args.warmup_teacher_temp,
215 | args.teacher_temp,
216 | args.warmup_teacher_temp_epochs,
217 | args.epochs,
218 | ).cuda()
219 |
220 | # ============ preparing optimizer ... ============
221 | params_groups = utils.get_params_groups(student)
222 | if args.optimizer == "adamw":
223 | optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
224 | elif args.optimizer == "sgd":
225 | optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
226 | elif args.optimizer == "lars":
227 | optimizer = utils.LARS(params_groups) # to use with convnet and large batches
228 | # for mixed precision training
229 | fp16_scaler = None
230 | if args.use_fp16:
231 | fp16_scaler = torch.cuda.amp.GradScaler()
232 |
233 | # ============ init schedulers ... ============
234 | lr_schedule = utils.cosine_scheduler(
235 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
236 | args.min_lr,
237 | args.epochs, len(data_loader),
238 | warmup_epochs=args.warmup_epochs,
239 | )
240 | wd_schedule = utils.cosine_scheduler(
241 | args.weight_decay,
242 | args.weight_decay_end,
243 | args.epochs, len(data_loader),
244 | )
245 | # momentum parameter is increased to 1. during training with a cosine schedule
246 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
247 | args.epochs, len(data_loader))
248 | print(f"Loss, optimizer and schedulers ready.")
249 |
250 | # ============ optionally resume training ... ============
251 | to_restore = {"epoch": 0}
252 | utils.restart_from_checkpoint(
253 | os.path.join(args.output_dir, "checkpoint.pth"),
254 | run_variables=to_restore,
255 | student=student,
256 | teacher=teacher,
257 | optimizer=optimizer,
258 | fp16_scaler=fp16_scaler,
259 | dino_loss=dino_loss,
260 | )
261 | start_epoch = to_restore["epoch"]
262 |
263 | start_time = time.time()
264 | print("Starting DINO training !")
265 | for epoch in range(start_epoch, args.epochs):
266 | data_loader.sampler.set_epoch(epoch)
267 |
268 | # ============ training one epoch of DINO ... ============
269 | train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
270 | data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
271 | epoch, fp16_scaler, args)
272 |
273 | # ============ writing logs ... ============
274 | save_dict = {
275 | 'student': student.state_dict(),
276 | 'teacher': teacher.state_dict(),
277 | 'optimizer': optimizer.state_dict(),
278 | 'epoch': epoch + 1,
279 | 'args': args,
280 | 'dino_loss': dino_loss.state_dict(),
281 | }
282 | if fp16_scaler is not None:
283 | save_dict['fp16_scaler'] = fp16_scaler.state_dict()
284 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
285 | if args.saveckp_freq and epoch % args.saveckp_freq == 0:
286 | utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
287 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
288 | 'epoch': epoch}
289 | if utils.is_main_process():
290 | with (Path(args.output_dir) / "log.txt").open("a") as f:
291 | f.write(json.dumps(log_stats) + "\n")
292 | total_time = time.time() - start_time
293 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
294 | print('Training time {}'.format(total_time_str))
295 |
296 |
297 | def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
298 | optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
299 | fp16_scaler, args):
300 | metric_logger = utils.MetricLogger(delimiter=" ")
301 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
302 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
303 | # update weight decay and learning rate according to their schedule
304 | it = len(data_loader) * epoch + it # global training iteration
305 | for i, param_group in enumerate(optimizer.param_groups):
306 | param_group["lr"] = lr_schedule[it]
307 | if i == 0: # only the first group is regularized
308 | param_group["weight_decay"] = wd_schedule[it]
309 |
310 | # move images to gpu
311 | images = [im.cuda(non_blocking=True) for im in images]
312 | # teacher and student forward passes + compute dino loss
313 | with torch.cuda.amp.autocast(fp16_scaler is not None):
314 | teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher
315 | student_output = student(images)
316 | loss = dino_loss(student_output, teacher_output, epoch)
317 |
318 | if not math.isfinite(loss.item()):
319 | print("Loss is {}, stopping training".format(loss.item()), force=True)
320 | sys.exit(1)
321 |
322 | # student update
323 | optimizer.zero_grad()
324 | param_norms = None
325 | if fp16_scaler is None:
326 | loss.backward()
327 | if args.clip_grad:
328 | param_norms = utils.clip_gradients(student, args.clip_grad)
329 | utils.cancel_gradients_last_layer(epoch, student,
330 | args.freeze_last_layer)
331 | optimizer.step()
332 | else:
333 | fp16_scaler.scale(loss).backward()
334 | if args.clip_grad:
335 | fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
336 | param_norms = utils.clip_gradients(student, args.clip_grad)
337 | utils.cancel_gradients_last_layer(epoch, student,
338 | args.freeze_last_layer)
339 | fp16_scaler.step(optimizer)
340 | fp16_scaler.update()
341 |
342 | # EMA update for the teacher
343 | with torch.no_grad():
344 | m = momentum_schedule[it] # momentum parameter
345 | for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
346 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
347 |
348 | # logging
349 | torch.cuda.synchronize()
350 | metric_logger.update(loss=loss.item())
351 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
352 | metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
353 | # gather the stats from all processes
354 | metric_logger.synchronize_between_processes()
355 | print("Averaged stats:", metric_logger)
356 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
357 |
358 |
359 | class DINOLoss(nn.Module):
360 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
361 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
362 | center_momentum=0.9):
363 | super().__init__()
364 | self.student_temp = student_temp
365 | self.center_momentum = center_momentum
366 | self.ncrops = ncrops
367 | self.register_buffer("center", torch.zeros(1, out_dim))
368 | # we apply a warm up for the teacher temperature because
369 | # a too high temperature makes the training instable at the beginning
370 | self.teacher_temp_schedule = np.concatenate((
371 | np.linspace(warmup_teacher_temp,
372 | teacher_temp, warmup_teacher_temp_epochs),
373 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
374 | ))
375 |
376 | def forward(self, student_output, teacher_output, epoch):
377 | """
378 | Cross-entropy between softmax outputs of the teacher and student networks.
379 | """
380 | student_out = student_output / self.student_temp
381 | student_out = student_out.chunk(self.ncrops)
382 |
383 | # teacher centering and sharpening
384 | temp = self.teacher_temp_schedule[epoch]
385 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
386 | teacher_out = teacher_out.detach().chunk(2)
387 |
388 | total_loss = 0
389 | n_loss_terms = 0
390 | for iq, q in enumerate(teacher_out):
391 | for v in range(len(student_out)):
392 | if v == iq:
393 | # we skip cases where student and teacher operate on the same view
394 | continue
395 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
396 | total_loss += loss.mean()
397 | n_loss_terms += 1
398 | total_loss /= n_loss_terms
399 | self.update_center(teacher_output)
400 | return total_loss
401 |
402 | @torch.no_grad()
403 | def update_center(self, teacher_output):
404 | """
405 | Update center used for teacher output.
406 | """
407 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
408 | dist.all_reduce(batch_center)
409 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
410 |
411 | # ema update
412 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
413 |
414 |
415 | class DataAugmentationDINO(object):
416 | def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
417 | flip_and_color_jitter = transforms.Compose([
418 | transforms.RandomHorizontalFlip(p=0.5),
419 | transforms.RandomApply(
420 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
421 | p=0.8
422 | ),
423 | transforms.RandomGrayscale(p=0.2),
424 | ])
425 | normalize = transforms.Compose([
426 | transforms.ToTensor(),
427 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
428 | ])
429 |
430 | # first global crop
431 | self.global_transfo1 = transforms.Compose([
432 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
433 | flip_and_color_jitter,
434 | utils.GaussianBlur(1.0),
435 | normalize,
436 | ])
437 | # second global crop
438 | self.global_transfo2 = transforms.Compose([
439 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
440 | flip_and_color_jitter,
441 | utils.GaussianBlur(0.1),
442 | utils.Solarization(0.2),
443 | normalize,
444 | ])
445 | # transformation for the local small crops
446 | self.local_crops_number = local_crops_number
447 | self.local_transfo = transforms.Compose([
448 | transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
449 | flip_and_color_jitter,
450 | utils.GaussianBlur(p=0.5),
451 | normalize,
452 | ])
453 |
454 | def __call__(self, image):
455 | crops = []
456 | crops.append(self.global_transfo1(image))
457 | crops.append(self.global_transfo2(image))
458 | for _ in range(self.local_crops_number):
459 | crops.append(self.local_transfo(image))
460 | return crops
461 |
462 |
463 | if __name__ == '__main__':
464 | parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()])
465 | args = parser.parse_args()
466 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
467 | train_dino(args)
468 |
--------------------------------------------------------------------------------
/dino_model/run_with_submitit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A script to run multinode training with submitit.
16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
17 | """
18 | import argparse
19 | import os
20 | import uuid
21 | from pathlib import Path
22 |
23 | import main_dino
24 | import submitit
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()])
29 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
30 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
31 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
32 |
33 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
34 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
35 | parser.add_argument('--comment', default="", type=str,
36 | help='Comment to pass to scheduler, e.g. priority message')
37 | return parser.parse_args()
38 |
39 |
40 | def get_shared_folder() -> Path:
41 | user = os.getenv("USER")
42 | if Path("/checkpoint/").is_dir():
43 | p = Path(f"/checkpoint/{user}/experiments")
44 | p.mkdir(exist_ok=True)
45 | return p
46 | raise RuntimeError("No shared folder available")
47 |
48 |
49 | def get_init_file():
50 | # Init file must not exist, but it's parent dir must exist.
51 | os.makedirs(str(get_shared_folder()), exist_ok=True)
52 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
53 | if init_file.exists():
54 | os.remove(str(init_file))
55 | return init_file
56 |
57 |
58 | class Trainer(object):
59 | def __init__(self, args):
60 | self.args = args
61 |
62 | def __call__(self):
63 | import main_dino
64 |
65 | self._setup_gpu_args()
66 | main_dino.train_dino(self.args)
67 |
68 | def checkpoint(self):
69 | import os
70 | import submitit
71 |
72 | self.args.dist_url = get_init_file().as_uri()
73 | print("Requeuing ", self.args)
74 | empty_trainer = type(self)(self.args)
75 | return submitit.helpers.DelayedSubmission(empty_trainer)
76 |
77 | def _setup_gpu_args(self):
78 | import submitit
79 | from pathlib import Path
80 |
81 | job_env = submitit.JobEnvironment()
82 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
83 | self.args.gpu = job_env.local_rank
84 | self.args.rank = job_env.global_rank
85 | self.args.world_size = job_env.num_tasks
86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
87 |
88 |
89 | def main():
90 | args = parse_args()
91 | if args.output_dir == "":
92 | args.output_dir = get_shared_folder() / "%j"
93 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
94 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
95 |
96 | num_gpus_per_node = args.ngpus
97 | nodes = args.nodes
98 | timeout_min = args.timeout
99 |
100 | partition = args.partition
101 | kwargs = {}
102 | if args.use_volta32:
103 | kwargs['slurm_constraint'] = 'volta32gb'
104 | if args.comment:
105 | kwargs['slurm_comment'] = args.comment
106 |
107 | executor.update_parameters(
108 | mem_gb=40 * num_gpus_per_node,
109 | gpus_per_node=num_gpus_per_node,
110 | tasks_per_node=num_gpus_per_node, # one task per GPU
111 | cpus_per_task=10,
112 | nodes=nodes,
113 | timeout_min=timeout_min, # max is 60 * 72
114 | # Below are cluster dependent parameters
115 | slurm_partition=partition,
116 | slurm_signal_delay_s=120,
117 | **kwargs
118 | )
119 |
120 | executor.update_parameters(name="dino")
121 |
122 | args.dist_url = get_init_file().as_uri()
123 |
124 | trainer = Trainer(args)
125 | job = executor.submit(trainer)
126 |
127 | print(f"Submitted job_id: {job.job_id}")
128 | print(f"Logs and checkpoints will be saved at: {args.output_dir}")
129 |
130 |
131 | if __name__ == "__main__":
132 | main()
133 |
--------------------------------------------------------------------------------
/dino_model/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Misc functions.
16 |
17 | Mostly copy-paste from torchvision references or other public repos like DETR:
18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py
19 | """
20 | import os
21 | import sys
22 | import time
23 | import math
24 | import random
25 | import datetime
26 | import subprocess
27 | from collections import defaultdict, deque
28 |
29 | import numpy as np
30 | import torch
31 | from torch import nn
32 | import torch.distributed as dist
33 | from PIL import ImageFilter, ImageOps
34 |
35 |
36 | class GaussianBlur(object):
37 | """
38 | Apply Gaussian Blur to the PIL image.
39 | """
40 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
41 | self.prob = p
42 | self.radius_min = radius_min
43 | self.radius_max = radius_max
44 |
45 | def __call__(self, img):
46 | do_it = random.random() <= self.prob
47 | if not do_it:
48 | return img
49 |
50 | return img.filter(
51 | ImageFilter.GaussianBlur(
52 | radius=random.uniform(self.radius_min, self.radius_max)
53 | )
54 | )
55 |
56 |
57 | class Solarization(object):
58 | """
59 | Apply Solarization to the PIL image.
60 | """
61 | def __init__(self, p):
62 | self.p = p
63 |
64 | def __call__(self, img):
65 | if random.random() < self.p:
66 | return ImageOps.solarize(img)
67 | else:
68 | return img
69 |
70 |
71 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
72 | if os.path.isfile(pretrained_weights):
73 | state_dict = torch.load(pretrained_weights, map_location="cpu")
74 | if checkpoint_key is not None and checkpoint_key in state_dict:
75 | print(f"Take key {checkpoint_key} in provided checkpoint dict")
76 | state_dict = state_dict[checkpoint_key]
77 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
78 | msg = model.load_state_dict(state_dict, strict=False)
79 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
80 | else:
81 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
82 | url = None
83 | if model_name == "deit_small" and patch_size == 16:
84 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
85 | elif model_name == "deit_small" and patch_size == 8:
86 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
87 | elif model_name == "vit_base" and patch_size == 16:
88 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
89 | elif model_name == "vit_base" and patch_size == 8:
90 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
91 | if url is not None:
92 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
93 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
94 | model.load_state_dict(state_dict, strict=True)
95 | else:
96 | print("There is no reference weights available for this model => We use random weights.")
97 |
98 |
99 | def clip_gradients(model, clip):
100 | norms = []
101 | for name, p in model.named_parameters():
102 | if p.grad is not None:
103 | param_norm = p.grad.data.norm(2)
104 | norms.append(param_norm.item())
105 | clip_coef = clip / (param_norm + 1e-6)
106 | if clip_coef < 1:
107 | p.grad.data.mul_(clip_coef)
108 | return norms
109 |
110 |
111 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
112 | if epoch >= freeze_last_layer:
113 | return
114 | for n, p in model.named_parameters():
115 | if "last_layer" in n:
116 | p.grad = None
117 |
118 |
119 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
120 | """
121 | Re-start from checkpoint
122 | """
123 | if not os.path.isfile(ckp_path):
124 | return
125 | print("Found checkpoint at {}".format(ckp_path))
126 |
127 | # open checkpoint file
128 | checkpoint = torch.load(ckp_path, map_location="cpu")
129 |
130 | # key is what to look for in the checkpoint file
131 | # value is the object to load
132 | # example: {'state_dict': model}
133 | for key, value in kwargs.items():
134 | if key in checkpoint and value is not None:
135 | try:
136 | msg = value.load_state_dict(checkpoint[key], strict=False)
137 | print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
138 | except TypeError:
139 | try:
140 | msg = value.load_state_dict(checkpoint[key])
141 | print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
142 | except ValueError:
143 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
144 | else:
145 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
146 |
147 | # re load variable important for the run
148 | if run_variables is not None:
149 | for var_name in run_variables:
150 | if var_name in checkpoint:
151 | run_variables[var_name] = checkpoint[var_name]
152 |
153 |
154 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
155 | warmup_schedule = np.array([])
156 | warmup_iters = warmup_epochs * niter_per_ep
157 | if warmup_epochs > 0:
158 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
159 |
160 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
161 | schedule = np.array([final_value + 0.5 * (base_value - final_value) * (1 + \
162 | math.cos(math.pi * i / (len(iters)))) for i in iters])
163 |
164 | schedule = np.concatenate((warmup_schedule, schedule))
165 | assert len(schedule) == epochs * niter_per_ep
166 | return schedule
167 |
168 |
169 | def bool_flag(s):
170 | """
171 | Parse boolean arguments from the command line.
172 | """
173 | FALSY_STRINGS = {"off", "false", "0"}
174 | TRUTHY_STRINGS = {"on", "true", "1"}
175 | if s.lower() in FALSY_STRINGS:
176 | return False
177 | elif s.lower() in TRUTHY_STRINGS:
178 | return True
179 | else:
180 | raise argparse.ArgumentTypeError("invalid value for a boolean flag")
181 |
182 |
183 | def fix_random_seeds(seed=31):
184 | """
185 | Fix random seeds.
186 | """
187 | torch.manual_seed(seed)
188 | torch.cuda.manual_seed_all(seed)
189 | np.random.seed(seed)
190 |
191 |
192 | class SmoothedValue(object):
193 | """Track a series of values and provide access to smoothed values over a
194 | window or the global series average.
195 | """
196 |
197 | def __init__(self, window_size=20, fmt=None):
198 | if fmt is None:
199 | fmt = "{median:.6f} ({global_avg:.6f})"
200 | self.deque = deque(maxlen=window_size)
201 | self.total = 0.0
202 | self.count = 0
203 | self.fmt = fmt
204 |
205 | def update(self, value, n=1):
206 | self.deque.append(value)
207 | self.count += n
208 | self.total += value * n
209 |
210 | def synchronize_between_processes(self):
211 | """
212 | Warning: does not synchronize the deque!
213 | """
214 | if not is_dist_avail_and_initialized():
215 | return
216 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
217 | dist.barrier()
218 | dist.all_reduce(t)
219 | t = t.tolist()
220 | self.count = int(t[0])
221 | self.total = t[1]
222 |
223 | @property
224 | def median(self):
225 | d = torch.tensor(list(self.deque))
226 | return d.median().item()
227 |
228 | @property
229 | def avg(self):
230 | d = torch.tensor(list(self.deque), dtype=torch.float32)
231 | return d.mean().item()
232 |
233 | @property
234 | def global_avg(self):
235 | return self.total / self.count
236 |
237 | @property
238 | def max(self):
239 | return max(self.deque)
240 |
241 | @property
242 | def value(self):
243 | return self.deque[-1]
244 |
245 | def __str__(self):
246 | return self.fmt.format(
247 | median=self.median,
248 | avg=self.avg,
249 | global_avg=self.global_avg,
250 | max=self.max,
251 | value=self.value)
252 |
253 |
254 | def reduce_dict(input_dict, average=True):
255 | """
256 | Args:
257 | input_dict (dict): all the values will be reduced
258 | average (bool): whether to do average or sum
259 | Reduce the values in the dictionary from all processes so that all processes
260 | have the averaged results. Returns a dict with the same fields as
261 | input_dict, after reduction.
262 | """
263 | world_size = get_world_size()
264 | if world_size < 2:
265 | return input_dict
266 | with torch.no_grad():
267 | names = []
268 | values = []
269 | # sort the keys so that they are consistent across processes
270 | for k in sorted(input_dict.keys()):
271 | names.append(k)
272 | values.append(input_dict[k])
273 | values = torch.stack(values, dim=0)
274 | dist.all_reduce(values)
275 | if average:
276 | values /= world_size
277 | reduced_dict = {k: v for k, v in zip(names, values)}
278 | return reduced_dict
279 |
280 |
281 | class MetricLogger(object):
282 | def __init__(self, delimiter="\t"):
283 | self.meters = defaultdict(SmoothedValue)
284 | self.delimiter = delimiter
285 |
286 | def update(self, **kwargs):
287 | for k, v in kwargs.items():
288 | if isinstance(v, torch.Tensor):
289 | v = v.item()
290 | assert isinstance(v, (float, int))
291 | self.meters[k].update(v)
292 |
293 | def __getattr__(self, attr):
294 | if attr in self.meters:
295 | return self.meters[attr]
296 | if attr in self.__dict__:
297 | return self.__dict__[attr]
298 | raise AttributeError("'{}' object has no attribute '{}'".format(
299 | type(self).__name__, attr))
300 |
301 | def __str__(self):
302 | loss_str = []
303 | for name, meter in self.meters.items():
304 | loss_str.append(
305 | "{}: {}".format(name, str(meter))
306 | )
307 | return self.delimiter.join(loss_str)
308 |
309 | def synchronize_between_processes(self):
310 | for meter in self.meters.values():
311 | meter.synchronize_between_processes()
312 |
313 | def add_meter(self, name, meter):
314 | self.meters[name] = meter
315 |
316 | def log_every(self, iterable, print_freq, header=None):
317 | i = 0
318 | if not header:
319 | header = ''
320 | start_time = time.time()
321 | end = time.time()
322 | iter_time = SmoothedValue(fmt='{avg:.6f}')
323 | data_time = SmoothedValue(fmt='{avg:.6f}')
324 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
325 | if torch.cuda.is_available():
326 | log_msg = self.delimiter.join([
327 | header,
328 | '[{0' + space_fmt + '}/{1}]',
329 | 'eta: {eta}',
330 | '{meters}',
331 | 'time: {time}',
332 | 'data: {data}',
333 | 'max mem: {memory:.0f}'
334 | ])
335 | else:
336 | log_msg = self.delimiter.join([
337 | header,
338 | '[{0' + space_fmt + '}/{1}]',
339 | 'eta: {eta}',
340 | '{meters}',
341 | 'time: {time}',
342 | 'data: {data}'
343 | ])
344 | MB = 1024.0 * 1024.0
345 | for obj in iterable:
346 | data_time.update(time.time() - end)
347 | yield obj
348 | iter_time.update(time.time() - end)
349 | if i % print_freq == 0 or i == len(iterable) - 1:
350 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
351 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
352 | if torch.cuda.is_available():
353 | print(log_msg.format(
354 | i, len(iterable), eta=eta_string,
355 | meters=str(self),
356 | time=str(iter_time), data=str(data_time),
357 | memory=torch.cuda.max_memory_allocated() / MB))
358 | else:
359 | print(log_msg.format(
360 | i, len(iterable), eta=eta_string,
361 | meters=str(self),
362 | time=str(iter_time), data=str(data_time)))
363 | i += 1
364 | end = time.time()
365 | total_time = time.time() - start_time
366 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
367 | print('{} Total time: {} ({:.6f} s / it)'.format(
368 | header, total_time_str, total_time / len(iterable)))
369 |
370 |
371 | def get_sha():
372 | cwd = os.path.dirname(os.path.abspath(__file__))
373 |
374 | def _run(command):
375 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
376 | sha = 'N/A'
377 | diff = "clean"
378 | branch = 'N/A'
379 | try:
380 | sha = _run(['git', 'rev-parse', 'HEAD'])
381 | subprocess.check_output(['git', 'diff'], cwd=cwd)
382 | diff = _run(['git', 'diff-index', 'HEAD'])
383 | diff = "has uncommited changes" if diff else "clean"
384 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
385 | except Exception:
386 | pass
387 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
388 | return message
389 |
390 |
391 | def is_dist_avail_and_initialized():
392 | if not dist.is_available():
393 | return False
394 | if not dist.is_initialized():
395 | return False
396 | return True
397 |
398 |
399 | def get_world_size():
400 | if not is_dist_avail_and_initialized():
401 | return 1
402 | return dist.get_world_size()
403 |
404 |
405 | def get_rank():
406 | if not is_dist_avail_and_initialized():
407 | return 0
408 | return dist.get_rank()
409 |
410 |
411 | def is_main_process():
412 | return get_rank() == 0
413 |
414 |
415 | def save_on_master(*args, **kwargs):
416 | if is_main_process():
417 | torch.save(*args, **kwargs)
418 |
419 |
420 | def setup_for_distributed(is_master):
421 | """
422 | This function disables printing when not in master process
423 | """
424 | import builtins as __builtin__
425 | builtin_print = __builtin__.print
426 |
427 | def print(*args, **kwargs):
428 | force = kwargs.pop('force', False)
429 | if is_master or force:
430 | builtin_print(*args, **kwargs)
431 |
432 | __builtin__.print = print
433 |
434 |
435 | def init_distributed_mode(args):
436 | # launched with torch.distributed.launch
437 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
438 | args.rank = int(os.environ["RANK"])
439 | args.world_size = int(os.environ['WORLD_SIZE'])
440 | args.gpu = int(os.environ['LOCAL_RANK'])
441 | # launched with submitit on a slurm cluster
442 | elif 'SLURM_PROCID' in os.environ:
443 | args.rank = int(os.environ['SLURM_PROCID'])
444 | args.gpu = args.rank % torch.cuda.device_count()
445 | # launched naively with `python main_dino.py`
446 | # we manually add MASTER_ADDR and MASTER_PORT to env variables
447 | elif torch.cuda.is_available():
448 | print('Will run the code on one GPU.')
449 | args.rank, args.gpu, args.world_size = 0, 0, 1
450 | os.environ['MASTER_ADDR'] = '127.0.0.1'
451 | os.environ['MASTER_PORT'] = '29500'
452 | else:
453 | print('Does not support training without GPU.')
454 | sys.exit(1)
455 |
456 | dist.init_process_group(
457 | backend="nccl",
458 | init_method=args.dist_url,
459 | world_size=args.world_size,
460 | rank=args.rank,
461 | )
462 |
463 | torch.cuda.set_device(args.gpu)
464 | print('| distributed init (rank {}): {}'.format(
465 | args.rank, args.dist_url), flush=True)
466 | dist.barrier()
467 | setup_for_distributed(args.rank == 0)
468 |
469 |
470 | def accuracy(output, target, topk=(1,)):
471 | """Computes the accuracy over the k top predictions for the specified values of k"""
472 | maxk = max(topk)
473 | batch_size = target.size(0)
474 | _, pred = output.topk(maxk, 1, True, True)
475 | pred = pred.t()
476 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
477 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
478 |
479 |
480 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
481 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
482 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
483 | def norm_cdf(x):
484 | # Computes standard normal cumulative distribution function
485 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
486 |
487 | if (mean < a - 2 * std) or (mean > b + 2 * std):
488 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
489 | "The distribution of values may be incorrect.",
490 | stacklevel=2)
491 |
492 | with torch.no_grad():
493 | # Values are generated by using a truncated uniform distribution and
494 | # then using the inverse CDF for the normal distribution.
495 | # Get upper and lower cdf values
496 | l = norm_cdf((a - mean) / std)
497 | u = norm_cdf((b - mean) / std)
498 |
499 | # Uniformly fill tensor with values from [l, u], then translate to
500 | # [2l-1, 2u-1].
501 | tensor.uniform_(2 * l - 1, 2 * u - 1)
502 |
503 | # Use inverse cdf transform for normal distribution to get truncated
504 | # standard normal
505 | tensor.erfinv_()
506 |
507 | # Transform to proper mean, std
508 | tensor.mul_(std * math.sqrt(2.))
509 | tensor.add_(mean)
510 |
511 | # Clamp to ensure it's in the proper range
512 | tensor.clamp_(min=a, max=b)
513 | return tensor
514 |
515 |
516 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
517 | # type: (Tensor, float, float, float, float) -> Tensor
518 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
519 |
520 |
521 | class LARS(torch.optim.Optimizer):
522 | """
523 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
524 | """
525 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
526 | weight_decay_filter=None, lars_adaptation_filter=None):
527 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
528 | eta=eta, weight_decay_filter=weight_decay_filter,
529 | lars_adaptation_filter=lars_adaptation_filter)
530 | super().__init__(params, defaults)
531 |
532 | @torch.no_grad()
533 | def step(self):
534 | for g in self.param_groups:
535 | for p in g['params']:
536 | dp = p.grad
537 |
538 | if dp is None:
539 | continue
540 |
541 | if p.ndim != 1:
542 | dp = dp.add(p, alpha=g['weight_decay'])
543 |
544 | if p.ndim != 1:
545 | param_norm = torch.norm(p)
546 | update_norm = torch.norm(dp)
547 | one = torch.ones_like(param_norm)
548 | q = torch.where(param_norm > 0.,
549 | torch.where(update_norm > 0,
550 | (g['eta'] * param_norm / update_norm), one), one)
551 | dp = dp.mul(q)
552 |
553 | param_state = self.state[p]
554 | if 'mu' not in param_state:
555 | param_state['mu'] = torch.zeros_like(p)
556 | mu = param_state['mu']
557 | mu.mul_(g['momentum']).add_(dp)
558 |
559 | p.add_(mu, alpha=-g['lr'])
560 |
561 |
562 | class MultiCropWrapper(nn.Module):
563 | """
564 | Perform forward pass separately on each resolution input.
565 | The inputs corresponding to a single resolution are clubbed and single
566 | forward is run on the same resolution inputs. Hence we do several
567 | forward passes = number of different resolutions used. We then
568 | concatenate all the output features.
569 | """
570 | def __init__(self, backbone, head):
571 | super(MultiCropWrapper, self).__init__()
572 | backbone.fc = nn.Identity()
573 | self.backbone = backbone
574 | self.head = head
575 |
576 | def forward(self, x):
577 | # convert to list
578 | if not isinstance(x, list):
579 | x = [x]
580 | idx_crops = torch.cumsum(torch.unique_consecutive(
581 | torch.tensor([inp.shape[-1] for inp in x]),
582 | return_counts=True,
583 | )[1], 0)
584 | start_idx = 0
585 | for end_idx in idx_crops:
586 | _out = self.backbone(torch.cat(x[start_idx: end_idx]))
587 | if start_idx == 0:
588 | output = _out
589 | else:
590 | output = torch.cat((output, _out))
591 | start_idx = end_idx
592 | # Run the head forward on the concatenated features.
593 | return self.head(output)
594 |
595 |
596 | def get_params_groups(model):
597 | regularized = []
598 | not_regularized = []
599 | for name, param in model.named_parameters():
600 | if not param.requires_grad:
601 | continue
602 | # we do not regularize biases nor Norm parameters
603 | if name.endswith(".bias") or len(param.shape) == 1:
604 | not_regularized.append(param)
605 | else:
606 | regularized.append(param)
607 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
608 |
609 |
610 | def has_batchnorms(model):
611 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
612 | for name, module in model.named_modules():
613 | if isinstance(module, bn_types):
614 | return True
615 | return False
616 |
--------------------------------------------------------------------------------
/dino_model/video_generation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import glob
16 | import sys
17 | import argparse
18 | import cv2
19 |
20 | from tqdm import tqdm
21 | import matplotlib.pyplot as plt
22 | import torch
23 | import torch.nn as nn
24 | import torchvision
25 | from torchvision import transforms as pth_transforms
26 | import numpy as np
27 | from PIL import Image
28 |
29 | import utils
30 | import vision_transformer as vits
31 |
32 |
33 | FOURCC = {
34 | "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
35 | "avi": cv2.VideoWriter_fourcc(*"XVID"),
36 | }
37 | DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
38 |
39 |
40 | class VideoGenerator:
41 | def __init__(self, args):
42 | self.args = args
43 | # self.model = None
44 | # Don't need to load model if you only want a video
45 | if not self.args.video_only:
46 | self.model = self.__load_model()
47 |
48 | def run(self):
49 | if self.args.input_path is None:
50 | print(f"Provided input path {self.args.input_path} is non valid.")
51 | sys.exit(1)
52 | else:
53 | if self.args.video_only:
54 | self._generate_video_from_images(
55 | self.args.input_path, self.args.output_path
56 | )
57 | else:
58 | # If input path exists
59 | if os.path.exists(self.args.input_path):
60 | # If input is a video file
61 | if os.path.isfile(self.args.input_path):
62 | frames_folder = os.path.join(self.args.output_path, "frames")
63 | attention_folder = os.path.join(
64 | self.args.output_path, "attention"
65 | )
66 |
67 | os.makedirs(frames_folder, exist_ok=True)
68 | os.makedirs(attention_folder, exist_ok=True)
69 |
70 | self._extract_frames_from_video(
71 | self.args.input_path, frames_folder
72 | )
73 |
74 | self._inference(
75 | frames_folder,
76 | attention_folder,
77 | )
78 |
79 | self._generate_video_from_images(
80 | attention_folder, self.args.output_path
81 | )
82 |
83 | # If input is a folder of already extracted frames
84 | if os.path.isdir(self.args.input_path):
85 | attention_folder = os.path.join(
86 | self.args.output_path, "attention"
87 | )
88 |
89 | os.makedirs(attention_folder, exist_ok=True)
90 |
91 | self._inference(self.args.input_path, attention_folder)
92 |
93 | self._generate_video_from_images(
94 | attention_folder, self.args.output_path
95 | )
96 |
97 | # If input path doesn't exists
98 | else:
99 | print(f"Provided input path {self.args.input_path} doesn't exists.")
100 | sys.exit(1)
101 |
102 | def _extract_frames_from_video(self, inp: str, out: str):
103 | vidcap = cv2.VideoCapture(inp)
104 | self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
105 |
106 | print(f"Video: {inp} ({self.args.fps} fps)")
107 | print(f"Extracting frames to {out}")
108 |
109 | success, image = vidcap.read()
110 | count = 0
111 | while success:
112 | cv2.imwrite(
113 | os.path.join(out, f"frame-{count:04}.jpg"),
114 | image,
115 | )
116 | success, image = vidcap.read()
117 | count += 1
118 |
119 | def _generate_video_from_images(self, inp: str, out: str):
120 | img_array = []
121 | attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
122 |
123 | # Get size of the first image
124 | with open(attention_images_list[0], "rb") as f:
125 | img = Image.open(f)
126 | img = img.convert("RGB")
127 | size = (img.width, img.height)
128 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
129 |
130 | print(f"Generating video {size} to {out}")
131 |
132 | for filename in tqdm(attention_images_list[1:]):
133 | with open(filename, "rb") as f:
134 | img = Image.open(f)
135 | img = img.convert("RGB")
136 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
137 |
138 | out = cv2.VideoWriter(
139 | os.path.join(out, "video." + self.args.video_format),
140 | FOURCC[self.args.video_format],
141 | self.args.fps,
142 | size,
143 | )
144 |
145 | for i in range(len(img_array)):
146 | out.write(img_array[i])
147 | out.release()
148 | print("Done")
149 |
150 | def _inference(self, inp: str, out: str):
151 | print(f"Generating attention images to {out}")
152 |
153 | for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
154 | with open(img_path, "rb") as f:
155 | img = Image.open(f)
156 | img = img.convert("RGB")
157 |
158 | if self.args.resize is not None:
159 | transform = pth_transforms.Compose(
160 | [
161 | pth_transforms.ToTensor(),
162 | pth_transforms.Resize(self.args.resize),
163 | pth_transforms.Normalize(
164 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
165 | ),
166 | ]
167 | )
168 | else:
169 | transform = pth_transforms.Compose(
170 | [
171 | pth_transforms.ToTensor(),
172 | pth_transforms.Normalize(
173 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
174 | ),
175 | ]
176 | )
177 |
178 | img = transform(img)
179 |
180 | # make the image divisible by the patch size
181 | w, h = (
182 | img.shape[1] - img.shape[1] % self.args.patch_size,
183 | img.shape[2] - img.shape[2] % self.args.patch_size,
184 | )
185 | img = img[:, :w, :h].unsqueeze(0)
186 |
187 | w_featmap = img.shape[-2] // self.args.patch_size
188 | h_featmap = img.shape[-1] // self.args.patch_size
189 |
190 | attentions = self.model.forward_selfattention(img.to(DEVICE))
191 |
192 | nh = attentions.shape[1] # number of head
193 |
194 | # we keep only the output patch attention
195 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
196 |
197 | # we keep only a certain percentage of the mass
198 | val, idx = torch.sort(attentions)
199 | val /= torch.sum(val, dim=1, keepdim=True)
200 | cumval = torch.cumsum(val, dim=1)
201 | th_attn = cumval > (1 - self.args.threshold)
202 | idx2 = torch.argsort(idx)
203 | for head in range(nh):
204 | th_attn[head] = th_attn[head][idx2[head]]
205 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
206 | # interpolate
207 | th_attn = (
208 | nn.functional.interpolate(
209 | th_attn.unsqueeze(0),
210 | scale_factor=self.args.patch_size,
211 | mode="nearest",
212 | )[0]
213 | .cpu()
214 | .numpy()
215 | )
216 |
217 | attentions = attentions.reshape(nh, w_featmap, h_featmap)
218 | attentions = (
219 | nn.functional.interpolate(
220 | attentions.unsqueeze(0),
221 | scale_factor=self.args.patch_size,
222 | mode="nearest",
223 | )[0]
224 | .cpu()
225 | .numpy()
226 | )
227 |
228 | # save attentions heatmaps
229 | fname = os.path.join(out, "attn-" + os.path.basename(img_path))
230 | plt.imsave(
231 | fname=fname,
232 | arr=sum(
233 | attentions[i] * 1 / attentions.shape[0]
234 | for i in range(attentions.shape[0])
235 | ),
236 | cmap="inferno",
237 | format="jpg",
238 | )
239 |
240 | def __load_model(self):
241 | # build model
242 | model = vits.__dict__[self.args.arch](
243 | patch_size=self.args.patch_size, num_classes=0
244 | )
245 | for p in model.parameters():
246 | p.requires_grad = False
247 | model.eval()
248 | model.to(DEVICE)
249 |
250 | if os.path.isfile(self.args.pretrained_weights):
251 | state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
252 | if (
253 | self.args.checkpoint_key is not None
254 | and self.args.checkpoint_key in state_dict
255 | ):
256 | print(
257 | f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
258 | )
259 | state_dict = state_dict[self.args.checkpoint_key]
260 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
261 | msg = model.load_state_dict(state_dict, strict=False)
262 | print(
263 | "Pretrained weights found at {} and loaded with msg: {}".format(
264 | self.args.pretrained_weights, msg
265 | )
266 | )
267 | else:
268 | print(
269 | "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
270 | )
271 | url = None
272 | if self.args.arch == "deit_small" and self.args.patch_size == 16:
273 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
274 | elif self.args.arch == "deit_small" and self.args.patch_size == 8:
275 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
276 | elif self.args.arch == "vit_base" and self.args.patch_size == 16:
277 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
278 | elif self.args.arch == "vit_base" and self.args.patch_size == 8:
279 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
280 | if url is not None:
281 | print(
282 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
283 | )
284 | state_dict = torch.hub.load_state_dict_from_url(
285 | url="https://dl.fbaipublicfiles.com/dino/" + url
286 | )
287 | model.load_state_dict(state_dict, strict=True)
288 | else:
289 | print(
290 | "There is no reference weights available for this model => We use random weights."
291 | )
292 | return model
293 |
294 |
295 | def parse_args():
296 | parser = argparse.ArgumentParser("Generation self-attention video")
297 | parser.add_argument(
298 | "--arch",
299 | default="deit_small",
300 | type=str,
301 | choices=["deit_tiny", "deit_small", "vit_base"],
302 | help="Architecture (support only ViT atm).",
303 | )
304 | parser.add_argument(
305 | "--patch_size", default=8, type=int, help="Patch resolution of the self.model."
306 | )
307 | parser.add_argument(
308 | "--pretrained_weights",
309 | default="",
310 | type=str,
311 | help="Path to pretrained weights to load.",
312 | )
313 | parser.add_argument(
314 | "--checkpoint_key",
315 | default="teacher",
316 | type=str,
317 | help='Key to use in the checkpoint (example: "teacher")',
318 | )
319 | parser.add_argument(
320 | "--input_path",
321 | required=True,
322 | type=str,
323 | help="""Path to a video file if you want to extract frames
324 | or to a folder of images already extracted by yourself.
325 | or to a folder of attention images.""",
326 | )
327 | parser.add_argument(
328 | "--output_path",
329 | default="./",
330 | type=str,
331 | help="""Path to store a folder of frames and / or a folder of attention images.
332 | and / or a final video. Default to current directory.""",
333 | )
334 | parser.add_argument(
335 | "--threshold",
336 | type=float,
337 | default=0.6,
338 | help="""We visualize masks
339 | obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
340 | )
341 | parser.add_argument(
342 | "--resize",
343 | default=None,
344 | type=int,
345 | nargs="+",
346 | help="""Apply a resize transformation to input image(s). Use if OOM error.
347 | Usage (single or W H): --resize 512, --resize 720 1280""",
348 | )
349 | parser.add_argument(
350 | "--video_only",
351 | action="store_true",
352 | help="""Use this flag if you only want to generate a video and not all attention images.
353 | If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
354 | )
355 | parser.add_argument(
356 | "--fps",
357 | default=30.0,
358 | type=float,
359 | help="FPS of input / output video. Automatically set if you extract frames from a video.",
360 | )
361 | parser.add_argument(
362 | "--video_format",
363 | default="mp4",
364 | type=str,
365 | choices=["mp4", "avi"],
366 | help="Format of generated video (mp4 or avi).",
367 | )
368 |
369 | return parser.parse_args()
370 |
371 |
372 | if __name__ == "__main__":
373 | args = parse_args()
374 |
375 | vg = VideoGenerator(args)
376 | vg.run()
377 |
--------------------------------------------------------------------------------
/dino_model/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from dino_model.utils import trunc_normal_
25 |
26 |
27 | def drop_path(x, drop_prob: float = 0., training: bool = False):
28 | if drop_prob == 0. or not training:
29 | return x
30 | keep_prob = 1 - drop_prob
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | output = x.div(keep_prob) * random_tensor
35 | return output
36 |
37 |
38 | class DropPath(nn.Module):
39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
40 | """
41 | def __init__(self, drop_prob=None):
42 | super(DropPath, self).__init__()
43 | self.drop_prob = drop_prob
44 |
45 | def forward(self, x):
46 | return drop_path(x, self.drop_prob, self.training)
47 |
48 |
49 | class Mlp(nn.Module):
50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51 | super().__init__()
52 | out_features = out_features or in_features
53 | hidden_features = hidden_features or in_features
54 | self.fc1 = nn.Linear(in_features, hidden_features)
55 | self.act = act_layer()
56 | self.fc2 = nn.Linear(hidden_features, out_features)
57 | self.drop = nn.Dropout(drop)
58 |
59 | def forward(self, x):
60 | x = self.fc1(x)
61 | x = self.act(x)
62 | x = self.drop(x)
63 | x = self.fc2(x)
64 | x = self.drop(x)
65 | return x
66 |
67 |
68 | class Attention(nn.Module):
69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
70 | super().__init__()
71 | self.num_heads = num_heads
72 | head_dim = dim // num_heads
73 | self.scale = qk_scale or head_dim ** -0.5
74 |
75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76 | self.attn_drop = nn.Dropout(attn_drop)
77 | self.proj = nn.Linear(dim, dim)
78 | self.proj_drop = nn.Dropout(proj_drop)
79 |
80 | def forward(self, x):
81 | B, N, C = x.shape
82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83 | q, k, v = qkv[0], qkv[1], qkv[2]
84 |
85 | attn = (q @ k.transpose(-2, -1)) * self.scale
86 | attn = attn.softmax(dim=-1)
87 | attn = self.attn_drop(attn)
88 |
89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x, attn
93 |
94 |
95 | class Block(nn.Module):
96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
98 | super().__init__()
99 | self.norm1 = norm_layer(dim)
100 | self.attn = Attention(
101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
106 |
107 | def forward(self, x, return_attention=False):
108 | y, attn = self.attn(self.norm1(x))
109 | if return_attention:
110 | return attn
111 | x = x + self.drop_path(y)
112 | x = x + self.drop_path(self.mlp(self.norm2(x)))
113 | return x
114 |
115 |
116 | class PatchEmbed(nn.Module):
117 | """ Image to Patch Embedding
118 | """
119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
120 | super().__init__()
121 | num_patches = (img_size // patch_size) * (img_size // patch_size)
122 | self.img_size = img_size
123 | self.patch_size = patch_size
124 | self.num_patches = num_patches
125 |
126 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
127 |
128 | def forward(self, x):
129 | B, C, H, W = x.shape
130 | x = self.proj(x).flatten(2).transpose(1, 2)
131 | return x
132 |
133 |
134 | class VisionTransformer(nn.Module):
135 | """ Vision Transformer """
136 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
137 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
138 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
139 | super().__init__()
140 | self.num_features = self.embed_dim = embed_dim
141 |
142 | self.patch_embed = PatchEmbed(
143 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
144 | num_patches = self.patch_embed.num_patches
145 |
146 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148 | self.pos_drop = nn.Dropout(p=drop_rate)
149 |
150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151 | self.blocks = nn.ModuleList([
152 | Block(
153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
155 | for i in range(depth)])
156 | self.norm = norm_layer(embed_dim)
157 |
158 | # Classifier head
159 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
160 |
161 | trunc_normal_(self.pos_embed, std=.02)
162 | trunc_normal_(self.cls_token, std=.02)
163 | self.apply(self._init_weights)
164 |
165 | def _init_weights(self, m):
166 | if isinstance(m, nn.Linear):
167 | trunc_normal_(m.weight, std=.02)
168 | if isinstance(m, nn.Linear) and m.bias is not None:
169 | nn.init.constant_(m.bias, 0)
170 | elif isinstance(m, nn.LayerNorm):
171 | nn.init.constant_(m.bias, 0)
172 | nn.init.constant_(m.weight, 1.0)
173 |
174 | def forward(self, x):
175 | # convert to list
176 | if not isinstance(x, list):
177 | x = [x]
178 | # Perform forward pass separately on each resolution input.
179 | # The inputs corresponding to a single resolution are clubbed and single
180 | # forward is run on the same resolution inputs. Hence we do several
181 | # forward passes = number of different resolutions used. We then
182 | # concatenate all the output features.
183 | idx_crops = torch.cumsum(torch.unique_consecutive(
184 | torch.tensor([inp.shape[-1] for inp in x]),
185 | return_counts=True,
186 | )[1], 0)
187 | start_idx = 0
188 | for end_idx in idx_crops:
189 | _out = self.forward_features(torch.cat(x[start_idx: end_idx]))
190 | if start_idx == 0:
191 | output = _out
192 | else:
193 | output = torch.cat((output, _out))
194 | start_idx = end_idx
195 | # Run the head forward on the concatenated features.
196 | return self.head(output)
197 |
198 | def forward_features(self, x):
199 | B = x.shape[0]
200 | x = self.patch_embed(x)
201 |
202 | cls_tokens = self.cls_token.expand(B, -1, -1)
203 | x = torch.cat((cls_tokens, x), dim=1)
204 | pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
205 | x = x + pos_embed
206 | x = self.pos_drop(x)
207 |
208 | for blk in self.blocks:
209 | x = blk(x)
210 | if self.norm is not None:
211 | x = self.norm(x)
212 |
213 | return x[:, 0]
214 |
215 | def interpolate_pos_encoding(self, x, pos_embed):
216 | npatch = x.shape[1] - 1
217 | N = pos_embed.shape[1] - 1
218 | if npatch == N:
219 | return pos_embed
220 | class_emb = pos_embed[:, 0]
221 | pos_embed = pos_embed[:, 1:]
222 | dim = x.shape[-1]
223 | pos_embed = nn.functional.interpolate(
224 | pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
225 | scale_factor=math.sqrt(npatch / N),
226 | mode='bicubic',
227 | )
228 | pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
229 | return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
230 |
231 | def forward_selfattention(self, x):
232 | B, nc, w, h = x.shape
233 | N = self.pos_embed.shape[1] - 1
234 | x = self.patch_embed(x)
235 |
236 | # interpolate patch embeddings
237 | dim = x.shape[-1]
238 | w0 = w // self.patch_embed.patch_size
239 | h0 = h // self.patch_embed.patch_size
240 | class_pos_embed = self.pos_embed[:, 0]
241 | patch_pos_embed = self.pos_embed[:, 1:]
242 | patch_pos_embed = nn.functional.interpolate(
243 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
244 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
245 | mode='bicubic',
246 | )
247 | # sometimes there is a floating point error in the interpolation and so
248 | # we need to pad the patch positional encoding.
249 | if w0 != patch_pos_embed.shape[-2]:
250 | helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
251 | patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
252 | if h0 != patch_pos_embed.shape[-1]:
253 | helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
254 | patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
255 |
256 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
257 | pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
258 | cls_tokens = self.cls_token.expand(B, -1, -1)
259 | x = torch.cat((cls_tokens, x), dim=1)
260 | x = x + pos_embed
261 | x = self.pos_drop(x)
262 |
263 | for i, blk in enumerate(self.blocks):
264 | if i < len(self.blocks) - 1:
265 | x = blk(x)
266 | else:
267 | return blk(x, return_attention=True)
268 |
269 | def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
270 | B = x.shape[0]
271 | x = self.patch_embed(x)
272 |
273 | cls_tokens = self.cls_token.expand(B, -1, -1)
274 | x = torch.cat((cls_tokens, x), dim=1)
275 | pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
276 | x = x + pos_embed
277 | x = self.pos_drop(x)
278 |
279 | # we will return the [CLS] tokens from the `n` last blocks
280 | output = []
281 | for i, blk in enumerate(self.blocks):
282 | x = blk(x)
283 | if len(self.blocks) - i <= n:
284 | output.append(self.norm(x)[:, 0])
285 | if return_patch_avgpool:
286 | x = self.norm(x)
287 | # In addition to the [CLS] tokens from the `n` last blocks, we also return
288 | # the patch tokens from the last block. This is useful for linear eval.
289 | output.append(torch.mean(x[:, 1:], dim=1))
290 | return torch.cat(output, dim=-1)
291 |
292 |
293 | def deit_tiny(patch_size=16, **kwargs):
294 | model = VisionTransformer(
295 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
296 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
297 | return model
298 |
299 |
300 | def deit_small(patch_size=16, **kwargs):
301 | model = VisionTransformer(
302 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
303 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
304 | return model
305 |
306 |
307 | def vit_base(patch_size=16, **kwargs):
308 | model = VisionTransformer(
309 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
310 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
311 | return model
312 |
313 |
314 | class DINOHead(nn.Module):
315 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
316 | super().__init__()
317 | nlayers = max(nlayers, 1)
318 | if nlayers == 1:
319 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
320 | else:
321 | layers = [nn.Linear(in_dim, hidden_dim)]
322 | if use_bn:
323 | layers.append(nn.BatchNorm1d(hidden_dim))
324 | layers.append(nn.GELU())
325 | for _ in range(nlayers - 2):
326 | layers.append(nn.Linear(hidden_dim, hidden_dim))
327 | if use_bn:
328 | layers.append(nn.BatchNorm1d(hidden_dim))
329 | layers.append(nn.GELU())
330 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
331 | self.mlp = nn.Sequential(*layers)
332 | self.apply(self._init_weights)
333 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
334 | self.last_layer.weight_g.data.fill_(1)
335 | if norm_last_layer:
336 | self.last_layer.weight_g.requires_grad = False
337 |
338 | def _init_weights(self, m):
339 | if isinstance(m, nn.Linear):
340 | trunc_normal_(m.weight, std=.02)
341 | if isinstance(m, nn.Linear) and m.bias is not None:
342 | nn.init.constant_(m.bias, 0)
343 |
344 | def forward(self, x):
345 | x = self.mlp(x)
346 | x = nn.functional.normalize(x, dim=-1, p=2)
347 | x = self.last_layer(x)
348 | return x
349 |
--------------------------------------------------------------------------------
/dino_model/visualize_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 | import argparse
17 | import cv2
18 | import random
19 | import colorsys
20 | import requests
21 | from io import BytesIO
22 |
23 | import skimage.io
24 | from skimage.measure import find_contours
25 | import matplotlib.pyplot as plt
26 | from matplotlib.patches import Polygon
27 | import torch
28 | import torch.nn as nn
29 | import torchvision
30 | from torchvision import transforms as pth_transforms
31 | import numpy as np
32 | from PIL import Image
33 |
34 | import utils
35 | import vision_transformer as vits
36 |
37 |
38 | def apply_mask(image, mask, color, alpha=0.5):
39 | for c in range(3):
40 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
41 | return image
42 |
43 |
44 | def random_colors(N, bright=True):
45 | """
46 | Generate random colors.
47 | """
48 | brightness = 1.0 if bright else 0.7
49 | hsv = [(i / N, 1, brightness) for i in range(N)]
50 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
51 | random.shuffle(colors)
52 | return colors
53 |
54 |
55 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
56 | fig = plt.figure(figsize=figsize, frameon=False)
57 | ax = plt.Axes(fig, [0., 0., 1., 1.])
58 | ax.set_axis_off()
59 | fig.add_axes(ax)
60 | ax = plt.gca()
61 |
62 | N = 1
63 | mask = mask[None, :, :]
64 | # Generate random colors
65 | colors = random_colors(N)
66 |
67 | # Show area outside image boundaries.
68 | height, width = image.shape[:2]
69 | margin = 0
70 | ax.set_ylim(height + margin, -margin)
71 | ax.set_xlim(-margin, width + margin)
72 | ax.axis('off')
73 | masked_image = image.astype(np.uint32).copy()
74 | for i in range(N):
75 | color = colors[i]
76 | _mask = mask[i]
77 | if blur:
78 | _mask = cv2.blur(_mask,(10,10))
79 | # Mask
80 | masked_image = apply_mask(masked_image, _mask, color, alpha)
81 | # Mask Polygon
82 | # Pad to ensure proper polygons for masks that touch image edges.
83 | if contour:
84 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
85 | padded_mask[1:-1, 1:-1] = _mask
86 | contours = find_contours(padded_mask, 0.5)
87 | for verts in contours:
88 | # Subtract the padding and flip (y, x) to (x, y)
89 | verts = np.fliplr(verts) - 1
90 | p = Polygon(verts, facecolor="none", edgecolor=color)
91 | ax.add_patch(p)
92 | ax.imshow(masked_image.astype(np.uint8), aspect='auto')
93 | fig.savefig(fname)
94 | print(f"{fname} saved.")
95 | return
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser('Visualize Self-Attention maps')
100 | parser.add_argument('--arch', default='deit_small', type=str,
101 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
102 | parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
103 | parser.add_argument('--pretrained_weights', default='', type=str,
104 | help="Path to pretrained weights to load.")
105 | parser.add_argument("--checkpoint_key", default="teacher", type=str,
106 | help='Key to use in the checkpoint (example: "teacher")')
107 | parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
108 | parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
109 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
110 | obtained by thresholding the self-attention maps to keep xx% of the mass.""")
111 | args = parser.parse_args()
112 |
113 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
114 | # build model
115 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
116 | for p in model.parameters():
117 | p.requires_grad = False
118 | model.eval()
119 | model.to(device)
120 | if os.path.isfile(args.pretrained_weights):
121 | state_dict = torch.load(args.pretrained_weights, map_location="cpu")
122 | if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
123 | print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
124 | state_dict = state_dict[args.checkpoint_key]
125 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
126 | msg = model.load_state_dict(state_dict, strict=False)
127 | print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
128 | else:
129 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
130 | url = None
131 | if args.arch == "deit_small" and args.patch_size == 16:
132 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
133 | elif args.arch == "deit_small" and args.patch_size == 8:
134 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
135 | elif args.arch == "vit_base" and args.patch_size == 16:
136 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
137 | elif args.arch == "vit_base" and args.patch_size == 8:
138 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
139 | if url is not None:
140 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
141 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
142 | model.load_state_dict(state_dict, strict=True)
143 | else:
144 | print("There is no reference weights available for this model => We use random weights.")
145 |
146 | # open image
147 | if args.image_path is None:
148 | # user has not specified any image - we use our own image
149 | print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.")
150 | print("Since no image path have been provided, we take the first image in our paper.")
151 | response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
152 | img = Image.open(BytesIO(response.content))
153 | img = img.convert('RGB')
154 | elif os.path.isfile(args.image_path):
155 | with open(args.image_path, 'rb') as f:
156 | img = Image.open(f)
157 | img = img.convert('RGB')
158 | else:
159 | print(f"Provided image path {args.image_path} is non valid.")
160 | sys.exit(1)
161 | transform = pth_transforms.Compose([
162 | pth_transforms.ToTensor(),
163 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
164 | ])
165 | img = transform(img)
166 |
167 | # make the image divisible by the patch size
168 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
169 | img = img[:, :w, :h].unsqueeze(0)
170 |
171 | w_featmap = img.shape[-2] // args.patch_size
172 | h_featmap = img.shape[-1] // args.patch_size
173 |
174 | attentions = model.forward_selfattention(img.to(device))
175 |
176 | nh = attentions.shape[1] # number of head
177 |
178 | # we keep only the output patch attention
179 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
180 |
181 | # we keep only a certain percentage of the mass
182 | val, idx = torch.sort(attentions)
183 | val /= torch.sum(val, dim=1, keepdim=True)
184 | cumval = torch.cumsum(val, dim=1)
185 | th_attn = cumval > (1 - args.threshold)
186 | idx2 = torch.argsort(idx)
187 | for head in range(nh):
188 | th_attn[head] = th_attn[head][idx2[head]]
189 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
190 | # interpolate
191 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
192 |
193 | attentions = attentions.reshape(nh, w_featmap, h_featmap)
194 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
195 |
196 | # save attentions heatmaps
197 | os.makedirs(args.output_dir, exist_ok=True)
198 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))
199 | for j in range(nh):
200 | fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
201 | plt.imsave(fname=fname, arr=attentions[j], format='png')
202 | print(f"{fname} saved.")
203 |
204 | image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
205 | for j in range(nh):
206 | display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)
207 |
--------------------------------------------------------------------------------
/images/kitti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/kitti.png
--------------------------------------------------------------------------------
/images/s3dis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/s3dis.png
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from datasets.s3dis import S3DIS
2 | from datasets.semantic_kitti import SK
3 | from optimization import apply_optimization
4 | from importlib import import_module
5 | import argparse
6 |
7 | def main(args):
8 |
9 | config = import_module(f'datasets.{args.dataset}_config')
10 |
11 | if args.dataset == 's3dis':
12 | dataset_instance = dataset_instance = S3DIS(config)
13 | elif args.dataset == 'semantic_kitti':
14 | dataset_instance = dataset_instance = SK(config)
15 | else:
16 | raise NotImplementedError("Only support for s3dis and semantic kitti dataset!")
17 |
18 | dataset_instance.extract_feature_vecs()
19 | dataset_instance.extract_data_stats()
20 | dataset_instance.extract_scene_clusters()
21 | dataset_instance.extract_scene_attributes()
22 | dataset_instance.extract_pairwise_scene_attributes()
23 |
24 | all_pairs, data_stats, pair_scores, reduction_size, area_threshold = dataset_instance.prepare_data_for_optimization()
25 | selected_scenes = apply_optimization(all_pairs, data_stats, pair_scores, reduction_size, area_threshold)
26 |
27 | dataset_instance.create_initial_set(selected_scenes)
28 |
29 | if __name__ == '__main__':
30 |
31 | parser = argparse.ArgumentParser(description='SeedAL framework.')
32 | parser.add_argument('-d', '--dataset', choices=['s3dis', 'semantic_kitti'], default='semantic_kitti')
33 |
34 | args = parser.parse_args()
35 | main(args)
36 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | from ortools.linear_solver import pywraplp
2 | import numpy as np
3 | from ortools.sat.python import cp_model
4 |
5 |
6 | def apply_optimization(all_pairs, data_stats, pair_scores, reduction_size, area_threshold):
7 |
8 | selected_indexes = lineer_solver(pair_scores, edge_number=reduction_size)
9 | ls_scans = [all_pairs[ind] for ind in selected_indexes]
10 |
11 | unique_scans = []
12 | for scan_t in ls_scans:
13 | if scan_t[0] not in unique_scans:
14 | unique_scans.append(scan_t[0])
15 | if scan_t[1] not in unique_scans:
16 | unique_scans.append(scan_t[1])
17 |
18 | node_len = len(unique_scans)
19 | sub_graph_areas = []
20 | sub_graph_scores = []
21 | sub_graph_pairs = []
22 | sub_aff_mat = np.zeros((node_len,node_len))
23 | for ind, us in enumerate(unique_scans):
24 | for ind2 in range(ind+1,len(unique_scans)):
25 | if (unique_scans[ind], unique_scans[ind2]) in all_pairs:
26 | ii = all_pairs.index((unique_scans[ind], unique_scans[ind2]))
27 | else:
28 | ii = all_pairs.index((unique_scans[ind2], unique_scans[ind]))
29 | sc = pair_scores[ii]
30 | sub_graph_scores.append(sc)
31 | sub_aff_mat[ind, ind2] = sc
32 | sub_graph_pairs.append(all_pairs[ii])
33 |
34 | for us in unique_scans:
35 | sub_graph_areas.append(data_stats[us]['area'])
36 |
37 | selected_edge_indexes, selected_node_indexes =\
38 | lineer_solver_by_node_weight(sub_aff_mat, sub_graph_areas, area_threshold=area_threshold)
39 |
40 | total_area = 0
41 | final_scenes = []
42 | for i in selected_node_indexes:
43 | print(unique_scans[i])
44 | final_scenes.append(unique_scans[i])
45 | total_area += data_stats[unique_scans[i]]['area']
46 |
47 | print(final_scenes)
48 | print(total_area)
49 |
50 | return final_scenes
51 |
52 |
53 | def create_reduction_data_model(edges, edge_number):
54 | data = {}
55 | edges = [np.float64(i) for i in edges]
56 | num_var = len(edges)
57 | num_of_xs = [1.] * num_var
58 |
59 | # constraits on sum of pairwise areas
60 | data['constraint_coeffs'] = [
61 | num_of_xs
62 | ]
63 | data['bounds'] = [edge_number]
64 |
65 | data['obj_coeffs'] = edges # similarity edges
66 |
67 | data['num_vars'] = num_var
68 | data['num_constraints'] = 1
69 |
70 | return data
71 |
72 |
73 | def create_linear_data_model(edge_weights, node_weights, area_threshold):
74 | data = {}
75 | data['edge_weights'] = edge_weights
76 | data['node_weights'] = node_weights
77 | data['area_threshold'] = area_threshold
78 | data['len'] = len(data['node_weights'])
79 | return data
80 |
81 |
82 | def lineer_solver(edges, edge_number):
83 |
84 | data = create_reduction_data_model(edges, edge_number)
85 | # Create the mip solver with the SCIP backend.
86 | solver = pywraplp.Solver.CreateSolver('SCIP')
87 | if not solver:
88 | return
89 |
90 | # infinity = solver.infinity()
91 | x = {}
92 | for j in range(data['num_vars']):
93 | x[j] = solver.BoolVar('x[%i]' % j)
94 |
95 | print('Number of variables =', solver.NumVariables())
96 |
97 | for i in range(data['num_constraints']):
98 | constraint = solver.RowConstraint(0, data['bounds'][i], '')
99 | for j in range(data['num_vars']):
100 | constraint.SetCoefficient(x[j], data['constraint_coeffs'][i][j])
101 | print('Number of constraints =', solver.NumConstraints())
102 |
103 | objective = solver.Objective()
104 | for j in range(data['num_vars']):
105 | objective.SetCoefficient(x[j], data['obj_coeffs'][j])
106 | objective.SetMaximization()
107 |
108 | status = solver.Solve()
109 | selected_indexes= []
110 | if status == pywraplp.Solver.OPTIMAL:
111 | print('Objective value =', solver.Objective().Value())
112 | for j in range(data['num_vars']):
113 | if x[j].solution_value() > 0.0:
114 | print(x[j].name(), ' = ', x[j].solution_value())
115 | selected_indexes.append(j)
116 | print()
117 | print('Problem solved in %f milliseconds' % solver.wall_time())
118 | print('Problem solved in %d iterations' % solver.iterations())
119 | print('Problem solved in %d branch-and-bound nodes' % solver.nodes())
120 | else:
121 | print('The problem does not have an optimal solution.')
122 |
123 | return selected_indexes
124 |
125 |
126 | def lineer_solver_by_node_weight(edge_weights, node_weights, area_threshold):
127 |
128 | data = create_linear_data_model(edge_weights,node_weights,area_threshold)
129 | solver = pywraplp.Solver.CreateSolver('SCIP')
130 |
131 | node_len = data['len']
132 | if not solver:
133 | return
134 |
135 | # Variables
136 | y = {}
137 | for i in range(node_len):
138 | for j in range(i+1, node_len):
139 | y[(i, j)] = solver.BoolVar(f'y_{i}_{j}')
140 |
141 | x = {}
142 | for j in range(node_len):
143 | x[j] = solver.BoolVar(f'x[{j}]')
144 |
145 | # Constraints
146 | for i in range(node_len):
147 | for j in range(i+1, node_len):
148 | solver.Add(y[(i, j)] <= x[i])
149 | solver.Add(y[(i, j)] <= x[j])
150 |
151 | solver.Add(sum(x[i] * data['node_weights'][i] for i in range(node_len)) <= data['area_threshold'])
152 |
153 | solver.Maximize(solver.Sum([y[i,j]*data['edge_weights'][i,j] for i in range(node_len) for j in range(i+1,node_len)]))
154 |
155 | status = solver.Solve()
156 |
157 | selected_node_indexes = []
158 | selected_edge_indexes = []
159 | if status == pywraplp.Solver.OPTIMAL or status == cp_model.FEASIBLE:
160 | print('Objective value =', solver.Objective().Value())
161 | for j in range(node_len):
162 | if x[j].solution_value() > 0.0:
163 | print(x[j].name(), ' = ', x[j].solution_value())
164 | selected_node_indexes.append(j)
165 |
166 | for i in range(node_len):
167 | for j in range(i+1, node_len):
168 | if y[i,j].solution_value() > 0.0:
169 | print(y[i,j].name(), ' = ', y[i,j].solution_value())
170 | selected_edge_indexes.append((i,j))
171 |
172 | print()
173 | print('Problem solved in %f milliseconds' % solver.wall_time())
174 | print('Problem solved in %d iterations' % solver.iterations())
175 | print('Problem solved in %d branch-and-bound nodes' % solver.nodes())
176 | print('Number of bins used:')
177 |
178 | print('Time = ', solver.WallTime(), ' milliseconds')
179 | else:
180 | print('The problem does not have an optimal solution.')
181 |
182 | return selected_edge_indexes, selected_node_indexes
183 |
184 |
185 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | importlib
2 | numpy
3 | pillow
4 | faiss
5 | ortools
6 | pickle5
7 | scikit-learn
8 | scipy
--------------------------------------------------------------------------------
/s3dis_seed/info.txt:
--------------------------------------------------------------------------------
1 | Region Num: 859 and Point Num: 6470008 in the current set
--------------------------------------------------------------------------------
/s3dis_seed/init_label_region.json:
--------------------------------------------------------------------------------
1 | {"Area_2#storage_5": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46],"Area_2#office_9": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67],"Area_2#storage_6": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],"Area_2#office_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52],"Area_2#hallway_11": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39],"Area_2#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],"Area_2#storage_3": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],"Area_2#storage_7": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56],"Area_3#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],"Area_4#office_14": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39],"Area_4#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26],"Area_4#storage_3": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52],"Area_4#storage_2": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],"Area_4#storage_4": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],"Area_6#office_10": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97],"Area_6#office_12": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96],"Area_6#office_15": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]}
--------------------------------------------------------------------------------
/s3dis_seed/init_label_scan.json:
--------------------------------------------------------------------------------
1 | [
2 | "Area_2/coords/storage_5.npy",
3 | "Area_2/coords/office_9.npy",
4 | "Area_2/coords/storage_6.npy",
5 | "Area_2/coords/office_1.npy",
6 | "Area_2/coords/hallway_11.npy",
7 | "Area_2/coords/storage_1.npy",
8 | "Area_2/coords/storage_3.npy",
9 | "Area_2/coords/storage_7.npy",
10 | "Area_3/coords/storage_1.npy",
11 | "Area_4/coords/office_14.npy",
12 | "Area_4/coords/storage_1.npy",
13 | "Area_4/coords/storage_3.npy",
14 | "Area_4/coords/storage_2.npy",
15 | "Area_4/coords/storage_4.npy",
16 | "Area_6/coords/office_10.npy",
17 | "Area_6/coords/office_12.npy",
18 | "Area_6/coords/office_15.npy"
19 | ]
20 |
--------------------------------------------------------------------------------
/s3dis_seed/init_ulabel_scan.json:
--------------------------------------------------------------------------------
1 | [
2 | "Area_1/coords/hallway_6.npy",
3 | "Area_1/coords/office_10.npy",
4 | "Area_1/coords/office_21.npy",
5 | "Area_1/coords/hallway_3.npy",
6 | "Area_1/coords/office_23.npy",
7 | "Area_1/coords/hallway_7.npy",
8 | "Area_1/coords/office_22.npy",
9 | "Area_1/coords/WC_1.npy",
10 | "Area_1/coords/office_9.npy",
11 | "Area_1/coords/office_29.npy",
12 | "Area_1/coords/office_6.npy",
13 | "Area_1/coords/hallway_2.npy",
14 | "Area_1/coords/copyRoom_1.npy",
15 | "Area_1/coords/office_12.npy",
16 | "Area_1/coords/office_17.npy",
17 | "Area_1/coords/hallway_8.npy",
18 | "Area_1/coords/hallway_5.npy",
19 | "Area_1/coords/office_8.npy",
20 | "Area_1/coords/conferenceRoom_2.npy",
21 | "Area_1/coords/office_13.npy",
22 | "Area_1/coords/office_1.npy",
23 | "Area_1/coords/office_27.npy",
24 | "Area_1/coords/office_16.npy",
25 | "Area_1/coords/office_24.npy",
26 | "Area_1/coords/office_14.npy",
27 | "Area_1/coords/office_15.npy",
28 | "Area_1/coords/office_11.npy",
29 | "Area_1/coords/office_28.npy",
30 | "Area_1/coords/office_2.npy",
31 | "Area_1/coords/office_18.npy",
32 | "Area_1/coords/hallway_1.npy",
33 | "Area_1/coords/office_31.npy",
34 | "Area_1/coords/office_25.npy",
35 | "Area_1/coords/office_5.npy",
36 | "Area_1/coords/conferenceRoom_1.npy",
37 | "Area_1/coords/office_19.npy",
38 | "Area_1/coords/pantry_1.npy",
39 | "Area_1/coords/office_30.npy",
40 | "Area_1/coords/hallway_4.npy",
41 | "Area_1/coords/office_26.npy",
42 | "Area_1/coords/office_7.npy",
43 | "Area_1/coords/office_20.npy",
44 | "Area_1/coords/office_3.npy",
45 | "Area_1/coords/office_4.npy",
46 | "Area_2/coords/hallway_6.npy",
47 | "Area_2/coords/office_10.npy",
48 | "Area_2/coords/hallway_3.npy",
49 | "Area_2/coords/hallway_7.npy",
50 | "Area_2/coords/WC_1.npy",
51 | "Area_2/coords/hallway_9.npy",
52 | "Area_2/coords/office_6.npy",
53 | "Area_2/coords/hallway_2.npy",
54 | "Area_2/coords/office_12.npy",
55 | "Area_2/coords/hallway_8.npy",
56 | "Area_2/coords/hallway_5.npy",
57 | "Area_2/coords/office_8.npy",
58 | "Area_2/coords/office_13.npy",
59 | "Area_2/coords/storage_9.npy",
60 | "Area_2/coords/hallway_10.npy",
61 | "Area_2/coords/storage_8.npy",
62 | "Area_2/coords/office_14.npy",
63 | "Area_2/coords/auditorium_2.npy",
64 | "Area_2/coords/office_11.npy",
65 | "Area_2/coords/office_2.npy",
66 | "Area_2/coords/hallway_1.npy",
67 | "Area_2/coords/office_5.npy",
68 | "Area_2/coords/storage_2.npy",
69 | "Area_2/coords/conferenceRoom_1.npy",
70 | "Area_2/coords/WC_2.npy",
71 | "Area_2/coords/auditorium_1.npy",
72 | "Area_2/coords/hallway_12.npy",
73 | "Area_2/coords/hallway_4.npy",
74 | "Area_2/coords/office_7.npy",
75 | "Area_2/coords/office_3.npy",
76 | "Area_2/coords/office_4.npy",
77 | "Area_2/coords/storage_4.npy",
78 | "Area_3/coords/hallway_6.npy",
79 | "Area_3/coords/office_10.npy",
80 | "Area_3/coords/hallway_3.npy",
81 | "Area_3/coords/WC_1.npy",
82 | "Area_3/coords/office_9.npy",
83 | "Area_3/coords/office_6.npy",
84 | "Area_3/coords/hallway_2.npy",
85 | "Area_3/coords/hallway_5.npy",
86 | "Area_3/coords/office_8.npy",
87 | "Area_3/coords/office_1.npy",
88 | "Area_3/coords/office_2.npy",
89 | "Area_3/coords/hallway_1.npy",
90 | "Area_3/coords/office_5.npy",
91 | "Area_3/coords/storage_2.npy",
92 | "Area_3/coords/conferenceRoom_1.npy",
93 | "Area_3/coords/WC_2.npy",
94 | "Area_3/coords/lounge_2.npy",
95 | "Area_3/coords/hallway_4.npy",
96 | "Area_3/coords/office_7.npy",
97 | "Area_3/coords/lounge_1.npy",
98 | "Area_3/coords/office_3.npy",
99 | "Area_3/coords/office_4.npy",
100 | "Area_4/coords/hallway_6.npy",
101 | "Area_4/coords/hallway_14.npy",
102 | "Area_4/coords/office_10.npy",
103 | "Area_4/coords/office_21.npy",
104 | "Area_4/coords/hallway_3.npy",
105 | "Area_4/coords/hallway_7.npy",
106 | "Area_4/coords/office_22.npy",
107 | "Area_4/coords/WC_1.npy",
108 | "Area_4/coords/hallway_9.npy",
109 | "Area_4/coords/office_9.npy",
110 | "Area_4/coords/office_6.npy",
111 | "Area_4/coords/hallway_2.npy",
112 | "Area_4/coords/office_12.npy",
113 | "Area_4/coords/office_17.npy",
114 | "Area_4/coords/hallway_8.npy",
115 | "Area_4/coords/hallway_5.npy",
116 | "Area_4/coords/office_8.npy",
117 | "Area_4/coords/conferenceRoom_2.npy",
118 | "Area_4/coords/office_13.npy",
119 | "Area_4/coords/lobby_1.npy",
120 | "Area_4/coords/office_1.npy",
121 | "Area_4/coords/office_16.npy",
122 | "Area_4/coords/hallway_10.npy",
123 | "Area_4/coords/hallway_11.npy",
124 | "Area_4/coords/office_15.npy",
125 | "Area_4/coords/office_11.npy",
126 | "Area_4/coords/conferenceRoom_3.npy",
127 | "Area_4/coords/office_2.npy",
128 | "Area_4/coords/office_18.npy",
129 | "Area_4/coords/hallway_1.npy",
130 | "Area_4/coords/office_5.npy",
131 | "Area_4/coords/conferenceRoom_1.npy",
132 | "Area_4/coords/WC_2.npy",
133 | "Area_4/coords/office_19.npy",
134 | "Area_4/coords/lobby_2.npy",
135 | "Area_4/coords/hallway_12.npy",
136 | "Area_4/coords/WC_4.npy",
137 | "Area_4/coords/hallway_4.npy",
138 | "Area_4/coords/office_7.npy",
139 | "Area_4/coords/office_20.npy",
140 | "Area_4/coords/office_3.npy",
141 | "Area_4/coords/office_4.npy",
142 | "Area_4/coords/hallway_13.npy",
143 | "Area_4/coords/WC_3.npy",
144 | "Area_6/coords/hallway_6.npy",
145 | "Area_6/coords/office_21.npy",
146 | "Area_6/coords/hallway_3.npy",
147 | "Area_6/coords/office_23.npy",
148 | "Area_6/coords/office_22.npy",
149 | "Area_6/coords/office_9.npy",
150 | "Area_6/coords/office_35.npy",
151 | "Area_6/coords/office_29.npy",
152 | "Area_6/coords/office_32.npy",
153 | "Area_6/coords/office_6.npy",
154 | "Area_6/coords/hallway_2.npy",
155 | "Area_6/coords/copyRoom_1.npy",
156 | "Area_6/coords/office_17.npy",
157 | "Area_6/coords/hallway_5.npy",
158 | "Area_6/coords/office_8.npy",
159 | "Area_6/coords/office_13.npy",
160 | "Area_6/coords/office_1.npy",
161 | "Area_6/coords/office_27.npy",
162 | "Area_6/coords/office_16.npy",
163 | "Area_6/coords/office_24.npy",
164 | "Area_6/coords/office_36.npy",
165 | "Area_6/coords/office_14.npy",
166 | "Area_6/coords/office_34.npy",
167 | "Area_6/coords/office_11.npy",
168 | "Area_6/coords/office_28.npy",
169 | "Area_6/coords/office_2.npy",
170 | "Area_6/coords/office_18.npy",
171 | "Area_6/coords/hallway_1.npy",
172 | "Area_6/coords/office_31.npy",
173 | "Area_6/coords/office_25.npy",
174 | "Area_6/coords/office_5.npy",
175 | "Area_6/coords/conferenceRoom_1.npy",
176 | "Area_6/coords/office_19.npy",
177 | "Area_6/coords/pantry_1.npy",
178 | "Area_6/coords/openspace_1.npy",
179 | "Area_6/coords/office_30.npy",
180 | "Area_6/coords/hallway_4.npy",
181 | "Area_6/coords/office_37.npy",
182 | "Area_6/coords/office_33.npy",
183 | "Area_6/coords/office_26.npy",
184 | "Area_6/coords/office_7.npy",
185 | "Area_6/coords/lounge_1.npy",
186 | "Area_6/coords/office_20.npy",
187 | "Area_6/coords/office_3.npy",
188 | "Area_6/coords/office_4.npy"
189 | ]
190 |
--------------------------------------------------------------------------------
/similarity_model.py:
--------------------------------------------------------------------------------
1 | from dino_model.fe_dino import DinoModel
2 | import numpy as np
3 | from PIL import Image
4 | import faiss
5 |
6 |
7 | class SimilarityModel:
8 | def __init__(self, sim_model_name='dino', feat_type='cls'):
9 | self.sim_model_name = sim_model_name
10 | self.feat_type = feat_type
11 | self.load_sim_model()
12 | self.d = None
13 |
14 | def get_sim_vec_single(self, im_name_or_data):
15 | # Prepare an image
16 | image = self.load_image(im_name_or_data)
17 | if image is None:
18 | print('No Image')
19 | return np.zeros((1, self.d), dtype=np.float32)
20 | # test image
21 | feature_vec = self.sim_model(image).cpu().data.numpy().astype(np.float32) # convert to numpy array
22 | return feature_vec
23 |
24 | def load_sim_model(self):
25 | if self.sim_model_name == 'dino':
26 | self.sim_model = self.load_sim_model_dino()
27 | self.d = 768
28 | else:
29 | raise Exception("sim_model_name must be dino!")
30 |
31 | def load_sim_model_dino(self):
32 | # Build models
33 | sim_model = DinoModel(self.feat_type) # eval mode (batch norm uses moving mean/variance)
34 | return sim_model
35 |
36 | def set_sim_model_feat_type_dino(self, feat_type):
37 | self.feat_type = feat_type
38 | self.sim_model.feat_type = self.feat_type
39 |
40 | def load_image(self, image_path_or_data):
41 | try:
42 | if isinstance(image_path_or_data, str):
43 | img = Image.open(image_path_or_data)
44 | image = img.convert('RGB')
45 | elif isinstance(image_path_or_data, Image.Image):
46 | image = image_path_or_data.convert('RGB')
47 | else:
48 | raise Exception("image type must be str or PIL.Image!")
49 |
50 | return image
51 |
52 | except Exception as e:
53 | return None
54 |
55 | def calculate_dino_aff_matrix_from_feats(self, feats):
56 | curr_dim = len(feats)
57 | faiss.normalize_L2(feats)
58 | index = faiss.IndexFlatIP(feats.shape[1])
59 | index.add(feats) # add vectors to the index
60 | lims, D, I = index.range_search(feats, -1)
61 | ordered_D = np.reshape(D, (curr_dim, curr_dim))
62 | return ordered_D
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/sk_seed/init_label_scan.json:
--------------------------------------------------------------------------------
1 | [
2 | "05/velodyne/001429.bin",
3 | "06/velodyne/000711.bin",
4 | "01/velodyne/000876.bin",
5 | "05/velodyne/001092.bin",
6 | "02/velodyne/002455.bin",
7 | "09/velodyne/000415.bin",
8 | "06/velodyne/000175.bin",
9 | "05/velodyne/001679.bin",
10 | "00/velodyne/002698.bin",
11 | "00/velodyne/000988.bin",
12 | "02/velodyne/003218.bin",
13 | "00/velodyne/000932.bin",
14 | "06/velodyne/000683.bin",
15 | "05/velodyne/002075.bin",
16 | "06/velodyne/000078.bin",
17 | "09/velodyne/000855.bin",
18 | "00/velodyne/000440.bin",
19 | "10/velodyne/000178.bin",
20 | "07/velodyne/000585.bin",
21 | "00/velodyne/000397.bin",
22 | "00/velodyne/001569.bin",
23 | "00/velodyne/001872.bin",
24 | "06/velodyne/000877.bin",
25 | "10/velodyne/000363.bin",
26 | "00/velodyne/004369.bin",
27 | "05/velodyne/002224.bin",
28 | "05/velodyne/001342.bin",
29 | "06/velodyne/000064.bin",
30 | "05/velodyne/001075.bin",
31 | "05/velodyne/001499.bin",
32 | "05/velodyne/001979.bin",
33 | "10/velodyne/000283.bin",
34 | "06/velodyne/000263.bin",
35 | "09/velodyne/001225.bin",
36 | "07/velodyne/000640.bin",
37 | "05/velodyne/002395.bin",
38 | "00/velodyne/001824.bin",
39 | "03/velodyne/000662.bin",
40 | "00/velodyne/003118.bin",
41 | "06/velodyne/001007.bin",
42 | "06/velodyne/000504.bin",
43 | "05/velodyne/002215.bin",
44 | "09/velodyne/000910.bin",
45 | "00/velodyne/001059.bin",
46 | "06/velodyne/000492.bin",
47 | "06/velodyne/000323.bin",
48 | "07/velodyne/000177.bin",
49 | "06/velodyne/000898.bin",
50 | "07/velodyne/000649.bin",
51 | "10/velodyne/000981.bin",
52 | "05/velodyne/002405.bin",
53 | "07/velodyne/000401.bin",
54 | "01/velodyne/000376.bin",
55 | "09/velodyne/001127.bin",
56 | "05/velodyne/002298.bin",
57 | "02/velodyne/002078.bin",
58 | "05/velodyne/002417.bin",
59 | "05/velodyne/001193.bin",
60 | "05/velodyne/002630.bin",
61 | "05/velodyne/000396.bin",
62 | "01/velodyne/000006.bin",
63 | "06/velodyne/000892.bin",
64 | "10/velodyne/000165.bin",
65 | "05/velodyne/002623.bin",
66 | "06/velodyne/000198.bin",
67 | "06/velodyne/001042.bin",
68 | "10/velodyne/000495.bin",
69 | "00/velodyne/003240.bin",
70 | "00/velodyne/003262.bin",
71 | "09/velodyne/001099.bin",
72 | "09/velodyne/000212.bin",
73 | "00/velodyne/003003.bin",
74 | "02/velodyne/002556.bin",
75 | "05/velodyne/001957.bin",
76 | "05/velodyne/001318.bin",
77 | "06/velodyne/000069.bin",
78 | "05/velodyne/001728.bin",
79 | "02/velodyne/002546.bin",
80 | "07/velodyne/000917.bin",
81 | "05/velodyne/002106.bin",
82 | "05/velodyne/002353.bin",
83 | "06/velodyne/000149.bin",
84 | "00/velodyne/001262.bin",
85 | "04/velodyne/000148.bin",
86 | "06/velodyne/000983.bin",
87 | "00/velodyne/004361.bin",
88 | "02/velodyne/002073.bin",
89 | "01/velodyne/000217.bin",
90 | "02/velodyne/000058.bin",
91 | "06/velodyne/000207.bin",
92 | "00/velodyne/002543.bin",
93 | "02/velodyne/002874.bin",
94 | "01/velodyne/001063.bin",
95 | "09/velodyne/000721.bin",
96 | "10/velodyne/000369.bin",
97 | "05/velodyne/000527.bin",
98 | "00/velodyne/000126.bin",
99 | "07/velodyne/000026.bin",
100 | "10/velodyne/000329.bin",
101 | "05/velodyne/000125.bin",
102 | "05/velodyne/000844.bin",
103 | "06/velodyne/000663.bin",
104 | "06/velodyne/000551.bin",
105 | "09/velodyne/000568.bin",
106 | "05/velodyne/001121.bin",
107 | "10/velodyne/001083.bin",
108 | "00/velodyne/000068.bin",
109 | "06/velodyne/000411.bin",
110 | "07/velodyne/000142.bin",
111 | "10/velodyne/001014.bin",
112 | "05/velodyne/002535.bin",
113 | "10/velodyne/000898.bin",
114 | "02/velodyne/001387.bin",
115 | "00/velodyne/000108.bin",
116 | "05/velodyne/001854.bin",
117 | "06/velodyne/001035.bin",
118 | "00/velodyne/003276.bin",
119 | "09/velodyne/001172.bin",
120 | "02/velodyne/001533.bin",
121 | "00/velodyne/002709.bin",
122 | "10/velodyne/000688.bin",
123 | "06/velodyne/000575.bin",
124 | "10/velodyne/000356.bin",
125 | "00/velodyne/004414.bin",
126 | "02/velodyne/001931.bin",
127 | "09/velodyne/000021.bin",
128 | "00/velodyne/002806.bin",
129 | "07/velodyne/000041.bin",
130 | "00/velodyne/002815.bin",
131 | "06/velodyne/000398.bin",
132 | "07/velodyne/000294.bin",
133 | "09/velodyne/001477.bin",
134 | "10/velodyne/000115.bin",
135 | "01/velodyne/000640.bin",
136 | "10/velodyne/000559.bin",
137 | "05/velodyne/000360.bin",
138 | "09/velodyne/001349.bin",
139 | "06/velodyne/000675.bin",
140 | "06/velodyne/000219.bin",
141 | "00/velodyne/003875.bin",
142 | "09/velodyne/001181.bin",
143 | "10/velodyne/000995.bin",
144 | "01/velodyne/000018.bin",
145 | "05/velodyne/002376.bin",
146 | "00/velodyne/002726.bin",
147 | "03/velodyne/000515.bin",
148 | "05/velodyne/002497.bin",
149 | "00/velodyne/002848.bin",
150 | "10/velodyne/000878.bin",
151 | "05/velodyne/001800.bin",
152 | "05/velodyne/000981.bin",
153 | "05/velodyne/000339.bin",
154 | "05/velodyne/002430.bin",
155 | "02/velodyne/002908.bin",
156 | "00/velodyne/001421.bin",
157 | "05/velodyne/000678.bin",
158 | "09/velodyne/001113.bin",
159 | "10/velodyne/000570.bin",
160 | "00/velodyne/003911.bin",
161 | "07/velodyne/000644.bin",
162 | "00/velodyne/004513.bin",
163 | "00/velodyne/003928.bin",
164 | "06/velodyne/000049.bin",
165 | "09/velodyne/000183.bin",
166 | "10/velodyne/000301.bin",
167 | "09/velodyne/000405.bin",
168 | "00/velodyne/002741.bin",
169 | "00/velodyne/001124.bin",
170 | "09/velodyne/000622.bin",
171 | "02/velodyne/004591.bin",
172 | "10/velodyne/000581.bin",
173 | "05/velodyne/002357.bin",
174 | "05/velodyne/001814.bin",
175 | "01/velodyne/000034.bin",
176 | "02/velodyne/000553.bin",
177 | "05/velodyne/000917.bin",
178 | "02/velodyne/004509.bin",
179 | "09/velodyne/000591.bin",
180 | "10/velodyne/000374.bin",
181 | "06/velodyne/001027.bin",
182 | "00/velodyne/003505.bin",
183 | "05/velodyne/000136.bin",
184 | "07/velodyne/000634.bin",
185 | "09/velodyne/000392.bin"
186 | ]
187 |
--------------------------------------------------------------------------------