├── .gitignore ├── README.md ├── datasets ├── base_dataset.py ├── s3dis.py ├── s3dis_config.py ├── s3dis_regions.json ├── semantic_kitti.py ├── semantic_kitti_areas.json ├── semantic_kitti_config.py └── semantic_kitti_regions.json ├── dino_model ├── LICENSE ├── README.md ├── eval_knn.py ├── eval_linear.py ├── fe_dino.py ├── hubconf.py ├── main_dino.py ├── run_with_submitit.py ├── utils.py ├── video_generation.py ├── vision_transformer.py └── visualize_attention.py ├── images ├── kitti.png ├── s3dis.png └── teaser.png ├── main.py ├── optimization.py ├── requirements.txt ├── s3dis_seed ├── info.txt ├── init_label_region.json ├── init_label_scan.json ├── init_ulabel_region.json └── init_ulabel_scan.json ├── similarity_model.py └── sk_seed ├── init_label_large_region.json ├── init_label_scan.json ├── init_ulabel_large_region.json └── init_ulabel_scan.json /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | eggs/ 4 | .eggs/ 5 | *.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation 2 | 3 | Official PyTorch implementation of SeedAL. 4 | 5 | > [**You Never Get a Second Chance To Make a Good First Impression: 6 | Seeding Active Learning for 3D Semantic Segmentation**](http://arxiv.org/abs/2304.11762), 7 | > [Nermin Samet](https://nerminsamet.github.io/), [Oriane Siméoni](https://osimeoni.github.io/), [Gilles Puy](https://sites.google.com/site/puygilles/), [Georgy Ponimatkin](https://ponimatkin.github.io/), [Renaud Marlet](http://imagine.enpc.fr/~marletr/), [Vincent Lepetit](https://vincentlepetit.github.io/), \ 8 | > *ICCV 2023.* 9 | 10 | ## Summary 11 | Active learning seeds (initial sets) have 12 | significant effect on the performances of 13 | active learning methods. Below, we show the variability of results obtained with 20 different random seeds (blue dashed lines), within an initial annotation budget of 3% of the dataset, when using various active learning methods for 3D semantic segmentation of S3DIS. 14 | We compare it to the result obtained with our seed selection strategy (solid red line), SeedAL, which performs 15 | better or on par with the best (lucky) random seeds among 20, and “protects” from very bad (unlucky) random seeds. 16 | 17 | 18 | 19 | We propose SeedAL for automatically constructing a seed that will ensure good performance for AL. 20 | Assuming that images of the point clouds are available, which is common, our method relies on powerful unsupervised image features 21 | to measure the diversity of the point clouds. 22 | It selects the point clouds for the seed by optimizing the diversity under an annotation budget. 23 | 24 | ## Outdoor Evaluation Results on Semantic KITTI 25 | 26 | 27 | 28 | ## Indoor Evaluation Results on S3DIS 29 | 30 | 31 | 32 | ## Installation 33 | 34 | Given images of the point cloud scenes, 35 | SeedAL outputs an initial set (or seed) under a certain budget to start the training 36 | of active learning methods for 3D point cloud semantic segmentation. 37 | In our experiments we have used [ReDAL](https://github.com/tsunghan-wu/ReDAL) framework which implements several active learning methods. 38 | For S3DIS and Semantic KITTI, SeedAL also follows the [data preparation](https://github.com/tsunghan-wu/ReDAL/tree/main/data_preparation) of ReDAL to be consistent with AL trainings. 39 | 40 | ### Environmental Setup 41 | 42 | The codebase was tested on Ubuntu 20.04 with Cuda 11.1. 43 | NVIDIA GPUs are needed for extracting the DINO features. 44 | 45 | ~~~ 46 | conda create --name seedal python=3.7 47 | conda activate seedal 48 | conda install pytorch==1.10.2 torchvision==0.11.3 cudatoolkit=11.1 -c pytorch 49 | pip install -r requirements.txt 50 | ~~~ 51 | 52 | After setting up the data and environment, you can run SeedAL as follows: 53 | 54 | ~~~ 55 | python main.py -d 56 | ~~~ 57 | 58 | By default, we save the results from intermediate steps such dataset statistics, DINO features and calculated scene diversities. 59 | 60 | ## Acknowledgement 61 | 62 | This work was granted access to the HPC resources of IDRIS under the allocation 2023-AD011013267R1 made by GENCI. 63 | 64 | 65 | ## License 66 | 67 | SeedAL is released under the [Apache 2.0 License](LICENSE). 68 | 69 | ## Citation 70 | 71 | If you find SeedAL useful for your research, please cite our paper as follows. 72 | 73 | > Nermin Samet, Oriane Siméoni, Gilles Puy, Georgy Ponimatkin, Renaud Marlet, Vincent Lepetit, "You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation", 74 | > In IEEE International Conference on Computer Vision (ICCV), 2023. 75 | 76 | 77 | BibTeX entry: 78 | ``` 79 | @inproceedings{seedal, 80 | author = {Nermin Samet, Oriane Siméoni, Gilles Puy, Georgy Ponimatkin, Renaud Marlet and Vincent Lepetit}, 81 | title = {You Never Get a Second Chance To Make a Good First Impression: Seeding Active Learning for 3D Semantic Segmentation}, 82 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 83 | year = {2023}, 84 | } 85 | ``` 86 | 87 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from os import listdir 4 | from os.path import isfile, join 5 | import os 6 | from similarity_model import SimilarityModel 7 | import pickle 8 | import pickle5 as pickle 9 | from sklearn.cluster import KMeans 10 | from scipy import stats 11 | import numpy as np 12 | 13 | class PCData: 14 | 15 | def __init__(self, base_config): 16 | 17 | # path related attributes 18 | self.dataset_name = None 19 | self.main_2d_path = None 20 | self.main_3d_path = None 21 | self.scene_relative_path_to_rgb = None 22 | self.save_path = None 23 | 24 | # inheretied class related attributes 25 | self.train_scenes = None 26 | self.class_num = None 27 | self.cluster_num = None 28 | self.target_point_num = None 29 | self.reduction_size = None 30 | self.seed_name = None 31 | 32 | for k, v in base_config.items(): 33 | setattr(self, k, v) 34 | 35 | isExist = os.path.exists(self.save_path) 36 | if not isExist: 37 | os.makedirs(self.save_path) 38 | 39 | self.sim_object = SimilarityModel() 40 | 41 | self.cluster_centers_file = self.save_path + f'{self.dataset_name}_scene_clusters_{self.cluster_num}.pkl' 42 | self.scene_attributes_file = self.save_path + f'{self.dataset_name}_attribute_dict.pkl' 43 | self.pairwise_scene_attributes_file = self.save_path + f'{self.dataset_name}_pairwise_attribute_dict.pkl' 44 | self.data_stats_file = self.save_path + f'{self.dataset_name}_stats.pkl' 45 | self.feature_vec_dict_file = self.save_path + f'{self.dataset_name}_{self.sim_object.sim_model_name}_features.pkl' 46 | 47 | # attributes to be calculated 48 | self.cluster_centers = None 49 | self.scene_attributes = None 50 | self.pairwise_scene_attributes = None 51 | self.data_stats = None 52 | self.feat_vec_image_dict = None 53 | 54 | self.all_images = self.get_all_data_rgb_names() 55 | self.all_scenes = None 56 | 57 | def load_feature_vec_dict(self): 58 | if os.path.exists(self.feature_vec_dict_file): 59 | with open(self.feature_vec_dict_file, 'rb') as handle: 60 | self.feat_vec_image_dict = pickle.load(handle) 61 | print(f'{self.feature_vec_dict_file} features are loaded') 62 | return True 63 | else: 64 | return False 65 | 66 | def load_data_stats(self): 67 | if os.path.exists(self.data_stats_file): 68 | with open(self.data_stats_file, 'rb') as handle: 69 | self.data_stats = pickle.load(handle) 70 | print(f'{self.data_stats_file} features are loaded') 71 | return True 72 | else: 73 | return False 74 | 75 | def load_cluster_centers(self): 76 | if os.path.exists(self.cluster_centers_file): 77 | with open(self.cluster_centers_file, 'rb') as handle: 78 | self.cluster_centers = pickle.load(handle) 79 | print(f'{self.cluster_centers_file} features are loaded') 80 | return True 81 | else: 82 | return False 83 | 84 | def load_scene_attributes(self): 85 | if os.path.exists(self.scene_attributes_file): 86 | with open(self.scene_attributes_file, 'rb') as handle: 87 | self.scene_attributes = pickle.load(handle) 88 | print(f'{self.scene_attributes_file} features are loaded') 89 | return True 90 | else: 91 | return False 92 | 93 | def load_pairwise_scene_attributes(self): 94 | if os.path.exists(self.pairwise_scene_attributes_file): 95 | with open(self.pairwise_scene_attributes_file, 'rb') as handle: 96 | self.pairwise_scene_attributes = pickle.load(handle) 97 | print(f'{self.pairwise_scene_attributes_file} features are loaded') 98 | return True 99 | else: 100 | return False 101 | 102 | def get_all_data_rgb_names(self): 103 | train_scenes = [x.lower() for x in self.train_scenes] 104 | 105 | all_files = [] 106 | for folder in train_scenes: 107 | curr_dir = self.main_2d_path + folder + self.scene_relative_path_to_rgb 108 | onlyfiles = [f for f in listdir(curr_dir) if isfile(join(curr_dir, f))] 109 | onlyfiles = [(folder + self.scene_relative_path_to_rgb + '/' + word) for word in onlyfiles] 110 | all_files = all_files + onlyfiles 111 | 112 | return all_files 113 | 114 | def cluster_scene(self, cluster_num,im_feats): 115 | 116 | kmeans = KMeans(n_clusters=cluster_num, random_state=123).fit(im_feats) 117 | return kmeans.cluster_centers_ 118 | 119 | def extract_scene_attributes(self): 120 | 121 | ret = self.load_scene_attributes() 122 | if ret: # if we managed to load then no need to run again! 123 | return 124 | 125 | all_scenes, all_scene_keys = self.get_scenes() 126 | 127 | self.load_cluster_centers() 128 | 129 | scene_attributes = {} 130 | 131 | for ind, kk in enumerate(all_scene_keys): 132 | xbb = self.cluster_centers.get(kk, None) 133 | 134 | if xbb is not None: 135 | xbb = self.cluster_centers[kk] 136 | curr_dim = len(xbb) 137 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(xbb) 138 | final_distance_list = list(ordered_D[np.triu_indices(curr_dim, 1)]) 139 | 140 | d = {} 141 | d[f"{kk}"] = {} 142 | d[f"{kk}"]['mean'] = stats.describe(final_distance_list).mean 143 | d[f"{kk}"]['variance'] = stats.describe(final_distance_list).variance 144 | d[f"{kk}"]['minmax'] = stats.describe(final_distance_list).minmax 145 | scene_attributes.update(d) 146 | 147 | self.scene_attributes = scene_attributes 148 | f = open(self.scene_attributes_file, "wb") 149 | pickle.dump(self.scene_attributes, f) 150 | f.close() 151 | print(f'Len of final rooms {len(self.scene_attributes)}') 152 | 153 | def extract_pairwise_scene_attributes(self): 154 | 155 | ret = self.load_pairwise_scene_attributes() 156 | if ret: # if we managed to load then no need to run again! 157 | return 158 | 159 | all_scenes, all_scene_keys = self.get_scenes() 160 | total_scene_number = len(all_scene_keys) 161 | 162 | self.load_cluster_centers() 163 | 164 | pairwise_scene_attributes = {} 165 | 166 | for ind, s1 in enumerate(all_scene_keys): 167 | for st in range(ind + 1, total_scene_number): 168 | s2 = all_scene_keys[st] 169 | kk = f'{s1}*{s2}' 170 | 171 | cluster_centers_s1 = self.cluster_centers[s1] 172 | cluster_centers_s2 = self.cluster_centers[s2] 173 | 174 | curr_dim = len(cluster_centers_s1) + len(cluster_centers_s2) 175 | xbb = np.zeros((curr_dim, cluster_centers_s1.shape[1]), dtype=np.float32) 176 | for ii, v in enumerate(cluster_centers_s1): 177 | xbb[ii] = v 178 | for ii, v in enumerate(cluster_centers_s2): 179 | xbb[len(cluster_centers_s1) + ii] = v 180 | 181 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(xbb) 182 | 183 | final_distance_list = list( 184 | ordered_D[len(cluster_centers_s1):, 0: len(cluster_centers_s2)].flatten()) 185 | 186 | d = {} 187 | d[f"{kk}"] = {} 188 | d[f"{kk}"]['mean'] = stats.describe(final_distance_list).mean 189 | d[f"{kk}"]['variance'] = stats.describe(final_distance_list).variance 190 | d[f"{kk}"]['minmax'] = stats.describe(final_distance_list).minmax 191 | pairwise_scene_attributes.update(d) 192 | 193 | self.pairwise_scene_attributes = pairwise_scene_attributes 194 | # create a binary pickle file 195 | f = open(self.pairwise_scene_attributes_file, "wb") 196 | pickle.dump(self.pairwise_scene_attributes, f) 197 | f.close() 198 | print(f'Len: {len(self.pairwise_scene_attributes)}') 199 | 200 | def prepare_data_for_optimization(self): 201 | 202 | self.load_scene_attributes() 203 | self.load_data_stats() 204 | self.load_pairwise_scene_attributes() 205 | 206 | all_scenes, all_scene_keys = self.get_scenes() 207 | total_scene_number = len(all_scene_keys) 208 | 209 | pair_scores = [] 210 | all_pairs = [] 211 | for i, scene_i in enumerate(all_scene_keys): 212 | for j in range(i + 1, total_scene_number): 213 | scene_j = all_scene_keys[j] 214 | dsim_i = 1 - self.scene_attributes[scene_i]['mean'] 215 | dsim_j = 1 - self.scene_attributes[scene_j]['mean'] 216 | 217 | dsim = dsim_i * dsim_j 218 | 219 | kk = f'{scene_i}*{scene_j}' 220 | 221 | pairwise_sim = self.pairwise_scene_attributes[kk]['mean'] 222 | pairwise_dsim = 1 - pairwise_sim 223 | final_score = pairwise_dsim * dsim 224 | 225 | pair_scores.append(final_score) 226 | all_pairs.append((scene_i, scene_j)) 227 | 228 | return all_pairs, self.data_stats, pair_scores, self.reduction_size, self.target_point_num 229 | 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /datasets/s3dis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from PIL import Image 6 | from os import listdir 7 | from os.path import isfile, join 8 | from scipy import stats 9 | from .base_dataset import PCData 10 | 11 | 12 | class S3DIS(PCData): 13 | 14 | def __init__(self, config): 15 | 16 | base_config = getattr(config, "base_parameters") 17 | super().__init__(base_config) 18 | 19 | s3dis_parameters = getattr(config, "s3dis_parameters") 20 | for k, v in s3dis_parameters.items(): 21 | setattr(self, k, v) 22 | 23 | def extract_feature_vecs(self): 24 | 25 | ret = self.load_feature_vec_dict() 26 | if ret: 27 | return 28 | feat_vec_image_dict = {} 29 | print("Extracting Feature Vectors!") 30 | for file_name in self.all_images: 31 | print(self.main_2d_path + file_name) 32 | im = Image.open(self.main_2d_path+file_name) 33 | feature_vec = self.sim_object.get_sim_vec_single(im) 34 | name_split = file_name.split('/') 35 | scene = name_split[0] 36 | further_split = name_split[3].split('_') 37 | camera = further_split[1] 38 | room = further_split[2] + '_' + further_split[3] 39 | frame = further_split[5] 40 | feat_vec_image_dict[f"{scene}_{room}_{camera}_{frame}"] = feature_vec 41 | 42 | self.feat_vec_image_dict = feat_vec_image_dict 43 | # lets dump 44 | with open(self.feature_vec_dict_file, 'wb') as handle: 45 | pickle.dump(self.feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 46 | 47 | def extract_data_stats(self): 48 | 49 | ret = self.load_data_stats() 50 | if ret: # if we managed to load then no need to run again! 51 | return 52 | 53 | f = open('./datasets/s3dis_regions.json') 54 | ulabel_region = json.load(f) 55 | 56 | s3dis_stats = {} 57 | total_point_number = 0 58 | total_region_number = 0 59 | for folder in self.train_scenes: 60 | curr_dir = self.main_3d_path + folder + '/supervoxel' 61 | annot_dir = self.main_3d_path + folder + '/labels' 62 | onlyfiles = [f for f in listdir(curr_dir) if isfile(join(curr_dir, f))] 63 | 64 | for room in onlyfiles: 65 | supvox = np.load(curr_dir + '/' + room) 66 | annots = np.load(annot_dir + '/' + room) 67 | preserving_labels = ulabel_region[f'{folder}#{room[:-4]}'] 68 | 69 | (unique, counts) = np.unique(supvox, return_counts=True) 70 | (annot_unique, annot_counts) = np.unique(annots, return_counts=True) 71 | dict_annots = {} 72 | for au, ac in zip(annot_unique, annot_counts): 73 | dict_annots[self.label_2_name[au]] = ac 74 | 75 | indices_preserving_labels = [unique.tolist().index(x) for x in preserving_labels] 76 | unique = unique[indices_preserving_labels] 77 | counts = counts[indices_preserving_labels] 78 | 79 | frequencies = np.asarray((unique, counts)).T 80 | d = {} 81 | key_name = folder + '_' + room[:-4] 82 | d[f"{key_name}"] = {} 83 | d[f"{key_name}"]['area'] = frequencies[:, 1].sum() 84 | 85 | mask = np.isin(supvox, preserving_labels) 86 | if frequencies[:, 1].sum() != mask.sum(): 87 | print('Something is wrong about AREA!') 88 | 89 | d[f"{key_name}"]['frequencies'] = frequencies 90 | d[f"{key_name}"]['preserving_labels'] = preserving_labels 91 | 92 | d[f"{key_name}"]['annot_stats'] = dict_annots 93 | 94 | total_point_number += frequencies[:, 1].sum() 95 | total_region_number += len(preserving_labels) 96 | 97 | s3dis_stats.update(d) 98 | 99 | self.data_stats = s3dis_stats 100 | f = open(self.data_stats_file, "wb") 101 | pickle.dump(self.data_stats, f) 102 | f.close() 103 | print(f'Point Num in Total {total_point_number}') 104 | print(f'Region Num in Total {total_region_number}') 105 | 106 | def extract_scene_clusters(self): 107 | 108 | ret = self.load_cluster_centers() 109 | if ret: # if we managed to load then no need to run again! 110 | return 111 | 112 | all_scenes, all_scene_keys = self.get_scenes() 113 | self.cluster_centers = {} 114 | 115 | for ind, s1 in enumerate(all_scene_keys): 116 | im_feats = np.squeeze(np.asarray(all_scenes[s1]), axis=1) 117 | 118 | cluster_num = self.cluster_num 119 | if len(im_feats) < cluster_num: 120 | cluster_num = len(im_feats) 121 | 122 | self.cluster_centers[s1] = self.cluster_scene(cluster_num, im_feats) 123 | 124 | f = open(self.cluster_centers_file, "wb") 125 | pickle.dump(self.cluster_centers, f) 126 | f.close() 127 | 128 | def get_scenes(self): 129 | all_scenes = {} 130 | self.load_feature_vec_dict() 131 | for scene in self.train_scenes: 132 | for room in self.rooms: 133 | for i in range(self.max_room_num): 134 | values = None 135 | kk = scene + '_' + room + str(i) + '_' 136 | values = [value for key, value in self.feat_vec_image_dict.items() if kk.lower() in key.lower()] 137 | if values: 138 | all_scenes[kk[:-1]] = values 139 | return all_scenes, list(all_scenes.keys()) 140 | 141 | def create_initial_set(self, scene_list): 142 | 143 | self.load_data_stats() 144 | 145 | scan_num = len(self.data_stats) 146 | all_keys = list(self.data_stats.keys()) 147 | all_values = list(self.data_stats.values()) 148 | 149 | selected_samples = [] 150 | for scn in scene_list: 151 | selected_samples.append(all_keys.index(scn)) 152 | 153 | path = os.path.join('.', self.seed_name) 154 | os.mkdir(path) 155 | 156 | f = open(path + "/init_label_scan.json", "w") 157 | fu = open(path + "/init_ulabel_scan.json", "w") 158 | f.write("[\n") 159 | fu.write("[\n") 160 | 161 | for j in range(scan_num): 162 | curr_name = all_keys[j] 163 | splits = curr_name.split('_') 164 | if j in selected_samples: 165 | f.write(f'"{splits[0]}_{splits[1]}/coords/{splits[2]}_{splits[3]}.npy",\n') 166 | else: 167 | fu.write(f'"{splits[0]}_{splits[1]}/coords/{splits[2]}_{splits[3]}.npy",\n') 168 | 169 | f.seek(f.tell() - 2) 170 | fu.seek(fu.tell() - 2) 171 | f.write("\n]\n") 172 | fu.write("\n]\n") 173 | f.close() 174 | fu.close() 175 | 176 | f = open(path+"/init_label_region.json", "w") 177 | fu = open(path+"/init_ulabel_region.json", "w") 178 | f.write("{") 179 | fu.write("{") 180 | 181 | total_point_num = 0 182 | total_region_num = 0 183 | for j in range(scan_num): 184 | curr_name = all_keys[j] 185 | splits = curr_name.split('_') 186 | supervoxel_list = str(all_values[j]['frequencies'][:,0].tolist()) 187 | if j in selected_samples: 188 | f.write(f'"{splits[0]}_{splits[1]}#{splits[2]}_{splits[3]}": {supervoxel_list},') 189 | total_point_num+= all_values[j]['area'] 190 | total_region_num += len(all_values[j]['frequencies']) 191 | else: 192 | fu.write(f'"{splits[0]}_{splits[1]}#{splits[2]}_{splits[3]}": {supervoxel_list},') 193 | 194 | f.seek(f.tell() - 1) 195 | fu.seek(fu.tell() - 1) 196 | f.write("}") 197 | fu.write("}") 198 | f.close() 199 | fu.close() 200 | 201 | fi = open(path + "/info.txt", "w") 202 | fi.write(f'Region Num: {total_region_num} and Point Num: {total_point_num} in the current set') 203 | fi.close() 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /datasets/s3dis_config.py: -------------------------------------------------------------------------------- 1 | 2 | base_parameters = dict( 3 | dataset_name = 's3dis', 4 | main_2d_path = 'path/to/2D_modalities/', 5 | main_3d_path = 'path/to/S3DIS_processed/', 6 | scene_relative_path_to_rgb = '/data/rgb', 7 | save_path = './s3dis_attribute_outputs/', 8 | train_scenes = ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'], 9 | class_num = 13, 10 | cluster_num = 13, 11 | target_point_num = 6500000, 12 | reduction_size = 80, 13 | target_region_num = 790, 14 | seed_name = 's3dis_seed' 15 | ) 16 | 17 | s3dis_parameters = dict( 18 | rooms = ['auditorium_', 'conferenceRoom_', 'copyRoom_', 'hallway_', 'lobby_', 'lounge_', 'office_', 'storage_', 'pantry_', 'WC_', 'openspace_'], 19 | max_room_num = 40, 20 | label_2_name = {0: 'ceiling', 21 | 1: 'floor', 22 | 2: 'wall', 23 | 3: 'beam', 24 | 4: 'column', 25 | 5: 'window', 26 | 6: 'door', 27 | 7: 'chair', 28 | 8: 'table', 29 | 9: 'bookcase', 30 | 10: 'sofa', 31 | 11: 'board', 32 | 12: 'clutter'}, 33 | ) 34 | -------------------------------------------------------------------------------- /datasets/semantic_kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from PIL import Image 6 | from .base_dataset import PCData 7 | 8 | 9 | class SK(PCData): 10 | 11 | def __init__(self, config): 12 | 13 | base_config = getattr(config, "base_parameters") 14 | super().__init__(base_config) 15 | 16 | s3dis_parameters = getattr(config, "sk_parameters") 17 | for k, v in s3dis_parameters.items(): 18 | setattr(self, k, v) 19 | 20 | self.cls_feature_vec_dict_file = self.save_path + f'{self.dataset_name}_{self.sim_object.sim_model_name}_cls_features.pkl' 21 | self.cls_feat_vec_image_dict = None 22 | 23 | self.sparsified_dataset_file = self.save_path + f'{self.dataset_name}_sparsified.pkl' 24 | self.sparsified_dataset = None 25 | 26 | def extract_feature_vecs(self): 27 | self.extract_cls_features() 28 | self.sparsify_dataset() 29 | self.extract_patch_features() 30 | 31 | def extract_data_stats(self): 32 | 33 | ret = self.load_data_stats() 34 | if ret: # if we managed to load then no need to run again! 35 | return 36 | 37 | with open('./datasets/semantic_kitti_areas.json', 'r') as f: 38 | supvox_pts = json.load(f) 39 | supvox_keys = list(supvox_pts.keys()) 40 | 41 | sk_stats = {} 42 | total_point_number = 0 43 | for im in self.sparsified_dataset: 44 | name_parse = im.split('/') 45 | seq = name_parse[0] 46 | im_id = name_parse[2][:-4] 47 | key_name = f'{seq}/velodyne/{im_id}.bin#' 48 | matching = [s for s in supvox_keys if key_name in s] 49 | area = 0 50 | for m in matching: 51 | area+=supvox_pts[m] 52 | d = {} 53 | d[f"{seq}_{self.scene_relative_path_to_rgb[1:]}_{im_id}"] = {} 54 | d[f"{seq}_{self.scene_relative_path_to_rgb[1:]}_{im_id}"]['area'] = area 55 | total_point_number+=area 56 | sk_stats.update(d) 57 | 58 | self.data_stats = sk_stats 59 | f = open(self.data_stats_file, "wb") 60 | pickle.dump(self.data_stats, f) 61 | f.close() 62 | print(f'Point Num in Total {total_point_number}') 63 | 64 | def extract_scene_clusters(self): 65 | ret = self.load_cluster_centers() 66 | if ret: # if we managed to load then no need to run again! 67 | return 68 | self.load_feature_vec_dict() 69 | self.cluster_centers = {} 70 | all_scene_keys = list(self.feat_vec_image_dict.keys()) 71 | for ind, s1 in enumerate(all_scene_keys): 72 | im_feats = self.feat_vec_image_dict[s1][0] 73 | 74 | cluster_num = self.cluster_num 75 | if len(im_feats) < cluster_num: 76 | cluster_num = len(im_feats) 77 | self.cluster_centers[s1] = self.cluster_scene(cluster_num, im_feats) 78 | 79 | f = open(self.cluster_centers_file, "wb") 80 | pickle.dump(self.cluster_centers, f) 81 | f.close() 82 | 83 | def extract_cls_features(self): 84 | 85 | ret = self.load_cls_feature_vec_dict() 86 | if ret: 87 | return 88 | cls_feat_vec_image_dict = {} 89 | print("Extracting Feature Vectors!") 90 | for file_name in self.all_images: 91 | print(self.main_2d_path + file_name) 92 | im = Image.open(self.main_2d_path+file_name) 93 | feature_vec = self.sim_object.get_sim_vec_single(im) 94 | name_split = file_name.split('/') 95 | scene = name_split[0] 96 | camera = name_split[1] 97 | frame = name_split[2].split('.')[0] 98 | cls_feat_vec_image_dict[f"{scene}_{camera}_{frame}"] = feature_vec 99 | 100 | self.cls_feat_vec_image_dict = cls_feat_vec_image_dict 101 | # lets dump 102 | with open(self.cls_feature_vec_dict_file, 'wb') as handle: 103 | pickle.dump(self.cls_feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 104 | 105 | def sparsify_dataset(self): 106 | 107 | ret = self.load_sparsified_dataset() 108 | if ret: 109 | return 110 | 111 | selected_frames = [] 112 | for folder in self.train_scenes: 113 | matching = sorted([s for s in self.all_images if f'{folder}/' in s]) 114 | p = 0 115 | for i, img in enumerate(matching): 116 | i = p 117 | key_i = matching[i].replace('/','_')[:-4] 118 | for j in range(i + 1, len(matching), 1): 119 | key_j = matching[j].replace('/','_')[:-4] 120 | 121 | feat_i = self.cls_feat_vec_image_dict[key_i] 122 | feat_j = self.cls_feat_vec_image_dict[key_j] 123 | 124 | feats = np.concatenate((feat_i, feat_j), axis=0) 125 | ordered_D = self.sim_object.calculate_dino_aff_matrix_from_feats(feats) 126 | sim = ordered_D[0][1] 127 | if sim < self.sparsification_similarity_thr: 128 | s_i = int((i + j) / 2) 129 | selected_frames.append(matching[s_i]) 130 | p = j 131 | break 132 | 133 | if p > len(matching): 134 | break 135 | 136 | # lets dump 137 | self.sparsified_dataset = selected_frames 138 | with open(self.sparsified_dataset_file, 'wb') as handle: 139 | pickle.dump(self.sparsified_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL) 140 | 141 | def load_cls_feature_vec_dict(self): 142 | if os.path.exists(self.cls_feature_vec_dict_file): 143 | with open(self.cls_feature_vec_dict_file, 'rb') as handle: 144 | self.cls_feat_vec_image_dict = pickle.load(handle) 145 | print(f'{self.cls_feature_vec_dict_file} features are loaded') 146 | return True 147 | else: 148 | return False 149 | 150 | def load_sparsified_dataset(self): 151 | if os.path.exists(self.sparsified_dataset_file): 152 | with open(self.sparsified_dataset_file, 'rb') as handle: 153 | self.sparsified_dataset = pickle.load(handle) 154 | print(f'{self.sparsified_dataset_file} are loaded') 155 | return True 156 | else: 157 | return False 158 | 159 | def extract_patch_features(self): 160 | 161 | ret = self.load_feature_vec_dict() 162 | if ret: 163 | return 164 | 165 | ret = self.load_sparsified_dataset() 166 | if not ret: 167 | print('No sparsified dataset!') 168 | return 169 | 170 | feat_vec_image_dict = {} 171 | 172 | self.sim_object.set_sim_model_feat_type_dino('patch') 173 | 174 | print("Extracting Feature Vectors!") 175 | for file_name in self.sparsified_dataset: 176 | print(self.main_2d_path + file_name) 177 | im = Image.open(self.main_2d_path+file_name) 178 | feature_vec = self.sim_object.get_sim_vec_single(im) 179 | name_split = file_name.split('/') 180 | scene = name_split[0] 181 | camera = name_split[1] 182 | frame = name_split[2].split('.')[0] 183 | feat_vec_image_dict[f"{scene}_{camera}_{frame}"] = feature_vec 184 | 185 | self.feat_vec_image_dict = feat_vec_image_dict 186 | # lets dump 187 | with open(self.feature_vec_dict_file, 'wb') as handle: 188 | pickle.dump(self.feat_vec_image_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 189 | 190 | self.sim_object.set_sim_model_feat_type_dino('cls') 191 | 192 | def create_initial_set(self, selected_samples): 193 | 194 | f = open('./datasets/semantic_kitti_regions.json') 195 | all_regions = json.load(f) 196 | 197 | path = os.path.join('.', self.seed_name) 198 | os.mkdir(path) 199 | 200 | f = open(path+"/init_label_scan.json", "w") 201 | fu = open(path+"/init_ulabel_scan.json", "w") 202 | f.write("[\n") 203 | fu.write("[\n") 204 | 205 | all_keys = list(all_regions.keys()) 206 | for scn in all_keys: 207 | search_str = scn.replace('_', '_image_2_') 208 | final_str = scn.replace('_','/velodyne/') + '.bin' 209 | if search_str in selected_samples: 210 | f.write(f' "{final_str}",\n') 211 | else: 212 | fu.write(f' "{final_str}",\n') 213 | 214 | f.seek(f.tell() - 2) 215 | fu.seek(fu.tell() - 2) 216 | f.write("\n]\n") 217 | fu.write("\n]\n") 218 | f.close() 219 | fu.close() 220 | 221 | f = open(path+"/init_label_large_region.json", "w") 222 | fu = open(path+"/init_ulabel_large_region.json", "w") 223 | f.write("{") 224 | fu.write("{") 225 | 226 | for i, scn in enumerate(all_regions): 227 | search_str = scn.replace('_','_image_2_') 228 | supervoxel_list = all_regions[scn] 229 | if search_str in selected_samples: 230 | f.write(f'"{scn}": {supervoxel_list}, ') 231 | else: 232 | fu.write(f'"{scn}": {supervoxel_list}, ') 233 | 234 | f.seek(f.tell() - 2) 235 | fu.seek(fu.tell() - 2) 236 | f.write("}") 237 | fu.write("}") 238 | f.close() 239 | fu.close() 240 | 241 | def get_scenes(self): 242 | 243 | sparsified_dataset = self.sparsified_dataset 244 | sparsified_dataset_keys = [] 245 | 246 | for scn in sparsified_dataset: 247 | sparsified_dataset_keys.append(scn.replace('/','_')[:-4]) 248 | 249 | return sparsified_dataset, sparsified_dataset_keys 250 | 251 | 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /datasets/semantic_kitti_config.py: -------------------------------------------------------------------------------- 1 | base_parameters = dict( 2 | dataset_name = 'sk', 3 | main_2d_path = 'path/to/SemanticKitti/data_odometry_color/sequences/', 4 | main_3d_path = 'path/to/SemanticKitti/data_odometry_velodyne/dataset/sequences', 5 | scene_relative_path_to_rgb = '/image_2', 6 | save_path = './sk_attribute_outputs/', 7 | train_scenes = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'], 8 | class_num = 19, 9 | cluster_num = 19, 10 | target_point_num = 22591773, 11 | reduction_size = 1200, 12 | seed_name = 'sk_seed' 13 | ) 14 | 15 | sk_parameters = dict( 16 | sparsification_similarity_thr=0.75, 17 | ) -------------------------------------------------------------------------------- /dino_model/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /dino_model/README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Vision Transformers with DINO 2 | 3 | PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**. 4 | [[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)] 5 | 6 |
7 | DINO illustration 8 |
9 | 10 | ## Pretrained models 11 | You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs. 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 |
archparamsk-nnlineardownload
DeiT-S/1621M74.5%77.0%backbone onlyfull checkpointargslogseval logs
DeiT-S/821M78.3%79.7%backbone onlyfull checkpointargslogseval logs
ViT-B/1685M76.1%78.2%backbone onlyfull checkpointargslogseval logs
ViT-B/885M77.4%80.1%backbone onlyfull checkpointargslogseval logs
ResNet-5023M67.5%75.3%backbone onlyfull checkpointargslogseval logs
77 | 78 | The pretrained models are available on PyTorch Hub. 79 | ```python 80 | import torch 81 | deits16 = torch.hub.load('facebookresearch/dino:main', 'dino_deits16') 82 | deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8') 83 | vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 84 | vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') 85 | resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 86 | ``` 87 | 88 | ## Training 89 | 90 | ### Documentation 91 | Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run: 92 | ``` 93 | python main_dino.py --help 94 | ``` 95 | 96 | ### Vanilla DINO training :sauropod: 97 | Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility. 98 | ``` 99 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 100 | ``` 101 | 102 | ### Multi-node training 103 | We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs): 104 | ``` 105 | python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 106 | ``` 107 | 108 |
109 | 110 | DINO with ViT-base network. 111 | 112 | 113 | ``` 114 | python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 115 | ``` 116 | 117 |
118 | 119 | ### Boosting DINO performance :t-rex: 120 | You can improve the performance of the vanilla run by: 121 | - training for more epochs: `--epochs 300`, 122 | - increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`. 123 | - removing last layer normalization (only safe with `--arch deit_small`): `--norm_last_layer false`, 124 | 125 |
126 | 127 | Full command. 128 | 129 | 130 | ``` 131 | python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 132 | ``` 133 | 134 |
135 | 136 | The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility. 137 | 138 | ### ResNet-50 and other convnets trainings 139 | This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) logs for this run. 140 | ``` 141 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 142 | ``` 143 | 144 | ## Self-attention visualization 145 | You can look at the self-attention of the [CLS] token on the different heads of the last layer by running: 146 | ``` 147 | python visualize_attention.py 148 | ``` 149 | 150 | ## Self-attention video generation 151 | You can generate videos like the one on the blog post with `video_generation.py`. 152 | 153 | https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4 154 | 155 | Extract frames from input video and generate attention video: 156 | ``` 157 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ 158 | --input_path input/video.mp4 \ 159 | --output_path output/ \ 160 | --fps 25 161 | ``` 162 | 163 | Use folder of frames already extracted and generate attention video: 164 | ``` 165 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ 166 | --input_path output/frames/ \ 167 | --output_path output/ \ 168 | --resize 256 \ 169 | ``` 170 | 171 | Only generate video from folder of attention maps images: 172 | ``` 173 | python video_generation.py --input_path output/attention \ 174 | --output_path output/ \ 175 | --video_only \ 176 | --video_format avi 177 | ``` 178 | 179 | Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook. 180 | 181 |
182 | Self-attention from a Vision Transformer with 8x8 patches trained with DINO 183 |
184 | 185 | 186 | ## Evaluation: k-NN classification on ImageNet 187 | To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run: 188 | ``` 189 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet 190 | ``` 191 | If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example: 192 | ``` 193 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 194 | ``` 195 | 196 | ## Evaluation: Linear classification on ImageNet 197 | To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run: 198 | ``` 199 | python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet 200 | ``` 201 | 202 | ## License 203 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. 204 | 205 | ## Citation 206 | If you find this repository useful, please consider giving a star :star: and citation :t-rex:: 207 | ``` 208 | @article{caron2021emerging, 209 | title={Emerging Properties in Self-Supervised Vision Transformers}, 210 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, 211 | journal={arXiv preprint arXiv:2104.14294}, 212 | year={2021} 213 | } 214 | ``` 215 | -------------------------------------------------------------------------------- /dino_model/eval_knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import argparse 16 | 17 | import torch 18 | from torch import nn 19 | import torch.distributed as dist 20 | import torch.backends.cudnn as cudnn 21 | from torchvision import datasets 22 | from torchvision import transforms as pth_transforms 23 | 24 | import utils 25 | import vision_transformer as vits 26 | 27 | 28 | def extract_feature_pipeline(args): 29 | # ============ preparing data ... ============ 30 | transform = pth_transforms.Compose([ 31 | pth_transforms.Resize(256, interpolation=3), 32 | pth_transforms.CenterCrop(224), 33 | pth_transforms.ToTensor(), 34 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 35 | ]) 36 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) 37 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) 38 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 39 | data_loader_train = torch.utils.data.DataLoader( 40 | dataset_train, 41 | sampler=sampler, 42 | batch_size=args.batch_size_per_gpu, 43 | num_workers=args.num_workers, 44 | pin_memory=True, 45 | drop_last=False, 46 | ) 47 | data_loader_val = torch.utils.data.DataLoader( 48 | dataset_val, 49 | batch_size=args.batch_size_per_gpu, 50 | num_workers=args.num_workers, 51 | pin_memory=True, 52 | drop_last=False, 53 | ) 54 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 55 | 56 | # ============ building network ... ============ 57 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 58 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 59 | model.cuda() 60 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 61 | model.eval() 62 | 63 | # ============ extract features ... ============ 64 | print("Extracting features for train set...") 65 | train_features = extract_features(model, data_loader_train) 66 | print("Extracting features for val set...") 67 | test_features = extract_features(model, data_loader_val) 68 | 69 | if utils.get_rank() == 0: 70 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 71 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 72 | 73 | train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() 74 | test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() 75 | # save features and labels 76 | if args.dump_features and dist.get_rank() == 0: 77 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 78 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 79 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 80 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 81 | return train_features, test_features, train_labels, test_labels 82 | 83 | 84 | @torch.no_grad() 85 | def extract_features(model, data_loader): 86 | metric_logger = utils.MetricLogger(delimiter=" ") 87 | features = None 88 | for samples, index in metric_logger.log_every(data_loader, 10): 89 | samples = samples.cuda(non_blocking=True) 90 | index = index.cuda(non_blocking=True) 91 | feats = model(samples).clone() 92 | 93 | # init storage feature matrix 94 | if dist.get_rank() == 0 and features is None: 95 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 96 | if args.use_cuda: 97 | features = features.cuda(non_blocking=True) 98 | print(f"Storing features into tensor of shape {features.shape}") 99 | 100 | # get indexes from all processes 101 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 102 | y_l = list(y_all.unbind(0)) 103 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 104 | y_all_reduce.wait() 105 | index_all = torch.cat(y_l) 106 | 107 | # share features between processes 108 | feats_all = torch.empty( 109 | dist.get_world_size(), 110 | feats.size(0), 111 | feats.size(1), 112 | dtype=feats.dtype, 113 | device=feats.device, 114 | ) 115 | output_l = list(feats_all.unbind(0)) 116 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 117 | output_all_reduce.wait() 118 | 119 | # update storage feature matrix 120 | if dist.get_rank() == 0: 121 | if args.use_cuda: 122 | features.index_copy_(0, index_all, torch.cat(output_l)) 123 | else: 124 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 125 | return features 126 | 127 | 128 | @torch.no_grad() 129 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): 130 | top1, top5, total = 0.0, 0.0, 0 131 | train_features = train_features.t() 132 | num_test_images, num_chunks = test_labels.shape[0], 100 133 | imgs_per_chunk = num_test_images // num_chunks 134 | retrieval_one_hot = torch.zeros(k, num_classes).cuda() 135 | for idx in range(0, num_test_images, imgs_per_chunk): 136 | # get the features for test images 137 | features = test_features[ 138 | idx : min((idx + imgs_per_chunk), num_test_images), : 139 | ] 140 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 141 | batch_size = targets.shape[0] 142 | 143 | # calculate the dot product and compute top-k neighbors 144 | similarity = torch.mm(features, train_features) 145 | distances, indices = similarity.topk(k, largest=True, sorted=True) 146 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 147 | retrieved_neighbors = torch.gather(candidates, 1, indices) 148 | 149 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 150 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 151 | distances_transform = distances.clone().div_(T).exp_() 152 | probs = torch.sum( 153 | torch.mul( 154 | retrieval_one_hot.view(batch_size, -1, num_classes), 155 | distances_transform.view(batch_size, -1, 1), 156 | ), 157 | 1, 158 | ) 159 | _, predictions = probs.sort(1, True) 160 | 161 | # find the predictions that match the target 162 | correct = predictions.eq(targets.data.view(-1, 1)) 163 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 164 | top5 = top5 + correct.narrow(1, 0, 5).sum().item() 165 | total += targets.size(0) 166 | top1 = top1 * 100.0 / total 167 | top5 = top5 * 100.0 / total 168 | return top1, top5 169 | 170 | 171 | class ReturnIndexDataset(datasets.ImageFolder): 172 | def __getitem__(self, idx): 173 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx) 174 | return img, idx 175 | 176 | 177 | if __name__ == '__main__': 178 | parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') 179 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 180 | parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, 181 | help='Number of NN to use. 20 is usually working the best.') 182 | parser.add_argument('--temperature', default=0.07, type=float, 183 | help='Temperature used in the voting coefficient') 184 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 185 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, 186 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 187 | parser.add_argument('--arch', default='deit_small', type=str, 188 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).') 189 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 190 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 191 | help='Key to use in the checkpoint (example: "teacher")') 192 | parser.add_argument('--dump_features', default=None, 193 | help='Path where to save computed features, empty for no saving') 194 | parser.add_argument('--load_features', default=None, help="""If the features have 195 | already been computed, where to find them.""") 196 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 197 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 198 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 199 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 200 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 201 | args = parser.parse_args() 202 | 203 | utils.init_distributed_mode(args) 204 | print("git:\n {}\n".format(utils.get_sha())) 205 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 206 | cudnn.benchmark = True 207 | 208 | if args.load_features: 209 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 210 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 211 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 212 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 213 | else: 214 | # need to extract features ! 215 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 216 | 217 | if utils.get_rank() == 0: 218 | if args.use_cuda: 219 | train_features = train_features.cuda() 220 | test_features = test_features.cuda() 221 | train_labels = train_labels.cuda() 222 | test_labels = test_labels.cuda() 223 | 224 | print("Features are ready!\nStart the k-NN classification.") 225 | for k in args.nb_knn: 226 | top1, top5 = knn_classifier(train_features, train_labels, 227 | test_features, test_labels, k, args.temperature) 228 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 229 | dist.barrier() 230 | -------------------------------------------------------------------------------- /dino_model/eval_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import argparse 16 | import json 17 | from pathlib import Path 18 | 19 | import torch 20 | from torch import nn 21 | import torch.distributed as dist 22 | import torch.backends.cudnn as cudnn 23 | from torchvision import datasets 24 | from torchvision import transforms as pth_transforms 25 | 26 | import utils 27 | import vision_transformer as vits 28 | 29 | 30 | def eval_linear(args): 31 | utils.init_distributed_mode(args) 32 | print("git:\n {}\n".format(utils.get_sha())) 33 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 34 | cudnn.benchmark = True 35 | 36 | # ============ preparing data ... ============ 37 | train_transform = pth_transforms.Compose([ 38 | pth_transforms.RandomResizedCrop(224), 39 | pth_transforms.RandomHorizontalFlip(), 40 | pth_transforms.ToTensor(), 41 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 42 | ]) 43 | val_transform = pth_transforms.Compose([ 44 | pth_transforms.Resize(256, interpolation=3), 45 | pth_transforms.CenterCrop(224), 46 | pth_transforms.ToTensor(), 47 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 48 | ]) 49 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) 50 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) 51 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 52 | train_loader = torch.utils.data.DataLoader( 53 | dataset_train, 54 | sampler=sampler, 55 | batch_size=args.batch_size_per_gpu, 56 | num_workers=args.num_workers, 57 | pin_memory=True, 58 | ) 59 | val_loader = torch.utils.data.DataLoader( 60 | dataset_val, 61 | batch_size=args.batch_size_per_gpu, 62 | num_workers=args.num_workers, 63 | pin_memory=True, 64 | ) 65 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 66 | 67 | # ============ building network ... ============ 68 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 69 | model.cuda() 70 | model.eval() 71 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 72 | # load weights to evaluate 73 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 74 | 75 | linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels) 76 | linear_classifier = linear_classifier.cuda() 77 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) 78 | 79 | # set optimizer 80 | optimizer = torch.optim.SGD( 81 | linear_classifier.parameters(), 82 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 83 | momentum=0.9, 84 | weight_decay=0, # we do not apply weight decay 85 | ) 86 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 87 | 88 | # Optionally resume from a checkpoint 89 | to_restore = {"epoch": 0, "best_acc": 0.} 90 | utils.restart_from_checkpoint( 91 | os.path.join(args.output_dir, "checkpoint.pth.tar"), 92 | run_variables=to_restore, 93 | state_dict=linear_classifier, 94 | optimizer=optimizer, 95 | scheduler=scheduler, 96 | ) 97 | start_epoch = to_restore["epoch"] 98 | best_acc = to_restore["best_acc"] 99 | 100 | for epoch in range(start_epoch, args.epochs): 101 | train_loader.sampler.set_epoch(epoch) 102 | 103 | train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens) 104 | scheduler.step() 105 | 106 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 107 | 'epoch': epoch} 108 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1: 109 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) 110 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 111 | best_acc = max(best_acc, test_stats["acc1"]) 112 | print(f'Max accuracy so far: {best_acc:.2f}%') 113 | log_stats = {**{k: v for k, v in log_stats.items()}, 114 | **{f'test_{k}': v for k, v in test_stats.items()}} 115 | if utils.is_main_process(): 116 | with (Path(args.output_dir) / "log.txt").open("a") as f: 117 | f.write(json.dumps(log_stats) + "\n") 118 | save_dict = { 119 | "epoch": epoch + 1, 120 | "state_dict": linear_classifier.state_dict(), 121 | "optimizer": optimizer.state_dict(), 122 | "scheduler": scheduler.state_dict(), 123 | "best_acc": best_acc, 124 | } 125 | torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) 126 | print("Training of the supervised linear classifier on frozen features completed.\n" 127 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) 128 | 129 | 130 | def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): 131 | linear_classifier.train() 132 | metric_logger = utils.MetricLogger(delimiter=" ") 133 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 134 | header = 'Epoch: [{}]'.format(epoch) 135 | for (inp, target) in metric_logger.log_every(loader, 20, header): 136 | # move to gpu 137 | inp = inp.cuda(non_blocking=True) 138 | target = target.cuda(non_blocking=True) 139 | 140 | # forward 141 | with torch.no_grad(): 142 | output = model.forward_return_n_last_blocks(inp, n, avgpool) 143 | output = linear_classifier(output) 144 | 145 | # compute cross entropy loss 146 | loss = nn.CrossEntropyLoss()(output, target) 147 | 148 | # compute the gradients 149 | optimizer.zero_grad() 150 | loss.backward() 151 | 152 | # step 153 | optimizer.step() 154 | 155 | # log 156 | torch.cuda.synchronize() 157 | metric_logger.update(loss=loss.item()) 158 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 159 | # gather the stats from all processes 160 | metric_logger.synchronize_between_processes() 161 | print("Averaged stats:", metric_logger) 162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 163 | 164 | 165 | @torch.no_grad() 166 | def validate_network(val_loader, model, linear_classifier, n, avgpool): 167 | linear_classifier.eval() 168 | metric_logger = utils.MetricLogger(delimiter=" ") 169 | header = 'Test:' 170 | for inp, target in metric_logger.log_every(val_loader, 20, header): 171 | # move to gpu 172 | inp = inp.cuda(non_blocking=True) 173 | target = target.cuda(non_blocking=True) 174 | 175 | # compute output 176 | output = model.forward_return_n_last_blocks(inp, n, avgpool) 177 | output = linear_classifier(output) 178 | loss = nn.CrossEntropyLoss()(output, target) 179 | 180 | if linear_classifier.module.num_labels >= 5: 181 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 182 | else: 183 | acc1, = utils.accuracy(output, target, topk=(1,)) 184 | 185 | batch_size = inp.shape[0] 186 | metric_logger.update(loss=loss.item()) 187 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 188 | if linear_classifier.module.num_labels >= 5: 189 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 190 | if linear_classifier.module.num_labels >= 5: 191 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 192 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 193 | else: 194 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' 195 | .format(top1=metric_logger.acc1, losses=metric_logger.loss)) 196 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 197 | 198 | 199 | class LinearClassifier(nn.Module): 200 | """Linear layer to train on top of frozen features""" 201 | def __init__(self, dim, num_labels=1000): 202 | super(LinearClassifier, self).__init__() 203 | self.num_labels = num_labels 204 | self.linear = nn.Linear(dim, num_labels) 205 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 206 | self.linear.bias.data.zero_() 207 | 208 | def forward(self, x): 209 | # flatten 210 | x = x.view(x.size(0), -1) 211 | 212 | # linear layer 213 | return self.linear(x) 214 | 215 | 216 | if __name__ == '__main__': 217 | parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet') 218 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens 219 | for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base.""") 220 | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, 221 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. 222 | We typically set this to False for DeiT-Small and to True with ViT-Base.""") 223 | parser.add_argument('--arch', default='deit_small', type=str, 224 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).') 225 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 226 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 227 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') 228 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 229 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of 230 | training (highest LR used during training). The learning rate is linearly scaled 231 | with the batch size, and specified here for a reference batch size of 256. 232 | We recommend tweaking the LR depending on the checkpoint evaluated.""") 233 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 234 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 235 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 236 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 237 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 238 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 239 | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") 240 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') 241 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') 242 | args = parser.parse_args() 243 | eval_linear(args) 244 | -------------------------------------------------------------------------------- /dino_model/fe_dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from torch import nn 5 | from torchvision import transforms as pth_transforms 6 | 7 | # from utils import load_image 8 | import dino_model.utils as utils 9 | import dino_model.vision_transformer as vits 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | class DinoModel(nn.Module): 14 | def __init__(self, feat_type): 15 | super(DinoModel, self).__init__() 16 | 17 | self.transform = pth_transforms.Compose([ 18 | pth_transforms.Resize((224, 224), interpolation=3), 19 | # pth_transforms.CenterCrop(224), 20 | pth_transforms.ToTensor(), 21 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 22 | ]) 23 | 24 | self.model = self.load_model() 25 | self.feat_type = feat_type 26 | 27 | 28 | def load_model(self): 29 | # ============ building network ... ============ 30 | p = 8 31 | model = vits.__dict__["vit_base"](patch_size=p, num_classes=0) 32 | print(f"Model {'vit_base'} {p}x{p} built.") 33 | model.cuda() 34 | utils.load_pretrained_weights(model, "", "", "vit_base", p) 35 | model.eval() 36 | return model 37 | 38 | def forward(self, images, feat_type='cls'): 39 | """Extract the image feature vectors.""" 40 | if self.transform is not None: 41 | images = self.transform(images).unsqueeze(0) 42 | with torch.no_grad(): 43 | 44 | if self.feat_type == 'cls': 45 | features = self.model(images.to(device)) 46 | 47 | elif self.feat_type == 'patch': 48 | 49 | feat_out = {} 50 | 51 | def hook_fn_forward_qkv(module, input, output): 52 | feat_out["qkv"] = output 53 | 54 | self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 55 | 56 | # Forward pass in the model 57 | attentions = self.model.forward_selfattention(images.to(device)) 58 | 59 | # Dimensions 60 | nb_im = attentions.shape[0] # Batch size 61 | nh = attentions.shape[1] # Number of heads 62 | nb_tokens = attentions.shape[2] # Number of tokens 63 | 64 | # Extract the qkv features of the last attention layer 65 | qkv = ( 66 | feat_out["qkv"] 67 | .reshape(nb_im, nb_tokens, 3, nh, -1 // nh) 68 | .permute(2, 0, 3, 1, 4) 69 | ) 70 | q, k, v = qkv[0], qkv[1], qkv[2] 71 | k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :] 72 | q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :] 73 | v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)[:, 1:, :] 74 | 75 | features = k 76 | 77 | return features 78 | -------------------------------------------------------------------------------- /dino_model/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torchvision.models.resnet import resnet50 16 | 17 | import vision_transformer as vits 18 | 19 | dependencies = ["torch", "torchvision"] 20 | 21 | 22 | def dino_deits16(pretrained=True, **kwargs): 23 | """ 24 | DeiT-Small/16x16 pre-trained with DINO. 25 | Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification. 26 | """ 27 | model = vits.__dict__["deit_small"](patch_size=16, num_classes=0, **kwargs) 28 | if pretrained: 29 | state_dict = torch.hub.load_state_dict_from_url( 30 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", 31 | map_location="cpu", 32 | ) 33 | model.load_state_dict(state_dict, strict=True) 34 | return model 35 | 36 | 37 | def dino_deits8(pretrained=True, **kwargs): 38 | """ 39 | DeiT-Small/8x8 pre-trained with DINO. 40 | Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification. 41 | """ 42 | model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs) 43 | if pretrained: 44 | state_dict = torch.hub.load_state_dict_from_url( 45 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", 46 | map_location="cpu", 47 | ) 48 | model.load_state_dict(state_dict, strict=True) 49 | return model 50 | 51 | 52 | def dino_vitb16(pretrained=True, **kwargs): 53 | """ 54 | ViT-Base/16x16 pre-trained with DINO. 55 | Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification. 56 | """ 57 | model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs) 58 | if pretrained: 59 | state_dict = torch.hub.load_state_dict_from_url( 60 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 61 | map_location="cpu", 62 | ) 63 | model.load_state_dict(state_dict, strict=True) 64 | return model 65 | 66 | 67 | def dino_vitb8(pretrained=True, **kwargs): 68 | """ 69 | ViT-Base/8x8 pre-trained with DINO. 70 | Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification. 71 | """ 72 | model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs) 73 | if pretrained: 74 | state_dict = torch.hub.load_state_dict_from_url( 75 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 76 | map_location="cpu", 77 | ) 78 | model.load_state_dict(state_dict, strict=True) 79 | return model 80 | 81 | 82 | def dino_resnet50(pretrained=True, **kwargs): 83 | """ 84 | ResNet-50 pre-trained with DINO. 85 | Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`). 86 | Note that `fc.weight` and `fc.bias` are randomly initialized. 87 | """ 88 | model = resnet50(pretrained=False, **kwargs) 89 | if pretrained: 90 | state_dict = torch.hub.load_state_dict_from_url( 91 | url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth", 92 | map_location="cpu", 93 | ) 94 | model.load_state_dict(state_dict, strict=False) 95 | return model 96 | -------------------------------------------------------------------------------- /dino_model/main_dino.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import os 16 | import sys 17 | import datetime 18 | import time 19 | import math 20 | import json 21 | from pathlib import Path 22 | 23 | import numpy as np 24 | from PIL import Image 25 | import torch 26 | import torch.nn as nn 27 | import torch.distributed as dist 28 | import torch.backends.cudnn as cudnn 29 | import torch.nn.functional as F 30 | from torchvision import datasets, transforms 31 | from torchvision import models as torchvision_models 32 | 33 | import utils 34 | import vision_transformer as vits 35 | from vision_transformer import DINOHead 36 | 37 | torchvision_archs = sorted(name for name in torchvision_models.__dict__ 38 | if name.islower() and not name.startswith("__") 39 | and callable(torchvision_models.__dict__[name])) 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('DINO', add_help=False) 43 | 44 | # Model parameters 45 | parser.add_argument('--arch', default='deit_small', type=str, 46 | choices=['deit_tiny', 'deit_small', 'vit_base'] + torchvision_archs, 47 | help="""Name of architecture to train. For quick experiments with ViTs, 48 | we recommend using deit_tiny or deit_small.""") 49 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels 50 | of input square patches - default 16 (for 16x16 patches). Using smaller 51 | values leads to better performance but requires more memory. Applies only 52 | for ViTs (deit_tiny, deit_small and vit_base). If <16, we recommend disabling 53 | mixed precision training (--use_fp16 false) to avoid unstabilities.""") 54 | parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of 55 | the DINO head output. For complex and large datasets large values (like 65k) work well.""") 56 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, 57 | help="""Whether or not to weight normalize the last layer of the DINO head. 58 | Not normalizing leads to better performance but can make the training unstable. 59 | In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""") 60 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA 61 | parameter for teacher update. The value is increased to 1 during training with cosine schedule. 62 | We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""") 63 | parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, 64 | help="Whether to use batch normalizations in projection head (Default: False)") 65 | 66 | # Temperature teacher parameters 67 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, 68 | help="""Initial value for the teacher temperature: 0.04 works well in most cases. 69 | Try decreasing it if the training loss does not decrease.""") 70 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) 71 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend 72 | starting with the default value of 0.04 and increase this slightly if needed.""") 73 | parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int, 74 | help='Number of warmup epochs for the teacher temperature (Default: 30).') 75 | 76 | # Training/Optimization parameters 77 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not 78 | to use half precision for training. Improves training time and memory requirements, 79 | but can provoke instability and slight decay of performance. We recommend disabling 80 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") 81 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the 82 | weight decay. With ViT, a smaller value at the beginning of training works well.""") 83 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the 84 | weight decay. We use a cosine schedule for WD and using a larger decay by 85 | the end of training improves performance for ViTs.""") 86 | parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter 87 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 88 | help optimization for larger ViT architectures. 0 for disabling.""") 89 | parser.add_argument('--batch_size_per_gpu', default=64, type=int, 90 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.') 91 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 92 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs 93 | during which we keep the output layer fixed. Typically doing so during 94 | the first epoch helps training. Try increasing this value if the loss does not decrease.""") 95 | parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of 96 | linear warmup (highest LR used during training). The learning rate is linearly scaled 97 | with the batch size, and specified here for a reference batch size of 256.""") 98 | parser.add_argument("--warmup_epochs", default=10, type=int, 99 | help="Number of epochs for the linear learning-rate warm up.") 100 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the 101 | end of optimization. We use a cosine LR schedule with linear warmup.""") 102 | parser.add_argument('--optimizer', default='adamw', type=str, 103 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") 104 | 105 | # Multi-crop parameters 106 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), 107 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 108 | Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we 109 | recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") 110 | parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small 111 | local views to generate. Set this parameter to 0 to disable multi-crop training. 112 | When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) 113 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), 114 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 115 | Used for small local view cropping of multi-crop.""") 116 | 117 | # Misc 118 | parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str, 119 | help='Please specify path to the ImageNet training data.') 120 | parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.') 121 | parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.') 122 | parser.add_argument('--seed', default=0, type=int, help='Random seed.') 123 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 124 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 125 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 126 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 127 | return parser 128 | 129 | 130 | def train_dino(args): 131 | utils.init_distributed_mode(args) 132 | utils.fix_random_seeds(args.seed) 133 | print("git:\n {}\n".format(utils.get_sha())) 134 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 135 | cudnn.benchmark = True 136 | 137 | # ============ preparing data ... ============ 138 | transform = DataAugmentationDINO( 139 | args.global_crops_scale, 140 | args.local_crops_scale, 141 | args.local_crops_number, 142 | ) 143 | dataset = datasets.ImageFolder(args.data_path, transform=transform) 144 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) 145 | data_loader = torch.utils.data.DataLoader( 146 | dataset, 147 | sampler=sampler, 148 | batch_size=args.batch_size_per_gpu, 149 | num_workers=args.num_workers, 150 | pin_memory=True, 151 | drop_last=True, 152 | ) 153 | print(f"Data loaded: there are {len(dataset)} images.") 154 | 155 | # ============ building student and teacher networks ... ============ 156 | # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base) 157 | if args.arch in vits.__dict__.keys(): 158 | student = vits.__dict__[args.arch]( 159 | patch_size=args.patch_size, 160 | drop_path_rate=0.1, # stochastic depth 161 | ) 162 | teacher = vits.__dict__[args.arch](patch_size=args.patch_size) 163 | student.head = DINOHead( 164 | student.embed_dim, 165 | args.out_dim, 166 | use_bn=args.use_bn_in_head, 167 | norm_last_layer=args.norm_last_layer, 168 | ) 169 | teacher.head = DINOHead(teacher.embed_dim, args.out_dim, args.use_bn_in_head) 170 | 171 | # otherwise, we check if the architecture is in torchvision models 172 | elif args.arch in torchvision_models.__dict__.keys(): 173 | student = torchvision_models.__dict__[args.arch]() 174 | teacher = torchvision_models.__dict__[args.arch]() 175 | embed_dim = student.fc.weight.shape[1] 176 | student = utils.MultiCropWrapper(student, DINOHead( 177 | embed_dim, 178 | args.out_dim, 179 | use_bn=args.use_bn_in_head, 180 | norm_last_layer=args.norm_last_layer, 181 | )) 182 | teacher = utils.MultiCropWrapper( 183 | teacher, 184 | DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), 185 | ) 186 | else: 187 | print(f"Unknow architecture: {args.arch}") 188 | 189 | # move networks to gpu 190 | student, teacher = student.cuda(), teacher.cuda() 191 | # synchronize batch norms (if any) 192 | if utils.has_batchnorms(student): 193 | student = nn.SyncBatchNorm.convert_sync_batchnorm(student) 194 | teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) 195 | 196 | # we need DDP wrapper to have synchro batch norms working... 197 | teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) 198 | teacher_without_ddp = teacher.module 199 | else: 200 | # teacher_without_ddp and teacher are the same thing 201 | teacher_without_ddp = teacher 202 | student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) 203 | # teacher and student start with the same weights 204 | teacher_without_ddp.load_state_dict(student.module.state_dict()) 205 | # there is no backpropagation through the teacher, so no need for gradients 206 | for p in teacher.parameters(): 207 | p.requires_grad = False 208 | print(f"Student and Teacher are built: they are both {args.arch} network.") 209 | 210 | # ============ preparing loss ... ============ 211 | dino_loss = DINOLoss( 212 | args.out_dim, 213 | args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number 214 | args.warmup_teacher_temp, 215 | args.teacher_temp, 216 | args.warmup_teacher_temp_epochs, 217 | args.epochs, 218 | ).cuda() 219 | 220 | # ============ preparing optimizer ... ============ 221 | params_groups = utils.get_params_groups(student) 222 | if args.optimizer == "adamw": 223 | optimizer = torch.optim.AdamW(params_groups) # to use with ViTs 224 | elif args.optimizer == "sgd": 225 | optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler 226 | elif args.optimizer == "lars": 227 | optimizer = utils.LARS(params_groups) # to use with convnet and large batches 228 | # for mixed precision training 229 | fp16_scaler = None 230 | if args.use_fp16: 231 | fp16_scaler = torch.cuda.amp.GradScaler() 232 | 233 | # ============ init schedulers ... ============ 234 | lr_schedule = utils.cosine_scheduler( 235 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 236 | args.min_lr, 237 | args.epochs, len(data_loader), 238 | warmup_epochs=args.warmup_epochs, 239 | ) 240 | wd_schedule = utils.cosine_scheduler( 241 | args.weight_decay, 242 | args.weight_decay_end, 243 | args.epochs, len(data_loader), 244 | ) 245 | # momentum parameter is increased to 1. during training with a cosine schedule 246 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, 247 | args.epochs, len(data_loader)) 248 | print(f"Loss, optimizer and schedulers ready.") 249 | 250 | # ============ optionally resume training ... ============ 251 | to_restore = {"epoch": 0} 252 | utils.restart_from_checkpoint( 253 | os.path.join(args.output_dir, "checkpoint.pth"), 254 | run_variables=to_restore, 255 | student=student, 256 | teacher=teacher, 257 | optimizer=optimizer, 258 | fp16_scaler=fp16_scaler, 259 | dino_loss=dino_loss, 260 | ) 261 | start_epoch = to_restore["epoch"] 262 | 263 | start_time = time.time() 264 | print("Starting DINO training !") 265 | for epoch in range(start_epoch, args.epochs): 266 | data_loader.sampler.set_epoch(epoch) 267 | 268 | # ============ training one epoch of DINO ... ============ 269 | train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, 270 | data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, 271 | epoch, fp16_scaler, args) 272 | 273 | # ============ writing logs ... ============ 274 | save_dict = { 275 | 'student': student.state_dict(), 276 | 'teacher': teacher.state_dict(), 277 | 'optimizer': optimizer.state_dict(), 278 | 'epoch': epoch + 1, 279 | 'args': args, 280 | 'dino_loss': dino_loss.state_dict(), 281 | } 282 | if fp16_scaler is not None: 283 | save_dict['fp16_scaler'] = fp16_scaler.state_dict() 284 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) 285 | if args.saveckp_freq and epoch % args.saveckp_freq == 0: 286 | utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth')) 287 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 288 | 'epoch': epoch} 289 | if utils.is_main_process(): 290 | with (Path(args.output_dir) / "log.txt").open("a") as f: 291 | f.write(json.dumps(log_stats) + "\n") 292 | total_time = time.time() - start_time 293 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 294 | print('Training time {}'.format(total_time_str)) 295 | 296 | 297 | def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, 298 | optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, 299 | fp16_scaler, args): 300 | metric_logger = utils.MetricLogger(delimiter=" ") 301 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 302 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): 303 | # update weight decay and learning rate according to their schedule 304 | it = len(data_loader) * epoch + it # global training iteration 305 | for i, param_group in enumerate(optimizer.param_groups): 306 | param_group["lr"] = lr_schedule[it] 307 | if i == 0: # only the first group is regularized 308 | param_group["weight_decay"] = wd_schedule[it] 309 | 310 | # move images to gpu 311 | images = [im.cuda(non_blocking=True) for im in images] 312 | # teacher and student forward passes + compute dino loss 313 | with torch.cuda.amp.autocast(fp16_scaler is not None): 314 | teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher 315 | student_output = student(images) 316 | loss = dino_loss(student_output, teacher_output, epoch) 317 | 318 | if not math.isfinite(loss.item()): 319 | print("Loss is {}, stopping training".format(loss.item()), force=True) 320 | sys.exit(1) 321 | 322 | # student update 323 | optimizer.zero_grad() 324 | param_norms = None 325 | if fp16_scaler is None: 326 | loss.backward() 327 | if args.clip_grad: 328 | param_norms = utils.clip_gradients(student, args.clip_grad) 329 | utils.cancel_gradients_last_layer(epoch, student, 330 | args.freeze_last_layer) 331 | optimizer.step() 332 | else: 333 | fp16_scaler.scale(loss).backward() 334 | if args.clip_grad: 335 | fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 336 | param_norms = utils.clip_gradients(student, args.clip_grad) 337 | utils.cancel_gradients_last_layer(epoch, student, 338 | args.freeze_last_layer) 339 | fp16_scaler.step(optimizer) 340 | fp16_scaler.update() 341 | 342 | # EMA update for the teacher 343 | with torch.no_grad(): 344 | m = momentum_schedule[it] # momentum parameter 345 | for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()): 346 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 347 | 348 | # logging 349 | torch.cuda.synchronize() 350 | metric_logger.update(loss=loss.item()) 351 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 352 | metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) 353 | # gather the stats from all processes 354 | metric_logger.synchronize_between_processes() 355 | print("Averaged stats:", metric_logger) 356 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 357 | 358 | 359 | class DINOLoss(nn.Module): 360 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, 361 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 362 | center_momentum=0.9): 363 | super().__init__() 364 | self.student_temp = student_temp 365 | self.center_momentum = center_momentum 366 | self.ncrops = ncrops 367 | self.register_buffer("center", torch.zeros(1, out_dim)) 368 | # we apply a warm up for the teacher temperature because 369 | # a too high temperature makes the training instable at the beginning 370 | self.teacher_temp_schedule = np.concatenate(( 371 | np.linspace(warmup_teacher_temp, 372 | teacher_temp, warmup_teacher_temp_epochs), 373 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 374 | )) 375 | 376 | def forward(self, student_output, teacher_output, epoch): 377 | """ 378 | Cross-entropy between softmax outputs of the teacher and student networks. 379 | """ 380 | student_out = student_output / self.student_temp 381 | student_out = student_out.chunk(self.ncrops) 382 | 383 | # teacher centering and sharpening 384 | temp = self.teacher_temp_schedule[epoch] 385 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 386 | teacher_out = teacher_out.detach().chunk(2) 387 | 388 | total_loss = 0 389 | n_loss_terms = 0 390 | for iq, q in enumerate(teacher_out): 391 | for v in range(len(student_out)): 392 | if v == iq: 393 | # we skip cases where student and teacher operate on the same view 394 | continue 395 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 396 | total_loss += loss.mean() 397 | n_loss_terms += 1 398 | total_loss /= n_loss_terms 399 | self.update_center(teacher_output) 400 | return total_loss 401 | 402 | @torch.no_grad() 403 | def update_center(self, teacher_output): 404 | """ 405 | Update center used for teacher output. 406 | """ 407 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 408 | dist.all_reduce(batch_center) 409 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) 410 | 411 | # ema update 412 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 413 | 414 | 415 | class DataAugmentationDINO(object): 416 | def __init__(self, global_crops_scale, local_crops_scale, local_crops_number): 417 | flip_and_color_jitter = transforms.Compose([ 418 | transforms.RandomHorizontalFlip(p=0.5), 419 | transforms.RandomApply( 420 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 421 | p=0.8 422 | ), 423 | transforms.RandomGrayscale(p=0.2), 424 | ]) 425 | normalize = transforms.Compose([ 426 | transforms.ToTensor(), 427 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 428 | ]) 429 | 430 | # first global crop 431 | self.global_transfo1 = transforms.Compose([ 432 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 433 | flip_and_color_jitter, 434 | utils.GaussianBlur(1.0), 435 | normalize, 436 | ]) 437 | # second global crop 438 | self.global_transfo2 = transforms.Compose([ 439 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 440 | flip_and_color_jitter, 441 | utils.GaussianBlur(0.1), 442 | utils.Solarization(0.2), 443 | normalize, 444 | ]) 445 | # transformation for the local small crops 446 | self.local_crops_number = local_crops_number 447 | self.local_transfo = transforms.Compose([ 448 | transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), 449 | flip_and_color_jitter, 450 | utils.GaussianBlur(p=0.5), 451 | normalize, 452 | ]) 453 | 454 | def __call__(self, image): 455 | crops = [] 456 | crops.append(self.global_transfo1(image)) 457 | crops.append(self.global_transfo2(image)) 458 | for _ in range(self.local_crops_number): 459 | crops.append(self.local_transfo(image)) 460 | return crops 461 | 462 | 463 | if __name__ == '__main__': 464 | parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()]) 465 | args = parser.parse_args() 466 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 467 | train_dino(args) 468 | -------------------------------------------------------------------------------- /dino_model/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A script to run multinode training with submitit. 16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 17 | """ 18 | import argparse 19 | import os 20 | import uuid 21 | from pathlib import Path 22 | 23 | import main_dino 24 | import submitit 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()]) 29 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 30 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 31 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 32 | 33 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 34 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 35 | parser.add_argument('--comment', default="", type=str, 36 | help='Comment to pass to scheduler, e.g. priority message') 37 | return parser.parse_args() 38 | 39 | 40 | def get_shared_folder() -> Path: 41 | user = os.getenv("USER") 42 | if Path("/checkpoint/").is_dir(): 43 | p = Path(f"/checkpoint/{user}/experiments") 44 | p.mkdir(exist_ok=True) 45 | return p 46 | raise RuntimeError("No shared folder available") 47 | 48 | 49 | def get_init_file(): 50 | # Init file must not exist, but it's parent dir must exist. 51 | os.makedirs(str(get_shared_folder()), exist_ok=True) 52 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 53 | if init_file.exists(): 54 | os.remove(str(init_file)) 55 | return init_file 56 | 57 | 58 | class Trainer(object): 59 | def __init__(self, args): 60 | self.args = args 61 | 62 | def __call__(self): 63 | import main_dino 64 | 65 | self._setup_gpu_args() 66 | main_dino.train_dino(self.args) 67 | 68 | def checkpoint(self): 69 | import os 70 | import submitit 71 | 72 | self.args.dist_url = get_init_file().as_uri() 73 | print("Requeuing ", self.args) 74 | empty_trainer = type(self)(self.args) 75 | return submitit.helpers.DelayedSubmission(empty_trainer) 76 | 77 | def _setup_gpu_args(self): 78 | import submitit 79 | from pathlib import Path 80 | 81 | job_env = submitit.JobEnvironment() 82 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 83 | self.args.gpu = job_env.local_rank 84 | self.args.rank = job_env.global_rank 85 | self.args.world_size = job_env.num_tasks 86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | if args.output_dir == "": 92 | args.output_dir = get_shared_folder() / "%j" 93 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 94 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 95 | 96 | num_gpus_per_node = args.ngpus 97 | nodes = args.nodes 98 | timeout_min = args.timeout 99 | 100 | partition = args.partition 101 | kwargs = {} 102 | if args.use_volta32: 103 | kwargs['slurm_constraint'] = 'volta32gb' 104 | if args.comment: 105 | kwargs['slurm_comment'] = args.comment 106 | 107 | executor.update_parameters( 108 | mem_gb=40 * num_gpus_per_node, 109 | gpus_per_node=num_gpus_per_node, 110 | tasks_per_node=num_gpus_per_node, # one task per GPU 111 | cpus_per_task=10, 112 | nodes=nodes, 113 | timeout_min=timeout_min, # max is 60 * 72 114 | # Below are cluster dependent parameters 115 | slurm_partition=partition, 116 | slurm_signal_delay_s=120, 117 | **kwargs 118 | ) 119 | 120 | executor.update_parameters(name="dino") 121 | 122 | args.dist_url = get_init_file().as_uri() 123 | 124 | trainer = Trainer(args) 125 | job = executor.submit(trainer) 126 | 127 | print(f"Submitted job_id: {job.job_id}") 128 | print(f"Logs and checkpoints will be saved at: {args.output_dir}") 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /dino_model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | 17 | Mostly copy-paste from torchvision references or other public repos like DETR: 18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 19 | """ 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import random 25 | import datetime 26 | import subprocess 27 | from collections import defaultdict, deque 28 | 29 | import numpy as np 30 | import torch 31 | from torch import nn 32 | import torch.distributed as dist 33 | from PIL import ImageFilter, ImageOps 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | Apply Gaussian Blur to the PIL image. 39 | """ 40 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 41 | self.prob = p 42 | self.radius_min = radius_min 43 | self.radius_max = radius_max 44 | 45 | def __call__(self, img): 46 | do_it = random.random() <= self.prob 47 | if not do_it: 48 | return img 49 | 50 | return img.filter( 51 | ImageFilter.GaussianBlur( 52 | radius=random.uniform(self.radius_min, self.radius_max) 53 | ) 54 | ) 55 | 56 | 57 | class Solarization(object): 58 | """ 59 | Apply Solarization to the PIL image. 60 | """ 61 | def __init__(self, p): 62 | self.p = p 63 | 64 | def __call__(self, img): 65 | if random.random() < self.p: 66 | return ImageOps.solarize(img) 67 | else: 68 | return img 69 | 70 | 71 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): 72 | if os.path.isfile(pretrained_weights): 73 | state_dict = torch.load(pretrained_weights, map_location="cpu") 74 | if checkpoint_key is not None and checkpoint_key in state_dict: 75 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 76 | state_dict = state_dict[checkpoint_key] 77 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 78 | msg = model.load_state_dict(state_dict, strict=False) 79 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 80 | else: 81 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 82 | url = None 83 | if model_name == "deit_small" and patch_size == 16: 84 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 85 | elif model_name == "deit_small" and patch_size == 8: 86 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 87 | elif model_name == "vit_base" and patch_size == 16: 88 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 89 | elif model_name == "vit_base" and patch_size == 8: 90 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 91 | if url is not None: 92 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 93 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 94 | model.load_state_dict(state_dict, strict=True) 95 | else: 96 | print("There is no reference weights available for this model => We use random weights.") 97 | 98 | 99 | def clip_gradients(model, clip): 100 | norms = [] 101 | for name, p in model.named_parameters(): 102 | if p.grad is not None: 103 | param_norm = p.grad.data.norm(2) 104 | norms.append(param_norm.item()) 105 | clip_coef = clip / (param_norm + 1e-6) 106 | if clip_coef < 1: 107 | p.grad.data.mul_(clip_coef) 108 | return norms 109 | 110 | 111 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 112 | if epoch >= freeze_last_layer: 113 | return 114 | for n, p in model.named_parameters(): 115 | if "last_layer" in n: 116 | p.grad = None 117 | 118 | 119 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 120 | """ 121 | Re-start from checkpoint 122 | """ 123 | if not os.path.isfile(ckp_path): 124 | return 125 | print("Found checkpoint at {}".format(ckp_path)) 126 | 127 | # open checkpoint file 128 | checkpoint = torch.load(ckp_path, map_location="cpu") 129 | 130 | # key is what to look for in the checkpoint file 131 | # value is the object to load 132 | # example: {'state_dict': model} 133 | for key, value in kwargs.items(): 134 | if key in checkpoint and value is not None: 135 | try: 136 | msg = value.load_state_dict(checkpoint[key], strict=False) 137 | print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 138 | except TypeError: 139 | try: 140 | msg = value.load_state_dict(checkpoint[key]) 141 | print("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) 142 | except ValueError: 143 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 144 | else: 145 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 146 | 147 | # re load variable important for the run 148 | if run_variables is not None: 149 | for var_name in run_variables: 150 | if var_name in checkpoint: 151 | run_variables[var_name] = checkpoint[var_name] 152 | 153 | 154 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 155 | warmup_schedule = np.array([]) 156 | warmup_iters = warmup_epochs * niter_per_ep 157 | if warmup_epochs > 0: 158 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 159 | 160 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 161 | schedule = np.array([final_value + 0.5 * (base_value - final_value) * (1 + \ 162 | math.cos(math.pi * i / (len(iters)))) for i in iters]) 163 | 164 | schedule = np.concatenate((warmup_schedule, schedule)) 165 | assert len(schedule) == epochs * niter_per_ep 166 | return schedule 167 | 168 | 169 | def bool_flag(s): 170 | """ 171 | Parse boolean arguments from the command line. 172 | """ 173 | FALSY_STRINGS = {"off", "false", "0"} 174 | TRUTHY_STRINGS = {"on", "true", "1"} 175 | if s.lower() in FALSY_STRINGS: 176 | return False 177 | elif s.lower() in TRUTHY_STRINGS: 178 | return True 179 | else: 180 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 181 | 182 | 183 | def fix_random_seeds(seed=31): 184 | """ 185 | Fix random seeds. 186 | """ 187 | torch.manual_seed(seed) 188 | torch.cuda.manual_seed_all(seed) 189 | np.random.seed(seed) 190 | 191 | 192 | class SmoothedValue(object): 193 | """Track a series of values and provide access to smoothed values over a 194 | window or the global series average. 195 | """ 196 | 197 | def __init__(self, window_size=20, fmt=None): 198 | if fmt is None: 199 | fmt = "{median:.6f} ({global_avg:.6f})" 200 | self.deque = deque(maxlen=window_size) 201 | self.total = 0.0 202 | self.count = 0 203 | self.fmt = fmt 204 | 205 | def update(self, value, n=1): 206 | self.deque.append(value) 207 | self.count += n 208 | self.total += value * n 209 | 210 | def synchronize_between_processes(self): 211 | """ 212 | Warning: does not synchronize the deque! 213 | """ 214 | if not is_dist_avail_and_initialized(): 215 | return 216 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 217 | dist.barrier() 218 | dist.all_reduce(t) 219 | t = t.tolist() 220 | self.count = int(t[0]) 221 | self.total = t[1] 222 | 223 | @property 224 | def median(self): 225 | d = torch.tensor(list(self.deque)) 226 | return d.median().item() 227 | 228 | @property 229 | def avg(self): 230 | d = torch.tensor(list(self.deque), dtype=torch.float32) 231 | return d.mean().item() 232 | 233 | @property 234 | def global_avg(self): 235 | return self.total / self.count 236 | 237 | @property 238 | def max(self): 239 | return max(self.deque) 240 | 241 | @property 242 | def value(self): 243 | return self.deque[-1] 244 | 245 | def __str__(self): 246 | return self.fmt.format( 247 | median=self.median, 248 | avg=self.avg, 249 | global_avg=self.global_avg, 250 | max=self.max, 251 | value=self.value) 252 | 253 | 254 | def reduce_dict(input_dict, average=True): 255 | """ 256 | Args: 257 | input_dict (dict): all the values will be reduced 258 | average (bool): whether to do average or sum 259 | Reduce the values in the dictionary from all processes so that all processes 260 | have the averaged results. Returns a dict with the same fields as 261 | input_dict, after reduction. 262 | """ 263 | world_size = get_world_size() 264 | if world_size < 2: 265 | return input_dict 266 | with torch.no_grad(): 267 | names = [] 268 | values = [] 269 | # sort the keys so that they are consistent across processes 270 | for k in sorted(input_dict.keys()): 271 | names.append(k) 272 | values.append(input_dict[k]) 273 | values = torch.stack(values, dim=0) 274 | dist.all_reduce(values) 275 | if average: 276 | values /= world_size 277 | reduced_dict = {k: v for k, v in zip(names, values)} 278 | return reduced_dict 279 | 280 | 281 | class MetricLogger(object): 282 | def __init__(self, delimiter="\t"): 283 | self.meters = defaultdict(SmoothedValue) 284 | self.delimiter = delimiter 285 | 286 | def update(self, **kwargs): 287 | for k, v in kwargs.items(): 288 | if isinstance(v, torch.Tensor): 289 | v = v.item() 290 | assert isinstance(v, (float, int)) 291 | self.meters[k].update(v) 292 | 293 | def __getattr__(self, attr): 294 | if attr in self.meters: 295 | return self.meters[attr] 296 | if attr in self.__dict__: 297 | return self.__dict__[attr] 298 | raise AttributeError("'{}' object has no attribute '{}'".format( 299 | type(self).__name__, attr)) 300 | 301 | def __str__(self): 302 | loss_str = [] 303 | for name, meter in self.meters.items(): 304 | loss_str.append( 305 | "{}: {}".format(name, str(meter)) 306 | ) 307 | return self.delimiter.join(loss_str) 308 | 309 | def synchronize_between_processes(self): 310 | for meter in self.meters.values(): 311 | meter.synchronize_between_processes() 312 | 313 | def add_meter(self, name, meter): 314 | self.meters[name] = meter 315 | 316 | def log_every(self, iterable, print_freq, header=None): 317 | i = 0 318 | if not header: 319 | header = '' 320 | start_time = time.time() 321 | end = time.time() 322 | iter_time = SmoothedValue(fmt='{avg:.6f}') 323 | data_time = SmoothedValue(fmt='{avg:.6f}') 324 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 325 | if torch.cuda.is_available(): 326 | log_msg = self.delimiter.join([ 327 | header, 328 | '[{0' + space_fmt + '}/{1}]', 329 | 'eta: {eta}', 330 | '{meters}', 331 | 'time: {time}', 332 | 'data: {data}', 333 | 'max mem: {memory:.0f}' 334 | ]) 335 | else: 336 | log_msg = self.delimiter.join([ 337 | header, 338 | '[{0' + space_fmt + '}/{1}]', 339 | 'eta: {eta}', 340 | '{meters}', 341 | 'time: {time}', 342 | 'data: {data}' 343 | ]) 344 | MB = 1024.0 * 1024.0 345 | for obj in iterable: 346 | data_time.update(time.time() - end) 347 | yield obj 348 | iter_time.update(time.time() - end) 349 | if i % print_freq == 0 or i == len(iterable) - 1: 350 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 351 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 352 | if torch.cuda.is_available(): 353 | print(log_msg.format( 354 | i, len(iterable), eta=eta_string, 355 | meters=str(self), 356 | time=str(iter_time), data=str(data_time), 357 | memory=torch.cuda.max_memory_allocated() / MB)) 358 | else: 359 | print(log_msg.format( 360 | i, len(iterable), eta=eta_string, 361 | meters=str(self), 362 | time=str(iter_time), data=str(data_time))) 363 | i += 1 364 | end = time.time() 365 | total_time = time.time() - start_time 366 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 367 | print('{} Total time: {} ({:.6f} s / it)'.format( 368 | header, total_time_str, total_time / len(iterable))) 369 | 370 | 371 | def get_sha(): 372 | cwd = os.path.dirname(os.path.abspath(__file__)) 373 | 374 | def _run(command): 375 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 376 | sha = 'N/A' 377 | diff = "clean" 378 | branch = 'N/A' 379 | try: 380 | sha = _run(['git', 'rev-parse', 'HEAD']) 381 | subprocess.check_output(['git', 'diff'], cwd=cwd) 382 | diff = _run(['git', 'diff-index', 'HEAD']) 383 | diff = "has uncommited changes" if diff else "clean" 384 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 385 | except Exception: 386 | pass 387 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 388 | return message 389 | 390 | 391 | def is_dist_avail_and_initialized(): 392 | if not dist.is_available(): 393 | return False 394 | if not dist.is_initialized(): 395 | return False 396 | return True 397 | 398 | 399 | def get_world_size(): 400 | if not is_dist_avail_and_initialized(): 401 | return 1 402 | return dist.get_world_size() 403 | 404 | 405 | def get_rank(): 406 | if not is_dist_avail_and_initialized(): 407 | return 0 408 | return dist.get_rank() 409 | 410 | 411 | def is_main_process(): 412 | return get_rank() == 0 413 | 414 | 415 | def save_on_master(*args, **kwargs): 416 | if is_main_process(): 417 | torch.save(*args, **kwargs) 418 | 419 | 420 | def setup_for_distributed(is_master): 421 | """ 422 | This function disables printing when not in master process 423 | """ 424 | import builtins as __builtin__ 425 | builtin_print = __builtin__.print 426 | 427 | def print(*args, **kwargs): 428 | force = kwargs.pop('force', False) 429 | if is_master or force: 430 | builtin_print(*args, **kwargs) 431 | 432 | __builtin__.print = print 433 | 434 | 435 | def init_distributed_mode(args): 436 | # launched with torch.distributed.launch 437 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 438 | args.rank = int(os.environ["RANK"]) 439 | args.world_size = int(os.environ['WORLD_SIZE']) 440 | args.gpu = int(os.environ['LOCAL_RANK']) 441 | # launched with submitit on a slurm cluster 442 | elif 'SLURM_PROCID' in os.environ: 443 | args.rank = int(os.environ['SLURM_PROCID']) 444 | args.gpu = args.rank % torch.cuda.device_count() 445 | # launched naively with `python main_dino.py` 446 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 447 | elif torch.cuda.is_available(): 448 | print('Will run the code on one GPU.') 449 | args.rank, args.gpu, args.world_size = 0, 0, 1 450 | os.environ['MASTER_ADDR'] = '127.0.0.1' 451 | os.environ['MASTER_PORT'] = '29500' 452 | else: 453 | print('Does not support training without GPU.') 454 | sys.exit(1) 455 | 456 | dist.init_process_group( 457 | backend="nccl", 458 | init_method=args.dist_url, 459 | world_size=args.world_size, 460 | rank=args.rank, 461 | ) 462 | 463 | torch.cuda.set_device(args.gpu) 464 | print('| distributed init (rank {}): {}'.format( 465 | args.rank, args.dist_url), flush=True) 466 | dist.barrier() 467 | setup_for_distributed(args.rank == 0) 468 | 469 | 470 | def accuracy(output, target, topk=(1,)): 471 | """Computes the accuracy over the k top predictions for the specified values of k""" 472 | maxk = max(topk) 473 | batch_size = target.size(0) 474 | _, pred = output.topk(maxk, 1, True, True) 475 | pred = pred.t() 476 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 477 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 478 | 479 | 480 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 481 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 482 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 483 | def norm_cdf(x): 484 | # Computes standard normal cumulative distribution function 485 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 486 | 487 | if (mean < a - 2 * std) or (mean > b + 2 * std): 488 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 489 | "The distribution of values may be incorrect.", 490 | stacklevel=2) 491 | 492 | with torch.no_grad(): 493 | # Values are generated by using a truncated uniform distribution and 494 | # then using the inverse CDF for the normal distribution. 495 | # Get upper and lower cdf values 496 | l = norm_cdf((a - mean) / std) 497 | u = norm_cdf((b - mean) / std) 498 | 499 | # Uniformly fill tensor with values from [l, u], then translate to 500 | # [2l-1, 2u-1]. 501 | tensor.uniform_(2 * l - 1, 2 * u - 1) 502 | 503 | # Use inverse cdf transform for normal distribution to get truncated 504 | # standard normal 505 | tensor.erfinv_() 506 | 507 | # Transform to proper mean, std 508 | tensor.mul_(std * math.sqrt(2.)) 509 | tensor.add_(mean) 510 | 511 | # Clamp to ensure it's in the proper range 512 | tensor.clamp_(min=a, max=b) 513 | return tensor 514 | 515 | 516 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 517 | # type: (Tensor, float, float, float, float) -> Tensor 518 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 519 | 520 | 521 | class LARS(torch.optim.Optimizer): 522 | """ 523 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 524 | """ 525 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 526 | weight_decay_filter=None, lars_adaptation_filter=None): 527 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 528 | eta=eta, weight_decay_filter=weight_decay_filter, 529 | lars_adaptation_filter=lars_adaptation_filter) 530 | super().__init__(params, defaults) 531 | 532 | @torch.no_grad() 533 | def step(self): 534 | for g in self.param_groups: 535 | for p in g['params']: 536 | dp = p.grad 537 | 538 | if dp is None: 539 | continue 540 | 541 | if p.ndim != 1: 542 | dp = dp.add(p, alpha=g['weight_decay']) 543 | 544 | if p.ndim != 1: 545 | param_norm = torch.norm(p) 546 | update_norm = torch.norm(dp) 547 | one = torch.ones_like(param_norm) 548 | q = torch.where(param_norm > 0., 549 | torch.where(update_norm > 0, 550 | (g['eta'] * param_norm / update_norm), one), one) 551 | dp = dp.mul(q) 552 | 553 | param_state = self.state[p] 554 | if 'mu' not in param_state: 555 | param_state['mu'] = torch.zeros_like(p) 556 | mu = param_state['mu'] 557 | mu.mul_(g['momentum']).add_(dp) 558 | 559 | p.add_(mu, alpha=-g['lr']) 560 | 561 | 562 | class MultiCropWrapper(nn.Module): 563 | """ 564 | Perform forward pass separately on each resolution input. 565 | The inputs corresponding to a single resolution are clubbed and single 566 | forward is run on the same resolution inputs. Hence we do several 567 | forward passes = number of different resolutions used. We then 568 | concatenate all the output features. 569 | """ 570 | def __init__(self, backbone, head): 571 | super(MultiCropWrapper, self).__init__() 572 | backbone.fc = nn.Identity() 573 | self.backbone = backbone 574 | self.head = head 575 | 576 | def forward(self, x): 577 | # convert to list 578 | if not isinstance(x, list): 579 | x = [x] 580 | idx_crops = torch.cumsum(torch.unique_consecutive( 581 | torch.tensor([inp.shape[-1] for inp in x]), 582 | return_counts=True, 583 | )[1], 0) 584 | start_idx = 0 585 | for end_idx in idx_crops: 586 | _out = self.backbone(torch.cat(x[start_idx: end_idx])) 587 | if start_idx == 0: 588 | output = _out 589 | else: 590 | output = torch.cat((output, _out)) 591 | start_idx = end_idx 592 | # Run the head forward on the concatenated features. 593 | return self.head(output) 594 | 595 | 596 | def get_params_groups(model): 597 | regularized = [] 598 | not_regularized = [] 599 | for name, param in model.named_parameters(): 600 | if not param.requires_grad: 601 | continue 602 | # we do not regularize biases nor Norm parameters 603 | if name.endswith(".bias") or len(param.shape) == 1: 604 | not_regularized.append(param) 605 | else: 606 | regularized.append(param) 607 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 608 | 609 | 610 | def has_batchnorms(model): 611 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 612 | for name, module in model.named_modules(): 613 | if isinstance(module, bn_types): 614 | return True 615 | return False 616 | -------------------------------------------------------------------------------- /dino_model/video_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import glob 16 | import sys 17 | import argparse 18 | import cv2 19 | 20 | from tqdm import tqdm 21 | import matplotlib.pyplot as plt 22 | import torch 23 | import torch.nn as nn 24 | import torchvision 25 | from torchvision import transforms as pth_transforms 26 | import numpy as np 27 | from PIL import Image 28 | 29 | import utils 30 | import vision_transformer as vits 31 | 32 | 33 | FOURCC = { 34 | "mp4": cv2.VideoWriter_fourcc(*"MP4V"), 35 | "avi": cv2.VideoWriter_fourcc(*"XVID"), 36 | } 37 | DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 38 | 39 | 40 | class VideoGenerator: 41 | def __init__(self, args): 42 | self.args = args 43 | # self.model = None 44 | # Don't need to load model if you only want a video 45 | if not self.args.video_only: 46 | self.model = self.__load_model() 47 | 48 | def run(self): 49 | if self.args.input_path is None: 50 | print(f"Provided input path {self.args.input_path} is non valid.") 51 | sys.exit(1) 52 | else: 53 | if self.args.video_only: 54 | self._generate_video_from_images( 55 | self.args.input_path, self.args.output_path 56 | ) 57 | else: 58 | # If input path exists 59 | if os.path.exists(self.args.input_path): 60 | # If input is a video file 61 | if os.path.isfile(self.args.input_path): 62 | frames_folder = os.path.join(self.args.output_path, "frames") 63 | attention_folder = os.path.join( 64 | self.args.output_path, "attention" 65 | ) 66 | 67 | os.makedirs(frames_folder, exist_ok=True) 68 | os.makedirs(attention_folder, exist_ok=True) 69 | 70 | self._extract_frames_from_video( 71 | self.args.input_path, frames_folder 72 | ) 73 | 74 | self._inference( 75 | frames_folder, 76 | attention_folder, 77 | ) 78 | 79 | self._generate_video_from_images( 80 | attention_folder, self.args.output_path 81 | ) 82 | 83 | # If input is a folder of already extracted frames 84 | if os.path.isdir(self.args.input_path): 85 | attention_folder = os.path.join( 86 | self.args.output_path, "attention" 87 | ) 88 | 89 | os.makedirs(attention_folder, exist_ok=True) 90 | 91 | self._inference(self.args.input_path, attention_folder) 92 | 93 | self._generate_video_from_images( 94 | attention_folder, self.args.output_path 95 | ) 96 | 97 | # If input path doesn't exists 98 | else: 99 | print(f"Provided input path {self.args.input_path} doesn't exists.") 100 | sys.exit(1) 101 | 102 | def _extract_frames_from_video(self, inp: str, out: str): 103 | vidcap = cv2.VideoCapture(inp) 104 | self.args.fps = vidcap.get(cv2.CAP_PROP_FPS) 105 | 106 | print(f"Video: {inp} ({self.args.fps} fps)") 107 | print(f"Extracting frames to {out}") 108 | 109 | success, image = vidcap.read() 110 | count = 0 111 | while success: 112 | cv2.imwrite( 113 | os.path.join(out, f"frame-{count:04}.jpg"), 114 | image, 115 | ) 116 | success, image = vidcap.read() 117 | count += 1 118 | 119 | def _generate_video_from_images(self, inp: str, out: str): 120 | img_array = [] 121 | attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg"))) 122 | 123 | # Get size of the first image 124 | with open(attention_images_list[0], "rb") as f: 125 | img = Image.open(f) 126 | img = img.convert("RGB") 127 | size = (img.width, img.height) 128 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 129 | 130 | print(f"Generating video {size} to {out}") 131 | 132 | for filename in tqdm(attention_images_list[1:]): 133 | with open(filename, "rb") as f: 134 | img = Image.open(f) 135 | img = img.convert("RGB") 136 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 137 | 138 | out = cv2.VideoWriter( 139 | os.path.join(out, "video." + self.args.video_format), 140 | FOURCC[self.args.video_format], 141 | self.args.fps, 142 | size, 143 | ) 144 | 145 | for i in range(len(img_array)): 146 | out.write(img_array[i]) 147 | out.release() 148 | print("Done") 149 | 150 | def _inference(self, inp: str, out: str): 151 | print(f"Generating attention images to {out}") 152 | 153 | for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))): 154 | with open(img_path, "rb") as f: 155 | img = Image.open(f) 156 | img = img.convert("RGB") 157 | 158 | if self.args.resize is not None: 159 | transform = pth_transforms.Compose( 160 | [ 161 | pth_transforms.ToTensor(), 162 | pth_transforms.Resize(self.args.resize), 163 | pth_transforms.Normalize( 164 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 165 | ), 166 | ] 167 | ) 168 | else: 169 | transform = pth_transforms.Compose( 170 | [ 171 | pth_transforms.ToTensor(), 172 | pth_transforms.Normalize( 173 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 174 | ), 175 | ] 176 | ) 177 | 178 | img = transform(img) 179 | 180 | # make the image divisible by the patch size 181 | w, h = ( 182 | img.shape[1] - img.shape[1] % self.args.patch_size, 183 | img.shape[2] - img.shape[2] % self.args.patch_size, 184 | ) 185 | img = img[:, :w, :h].unsqueeze(0) 186 | 187 | w_featmap = img.shape[-2] // self.args.patch_size 188 | h_featmap = img.shape[-1] // self.args.patch_size 189 | 190 | attentions = self.model.forward_selfattention(img.to(DEVICE)) 191 | 192 | nh = attentions.shape[1] # number of head 193 | 194 | # we keep only the output patch attention 195 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 196 | 197 | # we keep only a certain percentage of the mass 198 | val, idx = torch.sort(attentions) 199 | val /= torch.sum(val, dim=1, keepdim=True) 200 | cumval = torch.cumsum(val, dim=1) 201 | th_attn = cumval > (1 - self.args.threshold) 202 | idx2 = torch.argsort(idx) 203 | for head in range(nh): 204 | th_attn[head] = th_attn[head][idx2[head]] 205 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 206 | # interpolate 207 | th_attn = ( 208 | nn.functional.interpolate( 209 | th_attn.unsqueeze(0), 210 | scale_factor=self.args.patch_size, 211 | mode="nearest", 212 | )[0] 213 | .cpu() 214 | .numpy() 215 | ) 216 | 217 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 218 | attentions = ( 219 | nn.functional.interpolate( 220 | attentions.unsqueeze(0), 221 | scale_factor=self.args.patch_size, 222 | mode="nearest", 223 | )[0] 224 | .cpu() 225 | .numpy() 226 | ) 227 | 228 | # save attentions heatmaps 229 | fname = os.path.join(out, "attn-" + os.path.basename(img_path)) 230 | plt.imsave( 231 | fname=fname, 232 | arr=sum( 233 | attentions[i] * 1 / attentions.shape[0] 234 | for i in range(attentions.shape[0]) 235 | ), 236 | cmap="inferno", 237 | format="jpg", 238 | ) 239 | 240 | def __load_model(self): 241 | # build model 242 | model = vits.__dict__[self.args.arch]( 243 | patch_size=self.args.patch_size, num_classes=0 244 | ) 245 | for p in model.parameters(): 246 | p.requires_grad = False 247 | model.eval() 248 | model.to(DEVICE) 249 | 250 | if os.path.isfile(self.args.pretrained_weights): 251 | state_dict = torch.load(self.args.pretrained_weights, map_location="cpu") 252 | if ( 253 | self.args.checkpoint_key is not None 254 | and self.args.checkpoint_key in state_dict 255 | ): 256 | print( 257 | f"Take key {self.args.checkpoint_key} in provided checkpoint dict" 258 | ) 259 | state_dict = state_dict[self.args.checkpoint_key] 260 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 261 | msg = model.load_state_dict(state_dict, strict=False) 262 | print( 263 | "Pretrained weights found at {} and loaded with msg: {}".format( 264 | self.args.pretrained_weights, msg 265 | ) 266 | ) 267 | else: 268 | print( 269 | "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." 270 | ) 271 | url = None 272 | if self.args.arch == "deit_small" and self.args.patch_size == 16: 273 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 274 | elif self.args.arch == "deit_small" and self.args.patch_size == 8: 275 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper 276 | elif self.args.arch == "vit_base" and self.args.patch_size == 16: 277 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 278 | elif self.args.arch == "vit_base" and self.args.patch_size == 8: 279 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 280 | if url is not None: 281 | print( 282 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." 283 | ) 284 | state_dict = torch.hub.load_state_dict_from_url( 285 | url="https://dl.fbaipublicfiles.com/dino/" + url 286 | ) 287 | model.load_state_dict(state_dict, strict=True) 288 | else: 289 | print( 290 | "There is no reference weights available for this model => We use random weights." 291 | ) 292 | return model 293 | 294 | 295 | def parse_args(): 296 | parser = argparse.ArgumentParser("Generation self-attention video") 297 | parser.add_argument( 298 | "--arch", 299 | default="deit_small", 300 | type=str, 301 | choices=["deit_tiny", "deit_small", "vit_base"], 302 | help="Architecture (support only ViT atm).", 303 | ) 304 | parser.add_argument( 305 | "--patch_size", default=8, type=int, help="Patch resolution of the self.model." 306 | ) 307 | parser.add_argument( 308 | "--pretrained_weights", 309 | default="", 310 | type=str, 311 | help="Path to pretrained weights to load.", 312 | ) 313 | parser.add_argument( 314 | "--checkpoint_key", 315 | default="teacher", 316 | type=str, 317 | help='Key to use in the checkpoint (example: "teacher")', 318 | ) 319 | parser.add_argument( 320 | "--input_path", 321 | required=True, 322 | type=str, 323 | help="""Path to a video file if you want to extract frames 324 | or to a folder of images already extracted by yourself. 325 | or to a folder of attention images.""", 326 | ) 327 | parser.add_argument( 328 | "--output_path", 329 | default="./", 330 | type=str, 331 | help="""Path to store a folder of frames and / or a folder of attention images. 332 | and / or a final video. Default to current directory.""", 333 | ) 334 | parser.add_argument( 335 | "--threshold", 336 | type=float, 337 | default=0.6, 338 | help="""We visualize masks 339 | obtained by thresholding the self-attention maps to keep xx percent of the mass.""", 340 | ) 341 | parser.add_argument( 342 | "--resize", 343 | default=None, 344 | type=int, 345 | nargs="+", 346 | help="""Apply a resize transformation to input image(s). Use if OOM error. 347 | Usage (single or W H): --resize 512, --resize 720 1280""", 348 | ) 349 | parser.add_argument( 350 | "--video_only", 351 | action="store_true", 352 | help="""Use this flag if you only want to generate a video and not all attention images. 353 | If used, --input_path must be set to the folder of attention images. Ex: ./attention/""", 354 | ) 355 | parser.add_argument( 356 | "--fps", 357 | default=30.0, 358 | type=float, 359 | help="FPS of input / output video. Automatically set if you extract frames from a video.", 360 | ) 361 | parser.add_argument( 362 | "--video_format", 363 | default="mp4", 364 | type=str, 365 | choices=["mp4", "avi"], 366 | help="Format of generated video (mp4 or avi).", 367 | ) 368 | 369 | return parser.parse_args() 370 | 371 | 372 | if __name__ == "__main__": 373 | args = parse_args() 374 | 375 | vg = VideoGenerator(args) 376 | vg.run() 377 | -------------------------------------------------------------------------------- /dino_model/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from dino_model.utils import trunc_normal_ 25 | 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path(x, self.drop_prob, self.training) 47 | 48 | 49 | class Mlp(nn.Module): 50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 51 | super().__init__() 52 | out_features = out_features or in_features 53 | hidden_features = hidden_features or in_features 54 | self.fc1 = nn.Linear(in_features, hidden_features) 55 | self.act = act_layer() 56 | self.fc2 = nn.Linear(hidden_features, out_features) 57 | self.drop = nn.Dropout(drop) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x = self.act(x) 62 | x = self.drop(x) 63 | x = self.fc2(x) 64 | x = self.drop(x) 65 | return x 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | super().__init__() 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | self.proj = nn.Linear(dim, dim) 78 | self.proj_drop = nn.Dropout(proj_drop) 79 | 80 | def forward(self, x): 81 | B, N, C = x.shape 82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | q, k, v = qkv[0], qkv[1], qkv[2] 84 | 85 | attn = (q @ k.transpose(-2, -1)) * self.scale 86 | attn = attn.softmax(dim=-1) 87 | attn = self.attn_drop(attn) 88 | 89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x, attn 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 98 | super().__init__() 99 | self.norm1 = norm_layer(dim) 100 | self.attn = Attention( 101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 106 | 107 | def forward(self, x, return_attention=False): 108 | y, attn = self.attn(self.norm1(x)) 109 | if return_attention: 110 | return attn 111 | x = x + self.drop_path(y) 112 | x = x + self.drop_path(self.mlp(self.norm2(x))) 113 | return x 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """ Image to Patch Embedding 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 120 | super().__init__() 121 | num_patches = (img_size // patch_size) * (img_size // patch_size) 122 | self.img_size = img_size 123 | self.patch_size = patch_size 124 | self.num_patches = num_patches 125 | 126 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 127 | 128 | def forward(self, x): 129 | B, C, H, W = x.shape 130 | x = self.proj(x).flatten(2).transpose(1, 2) 131 | return x 132 | 133 | 134 | class VisionTransformer(nn.Module): 135 | """ Vision Transformer """ 136 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 137 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 138 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 139 | super().__init__() 140 | self.num_features = self.embed_dim = embed_dim 141 | 142 | self.patch_embed = PatchEmbed( 143 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 144 | num_patches = self.patch_embed.num_patches 145 | 146 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 147 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 148 | self.pos_drop = nn.Dropout(p=drop_rate) 149 | 150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 151 | self.blocks = nn.ModuleList([ 152 | Block( 153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 154 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 155 | for i in range(depth)]) 156 | self.norm = norm_layer(embed_dim) 157 | 158 | # Classifier head 159 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 160 | 161 | trunc_normal_(self.pos_embed, std=.02) 162 | trunc_normal_(self.cls_token, std=.02) 163 | self.apply(self._init_weights) 164 | 165 | def _init_weights(self, m): 166 | if isinstance(m, nn.Linear): 167 | trunc_normal_(m.weight, std=.02) 168 | if isinstance(m, nn.Linear) and m.bias is not None: 169 | nn.init.constant_(m.bias, 0) 170 | elif isinstance(m, nn.LayerNorm): 171 | nn.init.constant_(m.bias, 0) 172 | nn.init.constant_(m.weight, 1.0) 173 | 174 | def forward(self, x): 175 | # convert to list 176 | if not isinstance(x, list): 177 | x = [x] 178 | # Perform forward pass separately on each resolution input. 179 | # The inputs corresponding to a single resolution are clubbed and single 180 | # forward is run on the same resolution inputs. Hence we do several 181 | # forward passes = number of different resolutions used. We then 182 | # concatenate all the output features. 183 | idx_crops = torch.cumsum(torch.unique_consecutive( 184 | torch.tensor([inp.shape[-1] for inp in x]), 185 | return_counts=True, 186 | )[1], 0) 187 | start_idx = 0 188 | for end_idx in idx_crops: 189 | _out = self.forward_features(torch.cat(x[start_idx: end_idx])) 190 | if start_idx == 0: 191 | output = _out 192 | else: 193 | output = torch.cat((output, _out)) 194 | start_idx = end_idx 195 | # Run the head forward on the concatenated features. 196 | return self.head(output) 197 | 198 | def forward_features(self, x): 199 | B = x.shape[0] 200 | x = self.patch_embed(x) 201 | 202 | cls_tokens = self.cls_token.expand(B, -1, -1) 203 | x = torch.cat((cls_tokens, x), dim=1) 204 | pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) 205 | x = x + pos_embed 206 | x = self.pos_drop(x) 207 | 208 | for blk in self.blocks: 209 | x = blk(x) 210 | if self.norm is not None: 211 | x = self.norm(x) 212 | 213 | return x[:, 0] 214 | 215 | def interpolate_pos_encoding(self, x, pos_embed): 216 | npatch = x.shape[1] - 1 217 | N = pos_embed.shape[1] - 1 218 | if npatch == N: 219 | return pos_embed 220 | class_emb = pos_embed[:, 0] 221 | pos_embed = pos_embed[:, 1:] 222 | dim = x.shape[-1] 223 | pos_embed = nn.functional.interpolate( 224 | pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 225 | scale_factor=math.sqrt(npatch / N), 226 | mode='bicubic', 227 | ) 228 | pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 229 | return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) 230 | 231 | def forward_selfattention(self, x): 232 | B, nc, w, h = x.shape 233 | N = self.pos_embed.shape[1] - 1 234 | x = self.patch_embed(x) 235 | 236 | # interpolate patch embeddings 237 | dim = x.shape[-1] 238 | w0 = w // self.patch_embed.patch_size 239 | h0 = h // self.patch_embed.patch_size 240 | class_pos_embed = self.pos_embed[:, 0] 241 | patch_pos_embed = self.pos_embed[:, 1:] 242 | patch_pos_embed = nn.functional.interpolate( 243 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 244 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 245 | mode='bicubic', 246 | ) 247 | # sometimes there is a floating point error in the interpolation and so 248 | # we need to pad the patch positional encoding. 249 | if w0 != patch_pos_embed.shape[-2]: 250 | helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device) 251 | patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2) 252 | if h0 != patch_pos_embed.shape[-1]: 253 | helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device) 254 | patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1) 255 | 256 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 257 | pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 258 | cls_tokens = self.cls_token.expand(B, -1, -1) 259 | x = torch.cat((cls_tokens, x), dim=1) 260 | x = x + pos_embed 261 | x = self.pos_drop(x) 262 | 263 | for i, blk in enumerate(self.blocks): 264 | if i < len(self.blocks) - 1: 265 | x = blk(x) 266 | else: 267 | return blk(x, return_attention=True) 268 | 269 | def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False): 270 | B = x.shape[0] 271 | x = self.patch_embed(x) 272 | 273 | cls_tokens = self.cls_token.expand(B, -1, -1) 274 | x = torch.cat((cls_tokens, x), dim=1) 275 | pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) 276 | x = x + pos_embed 277 | x = self.pos_drop(x) 278 | 279 | # we will return the [CLS] tokens from the `n` last blocks 280 | output = [] 281 | for i, blk in enumerate(self.blocks): 282 | x = blk(x) 283 | if len(self.blocks) - i <= n: 284 | output.append(self.norm(x)[:, 0]) 285 | if return_patch_avgpool: 286 | x = self.norm(x) 287 | # In addition to the [CLS] tokens from the `n` last blocks, we also return 288 | # the patch tokens from the last block. This is useful for linear eval. 289 | output.append(torch.mean(x[:, 1:], dim=1)) 290 | return torch.cat(output, dim=-1) 291 | 292 | 293 | def deit_tiny(patch_size=16, **kwargs): 294 | model = VisionTransformer( 295 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 296 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 297 | return model 298 | 299 | 300 | def deit_small(patch_size=16, **kwargs): 301 | model = VisionTransformer( 302 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 303 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 304 | return model 305 | 306 | 307 | def vit_base(patch_size=16, **kwargs): 308 | model = VisionTransformer( 309 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 310 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 311 | return model 312 | 313 | 314 | class DINOHead(nn.Module): 315 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 316 | super().__init__() 317 | nlayers = max(nlayers, 1) 318 | if nlayers == 1: 319 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 320 | else: 321 | layers = [nn.Linear(in_dim, hidden_dim)] 322 | if use_bn: 323 | layers.append(nn.BatchNorm1d(hidden_dim)) 324 | layers.append(nn.GELU()) 325 | for _ in range(nlayers - 2): 326 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 327 | if use_bn: 328 | layers.append(nn.BatchNorm1d(hidden_dim)) 329 | layers.append(nn.GELU()) 330 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 331 | self.mlp = nn.Sequential(*layers) 332 | self.apply(self._init_weights) 333 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 334 | self.last_layer.weight_g.data.fill_(1) 335 | if norm_last_layer: 336 | self.last_layer.weight_g.requires_grad = False 337 | 338 | def _init_weights(self, m): 339 | if isinstance(m, nn.Linear): 340 | trunc_normal_(m.weight, std=.02) 341 | if isinstance(m, nn.Linear) and m.bias is not None: 342 | nn.init.constant_(m.bias, 0) 343 | 344 | def forward(self, x): 345 | x = self.mlp(x) 346 | x = nn.functional.normalize(x, dim=-1, p=2) 347 | x = self.last_layer(x) 348 | return x 349 | -------------------------------------------------------------------------------- /dino_model/visualize_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import cv2 18 | import random 19 | import colorsys 20 | import requests 21 | from io import BytesIO 22 | 23 | import skimage.io 24 | from skimage.measure import find_contours 25 | import matplotlib.pyplot as plt 26 | from matplotlib.patches import Polygon 27 | import torch 28 | import torch.nn as nn 29 | import torchvision 30 | from torchvision import transforms as pth_transforms 31 | import numpy as np 32 | from PIL import Image 33 | 34 | import utils 35 | import vision_transformer as vits 36 | 37 | 38 | def apply_mask(image, mask, color, alpha=0.5): 39 | for c in range(3): 40 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 41 | return image 42 | 43 | 44 | def random_colors(N, bright=True): 45 | """ 46 | Generate random colors. 47 | """ 48 | brightness = 1.0 if bright else 0.7 49 | hsv = [(i / N, 1, brightness) for i in range(N)] 50 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 51 | random.shuffle(colors) 52 | return colors 53 | 54 | 55 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): 56 | fig = plt.figure(figsize=figsize, frameon=False) 57 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 58 | ax.set_axis_off() 59 | fig.add_axes(ax) 60 | ax = plt.gca() 61 | 62 | N = 1 63 | mask = mask[None, :, :] 64 | # Generate random colors 65 | colors = random_colors(N) 66 | 67 | # Show area outside image boundaries. 68 | height, width = image.shape[:2] 69 | margin = 0 70 | ax.set_ylim(height + margin, -margin) 71 | ax.set_xlim(-margin, width + margin) 72 | ax.axis('off') 73 | masked_image = image.astype(np.uint32).copy() 74 | for i in range(N): 75 | color = colors[i] 76 | _mask = mask[i] 77 | if blur: 78 | _mask = cv2.blur(_mask,(10,10)) 79 | # Mask 80 | masked_image = apply_mask(masked_image, _mask, color, alpha) 81 | # Mask Polygon 82 | # Pad to ensure proper polygons for masks that touch image edges. 83 | if contour: 84 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) 85 | padded_mask[1:-1, 1:-1] = _mask 86 | contours = find_contours(padded_mask, 0.5) 87 | for verts in contours: 88 | # Subtract the padding and flip (y, x) to (x, y) 89 | verts = np.fliplr(verts) - 1 90 | p = Polygon(verts, facecolor="none", edgecolor=color) 91 | ax.add_patch(p) 92 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 93 | fig.savefig(fname) 94 | print(f"{fname} saved.") 95 | return 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser('Visualize Self-Attention maps') 100 | parser.add_argument('--arch', default='deit_small', type=str, 101 | choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).') 102 | parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.') 103 | parser.add_argument('--pretrained_weights', default='', type=str, 104 | help="Path to pretrained weights to load.") 105 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 106 | help='Key to use in the checkpoint (example: "teacher")') 107 | parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.") 108 | parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.') 109 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks 110 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 111 | args = parser.parse_args() 112 | 113 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 114 | # build model 115 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 116 | for p in model.parameters(): 117 | p.requires_grad = False 118 | model.eval() 119 | model.to(device) 120 | if os.path.isfile(args.pretrained_weights): 121 | state_dict = torch.load(args.pretrained_weights, map_location="cpu") 122 | if args.checkpoint_key is not None and args.checkpoint_key in state_dict: 123 | print(f"Take key {args.checkpoint_key} in provided checkpoint dict") 124 | state_dict = state_dict[args.checkpoint_key] 125 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 126 | msg = model.load_state_dict(state_dict, strict=False) 127 | print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) 128 | else: 129 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 130 | url = None 131 | if args.arch == "deit_small" and args.patch_size == 16: 132 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 133 | elif args.arch == "deit_small" and args.patch_size == 8: 134 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper 135 | elif args.arch == "vit_base" and args.patch_size == 16: 136 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 137 | elif args.arch == "vit_base" and args.patch_size == 8: 138 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 139 | if url is not None: 140 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 141 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 142 | model.load_state_dict(state_dict, strict=True) 143 | else: 144 | print("There is no reference weights available for this model => We use random weights.") 145 | 146 | # open image 147 | if args.image_path is None: 148 | # user has not specified any image - we use our own image 149 | print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.") 150 | print("Since no image path have been provided, we take the first image in our paper.") 151 | response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") 152 | img = Image.open(BytesIO(response.content)) 153 | img = img.convert('RGB') 154 | elif os.path.isfile(args.image_path): 155 | with open(args.image_path, 'rb') as f: 156 | img = Image.open(f) 157 | img = img.convert('RGB') 158 | else: 159 | print(f"Provided image path {args.image_path} is non valid.") 160 | sys.exit(1) 161 | transform = pth_transforms.Compose([ 162 | pth_transforms.ToTensor(), 163 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 164 | ]) 165 | img = transform(img) 166 | 167 | # make the image divisible by the patch size 168 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size 169 | img = img[:, :w, :h].unsqueeze(0) 170 | 171 | w_featmap = img.shape[-2] // args.patch_size 172 | h_featmap = img.shape[-1] // args.patch_size 173 | 174 | attentions = model.forward_selfattention(img.to(device)) 175 | 176 | nh = attentions.shape[1] # number of head 177 | 178 | # we keep only the output patch attention 179 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 180 | 181 | # we keep only a certain percentage of the mass 182 | val, idx = torch.sort(attentions) 183 | val /= torch.sum(val, dim=1, keepdim=True) 184 | cumval = torch.cumsum(val, dim=1) 185 | th_attn = cumval > (1 - args.threshold) 186 | idx2 = torch.argsort(idx) 187 | for head in range(nh): 188 | th_attn[head] = th_attn[head][idx2[head]] 189 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 190 | # interpolate 191 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 192 | 193 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 194 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 195 | 196 | # save attentions heatmaps 197 | os.makedirs(args.output_dir, exist_ok=True) 198 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png")) 199 | for j in range(nh): 200 | fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") 201 | plt.imsave(fname=fname, arr=attentions[j], format='png') 202 | print(f"{fname} saved.") 203 | 204 | image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) 205 | for j in range(nh): 206 | display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False) 207 | -------------------------------------------------------------------------------- /images/kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/kitti.png -------------------------------------------------------------------------------- /images/s3dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/s3dis.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerminsamet/seedal/52fdeeb42cf83b4b65814ae8b81efb2f7d61bb07/images/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from datasets.s3dis import S3DIS 2 | from datasets.semantic_kitti import SK 3 | from optimization import apply_optimization 4 | from importlib import import_module 5 | import argparse 6 | 7 | def main(args): 8 | 9 | config = import_module(f'datasets.{args.dataset}_config') 10 | 11 | if args.dataset == 's3dis': 12 | dataset_instance = dataset_instance = S3DIS(config) 13 | elif args.dataset == 'semantic_kitti': 14 | dataset_instance = dataset_instance = SK(config) 15 | else: 16 | raise NotImplementedError("Only support for s3dis and semantic kitti dataset!") 17 | 18 | dataset_instance.extract_feature_vecs() 19 | dataset_instance.extract_data_stats() 20 | dataset_instance.extract_scene_clusters() 21 | dataset_instance.extract_scene_attributes() 22 | dataset_instance.extract_pairwise_scene_attributes() 23 | 24 | all_pairs, data_stats, pair_scores, reduction_size, area_threshold = dataset_instance.prepare_data_for_optimization() 25 | selected_scenes = apply_optimization(all_pairs, data_stats, pair_scores, reduction_size, area_threshold) 26 | 27 | dataset_instance.create_initial_set(selected_scenes) 28 | 29 | if __name__ == '__main__': 30 | 31 | parser = argparse.ArgumentParser(description='SeedAL framework.') 32 | parser.add_argument('-d', '--dataset', choices=['s3dis', 'semantic_kitti'], default='semantic_kitti') 33 | 34 | args = parser.parse_args() 35 | main(args) 36 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | from ortools.linear_solver import pywraplp 2 | import numpy as np 3 | from ortools.sat.python import cp_model 4 | 5 | 6 | def apply_optimization(all_pairs, data_stats, pair_scores, reduction_size, area_threshold): 7 | 8 | selected_indexes = lineer_solver(pair_scores, edge_number=reduction_size) 9 | ls_scans = [all_pairs[ind] for ind in selected_indexes] 10 | 11 | unique_scans = [] 12 | for scan_t in ls_scans: 13 | if scan_t[0] not in unique_scans: 14 | unique_scans.append(scan_t[0]) 15 | if scan_t[1] not in unique_scans: 16 | unique_scans.append(scan_t[1]) 17 | 18 | node_len = len(unique_scans) 19 | sub_graph_areas = [] 20 | sub_graph_scores = [] 21 | sub_graph_pairs = [] 22 | sub_aff_mat = np.zeros((node_len,node_len)) 23 | for ind, us in enumerate(unique_scans): 24 | for ind2 in range(ind+1,len(unique_scans)): 25 | if (unique_scans[ind], unique_scans[ind2]) in all_pairs: 26 | ii = all_pairs.index((unique_scans[ind], unique_scans[ind2])) 27 | else: 28 | ii = all_pairs.index((unique_scans[ind2], unique_scans[ind])) 29 | sc = pair_scores[ii] 30 | sub_graph_scores.append(sc) 31 | sub_aff_mat[ind, ind2] = sc 32 | sub_graph_pairs.append(all_pairs[ii]) 33 | 34 | for us in unique_scans: 35 | sub_graph_areas.append(data_stats[us]['area']) 36 | 37 | selected_edge_indexes, selected_node_indexes =\ 38 | lineer_solver_by_node_weight(sub_aff_mat, sub_graph_areas, area_threshold=area_threshold) 39 | 40 | total_area = 0 41 | final_scenes = [] 42 | for i in selected_node_indexes: 43 | print(unique_scans[i]) 44 | final_scenes.append(unique_scans[i]) 45 | total_area += data_stats[unique_scans[i]]['area'] 46 | 47 | print(final_scenes) 48 | print(total_area) 49 | 50 | return final_scenes 51 | 52 | 53 | def create_reduction_data_model(edges, edge_number): 54 | data = {} 55 | edges = [np.float64(i) for i in edges] 56 | num_var = len(edges) 57 | num_of_xs = [1.] * num_var 58 | 59 | # constraits on sum of pairwise areas 60 | data['constraint_coeffs'] = [ 61 | num_of_xs 62 | ] 63 | data['bounds'] = [edge_number] 64 | 65 | data['obj_coeffs'] = edges # similarity edges 66 | 67 | data['num_vars'] = num_var 68 | data['num_constraints'] = 1 69 | 70 | return data 71 | 72 | 73 | def create_linear_data_model(edge_weights, node_weights, area_threshold): 74 | data = {} 75 | data['edge_weights'] = edge_weights 76 | data['node_weights'] = node_weights 77 | data['area_threshold'] = area_threshold 78 | data['len'] = len(data['node_weights']) 79 | return data 80 | 81 | 82 | def lineer_solver(edges, edge_number): 83 | 84 | data = create_reduction_data_model(edges, edge_number) 85 | # Create the mip solver with the SCIP backend. 86 | solver = pywraplp.Solver.CreateSolver('SCIP') 87 | if not solver: 88 | return 89 | 90 | # infinity = solver.infinity() 91 | x = {} 92 | for j in range(data['num_vars']): 93 | x[j] = solver.BoolVar('x[%i]' % j) 94 | 95 | print('Number of variables =', solver.NumVariables()) 96 | 97 | for i in range(data['num_constraints']): 98 | constraint = solver.RowConstraint(0, data['bounds'][i], '') 99 | for j in range(data['num_vars']): 100 | constraint.SetCoefficient(x[j], data['constraint_coeffs'][i][j]) 101 | print('Number of constraints =', solver.NumConstraints()) 102 | 103 | objective = solver.Objective() 104 | for j in range(data['num_vars']): 105 | objective.SetCoefficient(x[j], data['obj_coeffs'][j]) 106 | objective.SetMaximization() 107 | 108 | status = solver.Solve() 109 | selected_indexes= [] 110 | if status == pywraplp.Solver.OPTIMAL: 111 | print('Objective value =', solver.Objective().Value()) 112 | for j in range(data['num_vars']): 113 | if x[j].solution_value() > 0.0: 114 | print(x[j].name(), ' = ', x[j].solution_value()) 115 | selected_indexes.append(j) 116 | print() 117 | print('Problem solved in %f milliseconds' % solver.wall_time()) 118 | print('Problem solved in %d iterations' % solver.iterations()) 119 | print('Problem solved in %d branch-and-bound nodes' % solver.nodes()) 120 | else: 121 | print('The problem does not have an optimal solution.') 122 | 123 | return selected_indexes 124 | 125 | 126 | def lineer_solver_by_node_weight(edge_weights, node_weights, area_threshold): 127 | 128 | data = create_linear_data_model(edge_weights,node_weights,area_threshold) 129 | solver = pywraplp.Solver.CreateSolver('SCIP') 130 | 131 | node_len = data['len'] 132 | if not solver: 133 | return 134 | 135 | # Variables 136 | y = {} 137 | for i in range(node_len): 138 | for j in range(i+1, node_len): 139 | y[(i, j)] = solver.BoolVar(f'y_{i}_{j}') 140 | 141 | x = {} 142 | for j in range(node_len): 143 | x[j] = solver.BoolVar(f'x[{j}]') 144 | 145 | # Constraints 146 | for i in range(node_len): 147 | for j in range(i+1, node_len): 148 | solver.Add(y[(i, j)] <= x[i]) 149 | solver.Add(y[(i, j)] <= x[j]) 150 | 151 | solver.Add(sum(x[i] * data['node_weights'][i] for i in range(node_len)) <= data['area_threshold']) 152 | 153 | solver.Maximize(solver.Sum([y[i,j]*data['edge_weights'][i,j] for i in range(node_len) for j in range(i+1,node_len)])) 154 | 155 | status = solver.Solve() 156 | 157 | selected_node_indexes = [] 158 | selected_edge_indexes = [] 159 | if status == pywraplp.Solver.OPTIMAL or status == cp_model.FEASIBLE: 160 | print('Objective value =', solver.Objective().Value()) 161 | for j in range(node_len): 162 | if x[j].solution_value() > 0.0: 163 | print(x[j].name(), ' = ', x[j].solution_value()) 164 | selected_node_indexes.append(j) 165 | 166 | for i in range(node_len): 167 | for j in range(i+1, node_len): 168 | if y[i,j].solution_value() > 0.0: 169 | print(y[i,j].name(), ' = ', y[i,j].solution_value()) 170 | selected_edge_indexes.append((i,j)) 171 | 172 | print() 173 | print('Problem solved in %f milliseconds' % solver.wall_time()) 174 | print('Problem solved in %d iterations' % solver.iterations()) 175 | print('Problem solved in %d branch-and-bound nodes' % solver.nodes()) 176 | print('Number of bins used:') 177 | 178 | print('Time = ', solver.WallTime(), ' milliseconds') 179 | else: 180 | print('The problem does not have an optimal solution.') 181 | 182 | return selected_edge_indexes, selected_node_indexes 183 | 184 | 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | importlib 2 | numpy 3 | pillow 4 | faiss 5 | ortools 6 | pickle5 7 | scikit-learn 8 | scipy -------------------------------------------------------------------------------- /s3dis_seed/info.txt: -------------------------------------------------------------------------------- 1 | Region Num: 859 and Point Num: 6470008 in the current set -------------------------------------------------------------------------------- /s3dis_seed/init_label_region.json: -------------------------------------------------------------------------------- 1 | {"Area_2#storage_5": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46],"Area_2#office_9": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67],"Area_2#storage_6": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],"Area_2#office_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52],"Area_2#hallway_11": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39],"Area_2#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],"Area_2#storage_3": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],"Area_2#storage_7": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56],"Area_3#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],"Area_4#office_14": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39],"Area_4#storage_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26],"Area_4#storage_3": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52],"Area_4#storage_2": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],"Area_4#storage_4": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],"Area_6#office_10": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97],"Area_6#office_12": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96],"Area_6#office_15": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]} -------------------------------------------------------------------------------- /s3dis_seed/init_label_scan.json: -------------------------------------------------------------------------------- 1 | [ 2 | "Area_2/coords/storage_5.npy", 3 | "Area_2/coords/office_9.npy", 4 | "Area_2/coords/storage_6.npy", 5 | "Area_2/coords/office_1.npy", 6 | "Area_2/coords/hallway_11.npy", 7 | "Area_2/coords/storage_1.npy", 8 | "Area_2/coords/storage_3.npy", 9 | "Area_2/coords/storage_7.npy", 10 | "Area_3/coords/storage_1.npy", 11 | "Area_4/coords/office_14.npy", 12 | "Area_4/coords/storage_1.npy", 13 | "Area_4/coords/storage_3.npy", 14 | "Area_4/coords/storage_2.npy", 15 | "Area_4/coords/storage_4.npy", 16 | "Area_6/coords/office_10.npy", 17 | "Area_6/coords/office_12.npy", 18 | "Area_6/coords/office_15.npy" 19 | ] 20 | -------------------------------------------------------------------------------- /s3dis_seed/init_ulabel_scan.json: -------------------------------------------------------------------------------- 1 | [ 2 | "Area_1/coords/hallway_6.npy", 3 | "Area_1/coords/office_10.npy", 4 | "Area_1/coords/office_21.npy", 5 | "Area_1/coords/hallway_3.npy", 6 | "Area_1/coords/office_23.npy", 7 | "Area_1/coords/hallway_7.npy", 8 | "Area_1/coords/office_22.npy", 9 | "Area_1/coords/WC_1.npy", 10 | "Area_1/coords/office_9.npy", 11 | "Area_1/coords/office_29.npy", 12 | "Area_1/coords/office_6.npy", 13 | "Area_1/coords/hallway_2.npy", 14 | "Area_1/coords/copyRoom_1.npy", 15 | "Area_1/coords/office_12.npy", 16 | "Area_1/coords/office_17.npy", 17 | "Area_1/coords/hallway_8.npy", 18 | "Area_1/coords/hallway_5.npy", 19 | "Area_1/coords/office_8.npy", 20 | "Area_1/coords/conferenceRoom_2.npy", 21 | "Area_1/coords/office_13.npy", 22 | "Area_1/coords/office_1.npy", 23 | "Area_1/coords/office_27.npy", 24 | "Area_1/coords/office_16.npy", 25 | "Area_1/coords/office_24.npy", 26 | "Area_1/coords/office_14.npy", 27 | "Area_1/coords/office_15.npy", 28 | "Area_1/coords/office_11.npy", 29 | "Area_1/coords/office_28.npy", 30 | "Area_1/coords/office_2.npy", 31 | "Area_1/coords/office_18.npy", 32 | "Area_1/coords/hallway_1.npy", 33 | "Area_1/coords/office_31.npy", 34 | "Area_1/coords/office_25.npy", 35 | "Area_1/coords/office_5.npy", 36 | "Area_1/coords/conferenceRoom_1.npy", 37 | "Area_1/coords/office_19.npy", 38 | "Area_1/coords/pantry_1.npy", 39 | "Area_1/coords/office_30.npy", 40 | "Area_1/coords/hallway_4.npy", 41 | "Area_1/coords/office_26.npy", 42 | "Area_1/coords/office_7.npy", 43 | "Area_1/coords/office_20.npy", 44 | "Area_1/coords/office_3.npy", 45 | "Area_1/coords/office_4.npy", 46 | "Area_2/coords/hallway_6.npy", 47 | "Area_2/coords/office_10.npy", 48 | "Area_2/coords/hallway_3.npy", 49 | "Area_2/coords/hallway_7.npy", 50 | "Area_2/coords/WC_1.npy", 51 | "Area_2/coords/hallway_9.npy", 52 | "Area_2/coords/office_6.npy", 53 | "Area_2/coords/hallway_2.npy", 54 | "Area_2/coords/office_12.npy", 55 | "Area_2/coords/hallway_8.npy", 56 | "Area_2/coords/hallway_5.npy", 57 | "Area_2/coords/office_8.npy", 58 | "Area_2/coords/office_13.npy", 59 | "Area_2/coords/storage_9.npy", 60 | "Area_2/coords/hallway_10.npy", 61 | "Area_2/coords/storage_8.npy", 62 | "Area_2/coords/office_14.npy", 63 | "Area_2/coords/auditorium_2.npy", 64 | "Area_2/coords/office_11.npy", 65 | "Area_2/coords/office_2.npy", 66 | "Area_2/coords/hallway_1.npy", 67 | "Area_2/coords/office_5.npy", 68 | "Area_2/coords/storage_2.npy", 69 | "Area_2/coords/conferenceRoom_1.npy", 70 | "Area_2/coords/WC_2.npy", 71 | "Area_2/coords/auditorium_1.npy", 72 | "Area_2/coords/hallway_12.npy", 73 | "Area_2/coords/hallway_4.npy", 74 | "Area_2/coords/office_7.npy", 75 | "Area_2/coords/office_3.npy", 76 | "Area_2/coords/office_4.npy", 77 | "Area_2/coords/storage_4.npy", 78 | "Area_3/coords/hallway_6.npy", 79 | "Area_3/coords/office_10.npy", 80 | "Area_3/coords/hallway_3.npy", 81 | "Area_3/coords/WC_1.npy", 82 | "Area_3/coords/office_9.npy", 83 | "Area_3/coords/office_6.npy", 84 | "Area_3/coords/hallway_2.npy", 85 | "Area_3/coords/hallway_5.npy", 86 | "Area_3/coords/office_8.npy", 87 | "Area_3/coords/office_1.npy", 88 | "Area_3/coords/office_2.npy", 89 | "Area_3/coords/hallway_1.npy", 90 | "Area_3/coords/office_5.npy", 91 | "Area_3/coords/storage_2.npy", 92 | "Area_3/coords/conferenceRoom_1.npy", 93 | "Area_3/coords/WC_2.npy", 94 | "Area_3/coords/lounge_2.npy", 95 | "Area_3/coords/hallway_4.npy", 96 | "Area_3/coords/office_7.npy", 97 | "Area_3/coords/lounge_1.npy", 98 | "Area_3/coords/office_3.npy", 99 | "Area_3/coords/office_4.npy", 100 | "Area_4/coords/hallway_6.npy", 101 | "Area_4/coords/hallway_14.npy", 102 | "Area_4/coords/office_10.npy", 103 | "Area_4/coords/office_21.npy", 104 | "Area_4/coords/hallway_3.npy", 105 | "Area_4/coords/hallway_7.npy", 106 | "Area_4/coords/office_22.npy", 107 | "Area_4/coords/WC_1.npy", 108 | "Area_4/coords/hallway_9.npy", 109 | "Area_4/coords/office_9.npy", 110 | "Area_4/coords/office_6.npy", 111 | "Area_4/coords/hallway_2.npy", 112 | "Area_4/coords/office_12.npy", 113 | "Area_4/coords/office_17.npy", 114 | "Area_4/coords/hallway_8.npy", 115 | "Area_4/coords/hallway_5.npy", 116 | "Area_4/coords/office_8.npy", 117 | "Area_4/coords/conferenceRoom_2.npy", 118 | "Area_4/coords/office_13.npy", 119 | "Area_4/coords/lobby_1.npy", 120 | "Area_4/coords/office_1.npy", 121 | "Area_4/coords/office_16.npy", 122 | "Area_4/coords/hallway_10.npy", 123 | "Area_4/coords/hallway_11.npy", 124 | "Area_4/coords/office_15.npy", 125 | "Area_4/coords/office_11.npy", 126 | "Area_4/coords/conferenceRoom_3.npy", 127 | "Area_4/coords/office_2.npy", 128 | "Area_4/coords/office_18.npy", 129 | "Area_4/coords/hallway_1.npy", 130 | "Area_4/coords/office_5.npy", 131 | "Area_4/coords/conferenceRoom_1.npy", 132 | "Area_4/coords/WC_2.npy", 133 | "Area_4/coords/office_19.npy", 134 | "Area_4/coords/lobby_2.npy", 135 | "Area_4/coords/hallway_12.npy", 136 | "Area_4/coords/WC_4.npy", 137 | "Area_4/coords/hallway_4.npy", 138 | "Area_4/coords/office_7.npy", 139 | "Area_4/coords/office_20.npy", 140 | "Area_4/coords/office_3.npy", 141 | "Area_4/coords/office_4.npy", 142 | "Area_4/coords/hallway_13.npy", 143 | "Area_4/coords/WC_3.npy", 144 | "Area_6/coords/hallway_6.npy", 145 | "Area_6/coords/office_21.npy", 146 | "Area_6/coords/hallway_3.npy", 147 | "Area_6/coords/office_23.npy", 148 | "Area_6/coords/office_22.npy", 149 | "Area_6/coords/office_9.npy", 150 | "Area_6/coords/office_35.npy", 151 | "Area_6/coords/office_29.npy", 152 | "Area_6/coords/office_32.npy", 153 | "Area_6/coords/office_6.npy", 154 | "Area_6/coords/hallway_2.npy", 155 | "Area_6/coords/copyRoom_1.npy", 156 | "Area_6/coords/office_17.npy", 157 | "Area_6/coords/hallway_5.npy", 158 | "Area_6/coords/office_8.npy", 159 | "Area_6/coords/office_13.npy", 160 | "Area_6/coords/office_1.npy", 161 | "Area_6/coords/office_27.npy", 162 | "Area_6/coords/office_16.npy", 163 | "Area_6/coords/office_24.npy", 164 | "Area_6/coords/office_36.npy", 165 | "Area_6/coords/office_14.npy", 166 | "Area_6/coords/office_34.npy", 167 | "Area_6/coords/office_11.npy", 168 | "Area_6/coords/office_28.npy", 169 | "Area_6/coords/office_2.npy", 170 | "Area_6/coords/office_18.npy", 171 | "Area_6/coords/hallway_1.npy", 172 | "Area_6/coords/office_31.npy", 173 | "Area_6/coords/office_25.npy", 174 | "Area_6/coords/office_5.npy", 175 | "Area_6/coords/conferenceRoom_1.npy", 176 | "Area_6/coords/office_19.npy", 177 | "Area_6/coords/pantry_1.npy", 178 | "Area_6/coords/openspace_1.npy", 179 | "Area_6/coords/office_30.npy", 180 | "Area_6/coords/hallway_4.npy", 181 | "Area_6/coords/office_37.npy", 182 | "Area_6/coords/office_33.npy", 183 | "Area_6/coords/office_26.npy", 184 | "Area_6/coords/office_7.npy", 185 | "Area_6/coords/lounge_1.npy", 186 | "Area_6/coords/office_20.npy", 187 | "Area_6/coords/office_3.npy", 188 | "Area_6/coords/office_4.npy" 189 | ] 190 | -------------------------------------------------------------------------------- /similarity_model.py: -------------------------------------------------------------------------------- 1 | from dino_model.fe_dino import DinoModel 2 | import numpy as np 3 | from PIL import Image 4 | import faiss 5 | 6 | 7 | class SimilarityModel: 8 | def __init__(self, sim_model_name='dino', feat_type='cls'): 9 | self.sim_model_name = sim_model_name 10 | self.feat_type = feat_type 11 | self.load_sim_model() 12 | self.d = None 13 | 14 | def get_sim_vec_single(self, im_name_or_data): 15 | # Prepare an image 16 | image = self.load_image(im_name_or_data) 17 | if image is None: 18 | print('No Image') 19 | return np.zeros((1, self.d), dtype=np.float32) 20 | # test image 21 | feature_vec = self.sim_model(image).cpu().data.numpy().astype(np.float32) # convert to numpy array 22 | return feature_vec 23 | 24 | def load_sim_model(self): 25 | if self.sim_model_name == 'dino': 26 | self.sim_model = self.load_sim_model_dino() 27 | self.d = 768 28 | else: 29 | raise Exception("sim_model_name must be dino!") 30 | 31 | def load_sim_model_dino(self): 32 | # Build models 33 | sim_model = DinoModel(self.feat_type) # eval mode (batch norm uses moving mean/variance) 34 | return sim_model 35 | 36 | def set_sim_model_feat_type_dino(self, feat_type): 37 | self.feat_type = feat_type 38 | self.sim_model.feat_type = self.feat_type 39 | 40 | def load_image(self, image_path_or_data): 41 | try: 42 | if isinstance(image_path_or_data, str): 43 | img = Image.open(image_path_or_data) 44 | image = img.convert('RGB') 45 | elif isinstance(image_path_or_data, Image.Image): 46 | image = image_path_or_data.convert('RGB') 47 | else: 48 | raise Exception("image type must be str or PIL.Image!") 49 | 50 | return image 51 | 52 | except Exception as e: 53 | return None 54 | 55 | def calculate_dino_aff_matrix_from_feats(self, feats): 56 | curr_dim = len(feats) 57 | faiss.normalize_L2(feats) 58 | index = faiss.IndexFlatIP(feats.shape[1]) 59 | index.add(feats) # add vectors to the index 60 | lims, D, I = index.range_search(feats, -1) 61 | ordered_D = np.reshape(D, (curr_dim, curr_dim)) 62 | return ordered_D 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /sk_seed/init_label_scan.json: -------------------------------------------------------------------------------- 1 | [ 2 | "05/velodyne/001429.bin", 3 | "06/velodyne/000711.bin", 4 | "01/velodyne/000876.bin", 5 | "05/velodyne/001092.bin", 6 | "02/velodyne/002455.bin", 7 | "09/velodyne/000415.bin", 8 | "06/velodyne/000175.bin", 9 | "05/velodyne/001679.bin", 10 | "00/velodyne/002698.bin", 11 | "00/velodyne/000988.bin", 12 | "02/velodyne/003218.bin", 13 | "00/velodyne/000932.bin", 14 | "06/velodyne/000683.bin", 15 | "05/velodyne/002075.bin", 16 | "06/velodyne/000078.bin", 17 | "09/velodyne/000855.bin", 18 | "00/velodyne/000440.bin", 19 | "10/velodyne/000178.bin", 20 | "07/velodyne/000585.bin", 21 | "00/velodyne/000397.bin", 22 | "00/velodyne/001569.bin", 23 | "00/velodyne/001872.bin", 24 | "06/velodyne/000877.bin", 25 | "10/velodyne/000363.bin", 26 | "00/velodyne/004369.bin", 27 | "05/velodyne/002224.bin", 28 | "05/velodyne/001342.bin", 29 | "06/velodyne/000064.bin", 30 | "05/velodyne/001075.bin", 31 | "05/velodyne/001499.bin", 32 | "05/velodyne/001979.bin", 33 | "10/velodyne/000283.bin", 34 | "06/velodyne/000263.bin", 35 | "09/velodyne/001225.bin", 36 | "07/velodyne/000640.bin", 37 | "05/velodyne/002395.bin", 38 | "00/velodyne/001824.bin", 39 | "03/velodyne/000662.bin", 40 | "00/velodyne/003118.bin", 41 | "06/velodyne/001007.bin", 42 | "06/velodyne/000504.bin", 43 | "05/velodyne/002215.bin", 44 | "09/velodyne/000910.bin", 45 | "00/velodyne/001059.bin", 46 | "06/velodyne/000492.bin", 47 | "06/velodyne/000323.bin", 48 | "07/velodyne/000177.bin", 49 | "06/velodyne/000898.bin", 50 | "07/velodyne/000649.bin", 51 | "10/velodyne/000981.bin", 52 | "05/velodyne/002405.bin", 53 | "07/velodyne/000401.bin", 54 | "01/velodyne/000376.bin", 55 | "09/velodyne/001127.bin", 56 | "05/velodyne/002298.bin", 57 | "02/velodyne/002078.bin", 58 | "05/velodyne/002417.bin", 59 | "05/velodyne/001193.bin", 60 | "05/velodyne/002630.bin", 61 | "05/velodyne/000396.bin", 62 | "01/velodyne/000006.bin", 63 | "06/velodyne/000892.bin", 64 | "10/velodyne/000165.bin", 65 | "05/velodyne/002623.bin", 66 | "06/velodyne/000198.bin", 67 | "06/velodyne/001042.bin", 68 | "10/velodyne/000495.bin", 69 | "00/velodyne/003240.bin", 70 | "00/velodyne/003262.bin", 71 | "09/velodyne/001099.bin", 72 | "09/velodyne/000212.bin", 73 | "00/velodyne/003003.bin", 74 | "02/velodyne/002556.bin", 75 | "05/velodyne/001957.bin", 76 | "05/velodyne/001318.bin", 77 | "06/velodyne/000069.bin", 78 | "05/velodyne/001728.bin", 79 | "02/velodyne/002546.bin", 80 | "07/velodyne/000917.bin", 81 | "05/velodyne/002106.bin", 82 | "05/velodyne/002353.bin", 83 | "06/velodyne/000149.bin", 84 | "00/velodyne/001262.bin", 85 | "04/velodyne/000148.bin", 86 | "06/velodyne/000983.bin", 87 | "00/velodyne/004361.bin", 88 | "02/velodyne/002073.bin", 89 | "01/velodyne/000217.bin", 90 | "02/velodyne/000058.bin", 91 | "06/velodyne/000207.bin", 92 | "00/velodyne/002543.bin", 93 | "02/velodyne/002874.bin", 94 | "01/velodyne/001063.bin", 95 | "09/velodyne/000721.bin", 96 | "10/velodyne/000369.bin", 97 | "05/velodyne/000527.bin", 98 | "00/velodyne/000126.bin", 99 | "07/velodyne/000026.bin", 100 | "10/velodyne/000329.bin", 101 | "05/velodyne/000125.bin", 102 | "05/velodyne/000844.bin", 103 | "06/velodyne/000663.bin", 104 | "06/velodyne/000551.bin", 105 | "09/velodyne/000568.bin", 106 | "05/velodyne/001121.bin", 107 | "10/velodyne/001083.bin", 108 | "00/velodyne/000068.bin", 109 | "06/velodyne/000411.bin", 110 | "07/velodyne/000142.bin", 111 | "10/velodyne/001014.bin", 112 | "05/velodyne/002535.bin", 113 | "10/velodyne/000898.bin", 114 | "02/velodyne/001387.bin", 115 | "00/velodyne/000108.bin", 116 | "05/velodyne/001854.bin", 117 | "06/velodyne/001035.bin", 118 | "00/velodyne/003276.bin", 119 | "09/velodyne/001172.bin", 120 | "02/velodyne/001533.bin", 121 | "00/velodyne/002709.bin", 122 | "10/velodyne/000688.bin", 123 | "06/velodyne/000575.bin", 124 | "10/velodyne/000356.bin", 125 | "00/velodyne/004414.bin", 126 | "02/velodyne/001931.bin", 127 | "09/velodyne/000021.bin", 128 | "00/velodyne/002806.bin", 129 | "07/velodyne/000041.bin", 130 | "00/velodyne/002815.bin", 131 | "06/velodyne/000398.bin", 132 | "07/velodyne/000294.bin", 133 | "09/velodyne/001477.bin", 134 | "10/velodyne/000115.bin", 135 | "01/velodyne/000640.bin", 136 | "10/velodyne/000559.bin", 137 | "05/velodyne/000360.bin", 138 | "09/velodyne/001349.bin", 139 | "06/velodyne/000675.bin", 140 | "06/velodyne/000219.bin", 141 | "00/velodyne/003875.bin", 142 | "09/velodyne/001181.bin", 143 | "10/velodyne/000995.bin", 144 | "01/velodyne/000018.bin", 145 | "05/velodyne/002376.bin", 146 | "00/velodyne/002726.bin", 147 | "03/velodyne/000515.bin", 148 | "05/velodyne/002497.bin", 149 | "00/velodyne/002848.bin", 150 | "10/velodyne/000878.bin", 151 | "05/velodyne/001800.bin", 152 | "05/velodyne/000981.bin", 153 | "05/velodyne/000339.bin", 154 | "05/velodyne/002430.bin", 155 | "02/velodyne/002908.bin", 156 | "00/velodyne/001421.bin", 157 | "05/velodyne/000678.bin", 158 | "09/velodyne/001113.bin", 159 | "10/velodyne/000570.bin", 160 | "00/velodyne/003911.bin", 161 | "07/velodyne/000644.bin", 162 | "00/velodyne/004513.bin", 163 | "00/velodyne/003928.bin", 164 | "06/velodyne/000049.bin", 165 | "09/velodyne/000183.bin", 166 | "10/velodyne/000301.bin", 167 | "09/velodyne/000405.bin", 168 | "00/velodyne/002741.bin", 169 | "00/velodyne/001124.bin", 170 | "09/velodyne/000622.bin", 171 | "02/velodyne/004591.bin", 172 | "10/velodyne/000581.bin", 173 | "05/velodyne/002357.bin", 174 | "05/velodyne/001814.bin", 175 | "01/velodyne/000034.bin", 176 | "02/velodyne/000553.bin", 177 | "05/velodyne/000917.bin", 178 | "02/velodyne/004509.bin", 179 | "09/velodyne/000591.bin", 180 | "10/velodyne/000374.bin", 181 | "06/velodyne/001027.bin", 182 | "00/velodyne/003505.bin", 183 | "05/velodyne/000136.bin", 184 | "07/velodyne/000634.bin", 185 | "09/velodyne/000392.bin" 186 | ] 187 | --------------------------------------------------------------------------------