├── .gitattributes ├── .gitignore ├── README.md └── deepdive ├── feature_extraction.py ├── feature_reduction.py ├── mapping_methods.py ├── model_accuracy.csv ├── model_metadata.csv ├── model_metadata.ipynb ├── model_metadata.py ├── model_options.py ├── model_opts_utils.py ├── model_scores.csv ├── model_statistics ├── __init__.py ├── timm_imagenet_accuracies.csv └── torchvision_imagenet_accuracies.csv ├── model_typology.csv ├── model_typology.ipynb └── ridge_gcv_mod.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Specific Files # 2 | ################### 3 | github_rsync.sh 4 | 5 | # Data Files # 6 | ################### 7 | *testing/* 8 | *results/* 9 | 10 | # Checkpoints # 11 | ################### 12 | .ipynb_checkpoints/ 13 | __pycache__/ 14 | *.ckpt 15 | *.pyc 16 | *.pyo 17 | 18 | # Compiled source # 19 | ################### 20 | *.com 21 | *.class 22 | *.dll 23 | *.exe 24 | *.o 25 | *.so 26 | 27 | # Packages # 28 | ############ 29 | # it's better to unpack these files and commit the raw source 30 | # git has its own built in compression methods 31 | *.7z 32 | *.dmg 33 | *.gz 34 | *.iso 35 | *.jar 36 | *.rar 37 | *.tar 38 | *.zip 39 | 40 | # Logs and databases # 41 | ###################### 42 | *.log 43 | *.sql 44 | *.sqlite 45 | 46 | # OS generated files # 47 | ###################### 48 | .DS_Store 49 | .DS_Store? 50 | ._* 51 | .Spotlight-V100 52 | .Trashes 53 | ehthumbs.db 54 | Thumbs.db 55 | 56 | */.DS_Store 57 | **/.DS_Store 58 | .DS_Store 59 | .DS_Store? 60 | 61 | 62 | # Byte-compiled / optimized / DLL files 63 | __pycache__/ 64 | *.py[cod] 65 | *$py.class 66 | 67 | # C extensions 68 | *.so 69 | 70 | # Distribution / packaging 71 | .Python 72 | build/ 73 | develop-eggs/ 74 | dist/ 75 | downloads/ 76 | eggs/ 77 | .eggs/ 78 | lib/ 79 | lib64/ 80 | parts/ 81 | sdist/ 82 | var/ 83 | wheels/ 84 | pip-wheel-metadata/ 85 | share/python-wheels/ 86 | *.egg-info/ 87 | .installed.cfg 88 | *.egg 89 | MANIFEST 90 | 91 | # PyInstaller 92 | # Usually these files are written by a python script from a template 93 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 94 | *.manifest 95 | *.spec 96 | 97 | # Installer logs 98 | pip-log.txt 99 | pip-delete-this-directory.txt 100 | 101 | # Unit test / coverage reports 102 | htmlcov/ 103 | .tox/ 104 | .nox/ 105 | .coverage 106 | .coverage.* 107 | .cache 108 | nosetests.xml 109 | coverage.xml 110 | *.cover 111 | *.py,cover 112 | .hypothesis/ 113 | .pytest_cache/ 114 | 115 | # Translations 116 | *.mo 117 | *.pot 118 | 119 | # Django stuff: 120 | *.log 121 | local_settings.py 122 | db.sqlite3 123 | db.sqlite3-journal 124 | 125 | # Flask stuff: 126 | instance/ 127 | .webassets-cache 128 | 129 | # Scrapy stuff: 130 | .scrapy 131 | 132 | # Sphinx documentation 133 | docs/_build/ 134 | 135 | # PyBuilder 136 | target/ 137 | 138 | # Jupyter Notebook 139 | .ipynb_checkpoints 140 | 141 | # IPython 142 | profile_default/ 143 | ipython_config.py 144 | 145 | # pyenv 146 | .python-version 147 | 148 | # pipenv 149 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 150 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 151 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 152 | # install all needed dependencies. 153 | #Pipfile.lock 154 | 155 | # celery beat schedule file 156 | celerybeat-schedule 157 | 158 | # SageMath parsed files 159 | *.sage.py 160 | 161 | # Environments 162 | .env 163 | .venv 164 | env/ 165 | venv/ 166 | ENV/ 167 | env.bak/ 168 | venv.bak/ 169 | 170 | # Spyder project settings 171 | .spyderproject 172 | .spyproject 173 | 174 | # Rope project settings 175 | .ropeproject 176 | 177 | # mkdocs documentation 178 | /site 179 | 180 | # mypy 181 | .mypy_cache/ 182 | .dmypy.json 183 | dmypy.json 184 | 185 | # Pyre type checker 186 | .pyre/ 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDive (into Deep Neural Networks) 2 | 3 | Designed for deep net feature extraction, dimensionality reduction, and benchmarking, this repo contains a number of convenience functions for loading and instrumentalizing a variety of (PyTorch) models. Models available include those from: 4 | 5 | - the [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models) (Timm) library 6 | - the [Torchvision](https://pytorch.org/vision/stable/models.html) model zoo 7 | - the [Taskonomy](http://taskonomy.stanford.edu/) (visual_priors) project 8 | - the [VISSL](https://vissl.ai/) (self-supervised) model zoo 9 | - the [Detectron2](https://github.com/facebookresearch/detectron2) model zoo 10 | - ISL's [MiDas](https://github.com/isl-org/MiDaS) models, FaceBook's [DINO](https://github.com/facebookresearch/dino) models... 11 | 12 | Check out these repos that benchmark these models on [human fMRI](https://github.com/ColinConwell/DeepNSD) and [mouse optical physiology](https://github.com/ColinConwell/DeepMouseTrap) data. 13 | 14 | A tutorial that demonstrates the main functionality of this pipeline in both behavior and brains may be found [here](https://colab.research.google.com/drive/1CvOpeKL4xRDbHkpPXGlSDs-JyD-438vl#scrollTo=Jd9vyENcvsIg). 15 | 16 | This repository is a work in progress; please feel free to file any issues you find. 17 | 18 | If you find this repository useful, please consider citing the work that fueled its most recent development: 19 | 20 | ```bibtex 21 | @article{conwell2023pressures, 22 | title={What can 1.8 billion regressions tell us about the pressures shaping high-level visual representation in brains and machines}, 23 | author={Conwell, Colin and Prince, Jacob S and Kay, Kendrick N and Alvarez, George A and Konkle, Talia}, 24 | journal={bioRxiv}, 25 | year={2023} 26 | } 27 | ``` 28 | 29 | (Also remember to cite any of the specific models you use by referring to their original sources linked in the model_typology.csv file). 30 | 31 | ## 2024 Update: *DeepDive* to *DeepJuice* 32 | 33 | + **Squeezing your deep nets for science!** 34 | 35 | Recently, our team has been working on a new, highly-accelerated version of this codebase called **Deepjuice** -- effectively, a bottom-up reimplementation of all DeepDive functionalities that allows for end-to-end benchmarking (feature extraction, SRP, PCA, CKA, RSA, and regression) without ever removing data from the GPU. 36 | 37 | **DeepJuice** is currently in private beta, but if you're interested in trying out, please feel free to contact me (Colin Conwell) by email: conwell[at]g[dot]harvard[dot]edu 38 | -------------------------------------------------------------------------------- /deepdive/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from logging import warning 6 | from tqdm.auto import tqdm as tqdm 7 | from collections import defaultdict, OrderedDict 8 | 9 | from PIL import Image 10 | import torch.nn as nn 11 | import torch, torchvision 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.autograd import Variable 15 | 16 | from model_options import * 17 | 18 | def get_prepped_model(model_string): 19 | model_options = get_model_options() 20 | model_call = model_options[model_string]['call'] 21 | model = eval(model_call) 22 | model = model.eval() 23 | if torch.cuda.is_available(): 24 | model = model.cuda() 25 | 26 | return(model) 27 | 28 | def check_model(model_string, model = None): 29 | if not isinstance(model_string, str): 30 | model = model_string 31 | model_options = get_model_options() 32 | if model_string not in model_options and model == None: 33 | raise ValueError('model_string not available in prepped models. Please supply model object.') 34 | 35 | def prep_model_for_extraction(model, inputs = None): 36 | if model.training: 37 | model = model.eval() 38 | if not next(model.parameters()).is_cuda: 39 | if torch.cuda.is_available(): 40 | model = model.cuda() 41 | 42 | if inputs == None: 43 | return(model) 44 | 45 | if inputs != None: 46 | if next(model.parameters()).is_cuda: 47 | if isinstance(inputs, dict): 48 | inputs = {k:v.cuda() for k,v in inputs.items()} 49 | if not isinstance(inputs, dict): 50 | inputs = inputs.cuda() 51 | 52 | return(model, inputs) 53 | 54 | def convert_relu(parent): 55 | for child_name, child in parent.named_children(): 56 | if isinstance(child, nn.ReLU): 57 | setattr(parent, child_name, nn.ReLU(inplace=False)) 58 | elif len(list(child.children())) > 0: 59 | convert_relu(child) 60 | 61 | # Method 1: Flatten model; extract features by layer 62 | 63 | class SaveFeatures(): 64 | def __init__(self, module): 65 | self.hook = module.register_forward_hook(self.hook_fn) 66 | def hook_fn(self, module, input, output): 67 | self.out = output.clone().detach().requires_grad_(True).cuda() 68 | def close(self): 69 | self.hook.remove() 70 | def extract(self): 71 | return self.out 72 | 73 | def get_layer_names(layers): 74 | layer_names = [] 75 | for layer in layers: 76 | layer_name = str(layer).split('(')[0] 77 | layer_names.append(layer_name + '-' + str(sum(layer_name in string for string in layer_names) + 1)) 78 | return layer_names 79 | 80 | def get_features_by_layer(model, target_layer, img_tensor): 81 | model = prep_model_for_extraction() 82 | features = SaveFeatures(target_layer) 83 | model(img_tensor) 84 | features.close() 85 | return features.extract() 86 | 87 | # Method 2: Hook all layers simultaneously; remove duplicates 88 | 89 | def get_inputs_sample(inputs, n = 3): 90 | if isinstance(inputs, torch.Tensor): 91 | input_sample = inputs[:n] 92 | 93 | if isinstance(inputs, DataLoader): 94 | input_sample = next(iter(inputs))[:3] 95 | 96 | return input_sample 97 | 98 | def get_module_name(module, module_list): 99 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 100 | class_count = str(sum(class_name in module for module in module_list) + 1) 101 | 102 | return '-'.join([class_name, class_count]) 103 | 104 | def _get_feature_maps(model, inputs): 105 | model, inputs = prep_model_for_extraction(model, inputs) 106 | 107 | def register_hook(module): 108 | def hook(module, input, output): 109 | module_name = get_module_name(module, feature_maps) 110 | feature_maps[module_name] = output 111 | 112 | if not isinstance(module, nn.Sequential): 113 | if not isinstance(module, nn.ModuleList): 114 | hooks.append(module.register_forward_hook(hook)) 115 | 116 | if next(model.parameters()).is_cuda: 117 | inputs = inputs.cuda() 118 | 119 | feature_maps = OrderedDict() 120 | hooks = [] 121 | 122 | model.apply(register_hook) 123 | with torch.no_grad(): 124 | model(**inputs) if isinstance(inputs, dict) else model(inputs) 125 | 126 | for hook in hooks: 127 | hook.remove() 128 | 129 | return(feature_maps) 130 | 131 | def remove_duplicate_feature_maps(feature_maps, method = 'hashkey', return_matches = False, use_tqdm = False): 132 | matches, layer_names = [], list(feature_maps.keys()) 133 | 134 | if method == 'iterative': 135 | 136 | target_iterator = tqdm(range(len(layer_names)), leave = False) if use_tqdm else range(len(layer_names)) 137 | 138 | for i in target_iterator: 139 | for j in range(i+1,len(layer_names)): 140 | layer1 = feature_maps[layer_names[i]].flatten() 141 | layer2 = feature_maps[layer_names[j]].flatten() 142 | if layer1.shape == layer2.shape and torch.all(torch.eq(layer1,layer2)): 143 | if layer_names[j] not in matches: 144 | matches.append(layer_names[j]) 145 | 146 | deduplicated_feature_maps = {key:value for (key,value) in feature_maps.items() 147 | if key not in matches} 148 | 149 | if method == 'hashkey': 150 | 151 | target_iterator = tqdm(layer_names, leave = False) if use_tqdm else layer_names 152 | layer_lengths = [len(tensor.flatten()) for tensor in feature_maps.values()] 153 | random_tensor = torch.rand(np.array(layer_lengths).max()) 154 | 155 | tensor_dict = defaultdict(lambda:[]) 156 | for layer_name in target_iterator: 157 | target_tensor = feature_maps[layer_name].flatten() 158 | tensor_dot = torch.dot(target_tensor, random_tensor[:len(target_tensor)]) 159 | tensor_hash = np.array(tensor_dot).tobytes() 160 | tensor_dict[tensor_hash].append(layer_name) 161 | 162 | matches = [match for match in list(tensor_dict.values()) if len(match) > 1] 163 | layers_to_keep = [tensor_dict[tensor_hash][0] for tensor_hash in tensor_dict] 164 | 165 | deduplicated_feature_maps = {key:value for (key,value) in feature_maps.items() 166 | if key in layers_to_keep} 167 | 168 | if return_matches: 169 | return(deduplicated_feature_maps, matches) 170 | 171 | if not return_matches: 172 | return(deduplicated_feature_maps) 173 | 174 | def check_for_input_axis(feature_map, input_size): 175 | axis_match = [dim for dim in feature_map.shape if dim == input_size] 176 | return True if len(axis_match) == 1 else False 177 | 178 | def reset_input_axis(feature_map, input_size): 179 | input_axis = feature_map.shape.index(input_size) 180 | return torch.swapaxes(feature_map, 0, input_axis) 181 | 182 | def get_feature_maps(model, inputs, layers_to_retain = None, remove_duplicates = True, enforce_input_shape = True): 183 | 184 | model, inputs = prep_model_for_extraction(model, inputs) 185 | 186 | if layers_to_retain: 187 | if not isinstance(layers_to_retain, list): 188 | layers_to_retain = [layers_to_retain] 189 | 190 | def fix_outputs_shape(inputs, outputs, module_name): 191 | if len(outputs.shape) == 0: 192 | warning('Output in {} is empty. Skipping...'.format(module_name)) 193 | return None 194 | if enforce_input_shape: 195 | if outputs.shape[0] == inputs.shape[0]: 196 | return outputs 197 | if outputs.shape[0] != inputs.shape[0]: 198 | if check_for_input_axis(outputs, inputs.shape[0]): 199 | return reset_input_axis(outputs, inputs.shape[0]) 200 | if not check_for_input_axis(outputs, inputs.shape[0]): 201 | warning('Ambiguous input axis in {}. Skipping...'.format(module_name)) 202 | return None 203 | if not enforce_input_shape: 204 | return outputs 205 | 206 | def register_hook(module): 207 | def hook(module, input, output): 208 | def process_output(output, module_name): 209 | if layers_to_retain is None or module_name in layers_to_retain: 210 | if isinstance(output, torch.Tensor): 211 | outputs = output.cpu().detach().type(torch.FloatTensor) 212 | outputs = fix_outputs_shape(inputs, outputs, module_name) 213 | feature_maps[module_name] = outputs 214 | if layers_to_retain is not None and module_name not in layers_to_retain: 215 | feature_maps[module_name] = None 216 | 217 | module_name = get_module_name(module, feature_maps) 218 | 219 | if not any([isinstance(output, type_) for type_ in (tuple,list)]): 220 | process_output(output, module_name) 221 | 222 | if any([isinstance(output, type_) for type_ in (tuple,list)]): 223 | for output_i, output_ in enumerate(output): 224 | module_name_ = '-'.join([module_name, str(output_i+1)]) 225 | process_output(output_, module_name_) 226 | 227 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)): 228 | hooks.append(module.register_forward_hook(hook)) 229 | 230 | feature_maps = OrderedDict() 231 | hooks = [] 232 | 233 | model.apply(convert_relu) 234 | model.apply(register_hook) 235 | with torch.no_grad(): 236 | model(**inputs) if isinstance(inputs, dict) else model(inputs) 237 | 238 | for hook in hooks: 239 | hook.remove() 240 | 241 | feature_maps = {map:features for (map,features) in feature_maps.items() 242 | if features is not None} 243 | 244 | if remove_duplicates == True: 245 | feature_maps = remove_duplicate_feature_maps(feature_maps) 246 | 247 | return(feature_maps) 248 | 249 | def get_empty_feature_maps(model, inputs = None, input_size=(3,224,224), dataset_size=3, 250 | layers_to_retain = None, remove_duplicates = True, names_only=False): 251 | 252 | check_model(model) 253 | if isinstance(model, str): 254 | model = get_prepped_model(model) 255 | 256 | if inputs is not None: 257 | inputs = get_inputs_sample(inputs) 258 | 259 | if inputs is None: 260 | inputs = torch.rand(3, *input_size) 261 | 262 | empty_feature_maps = get_feature_maps(model, inputs, layers_to_retain, remove_duplicates) 263 | 264 | for map_key in empty_feature_maps: 265 | empty_feature_maps[map_key] = torch.empty(dataset_size, *empty_feature_maps[map_key].shape[1:]) 266 | 267 | if names_only == True: 268 | return list(empty_feature_maps.keys()) 269 | 270 | if names_only == False: 271 | return empty_feature_maps 272 | 273 | 274 | def get_all_feature_maps(model, inputs, layers_to_retain=None, remove_duplicates=True, 275 | include_input_space = False, flatten=True, numpy=True, use_tqdm = True): 276 | 277 | check_model(model) 278 | if isinstance(model, str): 279 | model = get_prepped_model(model) 280 | 281 | if isinstance(inputs, DataLoader): 282 | input_size, dataset_size, start_index = inputs.dataset[0].shape, len(inputs.dataset), 0 283 | feature_maps = get_empty_feature_maps(model, next(iter(inputs))[:3], input_size, 284 | dataset_size, layers_to_retain, remove_duplicates) 285 | 286 | if include_input_space: 287 | input_map = {'Input': torch.empty(dataset_size, *input_size)} 288 | feature_maps = {**input_map, **feature_maps} 289 | 290 | 291 | for imgs in tqdm(inputs, desc = 'Feature Extraction (Batch)') if use_tqdm else inputs: 292 | imgs = imgs.cuda() if next(model.parameters()).is_cuda else imgs 293 | batch_feature_maps = get_feature_maps(model, imgs, layers_to_retain, remove_duplicates = False) 294 | 295 | if include_input_space: 296 | batch_feature_maps['Input'] = imgs.cpu() 297 | 298 | for map_i, map_key in enumerate(feature_maps): 299 | feature_maps[map_key][start_index:start_index+imgs.shape[0],...] = batch_feature_maps[map_key] 300 | start_index += imgs.shape[0] 301 | 302 | if not isinstance(inputs, DataLoader): 303 | if isinstance(inputs, torch.Tensor): 304 | inputs = inputs.cuda() if next(model.parameters()).is_cuda else inputs 305 | feature_maps = get_feature_maps(model, inputs, layers_to_retain, remove_duplicates) 306 | 307 | if include_input_space: 308 | feature_maps = {**{'Input': inputs.cpu()}, **feature_maps} 309 | 310 | if remove_duplicates == True: 311 | feature_maps = remove_duplicate_feature_maps(feature_maps) 312 | 313 | if flatten == True: 314 | for map_key in feature_maps: 315 | incoming_map = feature_maps[map_key] 316 | feature_maps[map_key] = incoming_map.reshape(incoming_map.shape[0], -1) 317 | 318 | if numpy == True: 319 | for map_key in feature_maps: 320 | feature_maps[map_key] = feature_maps[map_key].numpy() 321 | 322 | return feature_maps 323 | 324 | def get_feature_map_metadata(model, input_size=(3,224,224), remove_duplicates = False): 325 | model = prep_model_for_extraction(model) 326 | enforce_input_shape = True 327 | 328 | inputs = torch.rand(3, *input_size) 329 | if next(model.parameters()).is_cuda: 330 | inputs = inputs.cuda() 331 | 332 | def register_hook(module): 333 | def hook(module, input, output): 334 | def process_output(output, module_name): 335 | if isinstance(output, torch.Tensor): 336 | outputs = output.cpu().detach().type(torch.FloatTensor) 337 | if not enforce_input_shape: 338 | map_data[module_name] = outputs 339 | if enforce_input_shape: 340 | if outputs.shape[0] == inputs.shape[0]: 341 | map_data[module_name] = outputs 342 | if outputs.shape[0] != inputs.shape[0]: 343 | if check_for_input_axis(outputs, inputs.shape[0]): 344 | outputs = reset_input_axis(outputs, inputs.shape[0]) 345 | map_data[module_name] = outputs 346 | if not check_for_input_axis(outputs, inputs.shape[0]): 347 | feature_maps[module_name] = None 348 | warning('Ambiguous input axis in {}. Skipping...'.format(module_name)) 349 | 350 | if module_name in map_data: 351 | module_name = get_module_name(module, metadata) 352 | feature_map = output.cpu().detach() 353 | map_data[module_name] = feature_map 354 | metadata[module_name] = {} 355 | 356 | metadata[module_name]['feature_map_shape'] = feature_map.numpy().shape[1:] 357 | metadata[module_name]['feature_count'] = feature_map.numpy().reshape(1, -1).shape[1] 358 | 359 | params = 0 360 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 361 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 362 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 363 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 364 | if isinstance(params, torch.Tensor): 365 | params = params.item() 366 | metadata[module_name]['parameter_count'] = params 367 | 368 | module_name = get_module_name(module, metadata) 369 | 370 | if not any([isinstance(output, type_) for type_ in (tuple,list)]): 371 | process_output(output, module_name) 372 | 373 | if any([isinstance(output, type_) for type_ in (tuple,list)]): 374 | for output_i, output_ in enumerate(output): 375 | module_name_ = '-'.join([module_name, str(output_i+1)]) 376 | process_output(output_, module_name_) 377 | 378 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)): 379 | hooks.append(module.register_forward_hook(hook)) 380 | 381 | map_data = OrderedDict() 382 | metadata = OrderedDict() 383 | hooks = [] 384 | 385 | model.apply(register_hook) 386 | with torch.no_grad(): 387 | model(**inputs) if isinstance(inputs, dict) else model(inputs) 388 | 389 | for hook in hooks: 390 | hook.remove() 391 | 392 | if remove_duplicates: 393 | map_data = remove_duplicate_feature_maps(map_data) 394 | metadata = {k:v for (k,v) in metadata.items() if k in map_data} 395 | 396 | return(metadata) 397 | 398 | def get_feature_map_names(model, inputs = None, remove_duplicates = True): 399 | feature_map_names = get_empty_feature_maps(model, inputs, names_only = True, 400 | remove_duplicates = remove_duplicates) 401 | 402 | return(feature_map_names) 403 | 404 | def get_feature_map_count(model, inputs = None, remove_duplicates = True): 405 | feature_map_names = get_feature_map_names(model, inputs, remove_duplicates) 406 | 407 | return(len(feature_map_names)) 408 | 409 | # Helpers: Dataloaders and functions for facilitating feature extraction 410 | 411 | class StimulusSet(Dataset): 412 | def __init__(self, image_paths, transforms=None): 413 | self.images = image_paths 414 | self.transforms = transforms 415 | 416 | def __getitem__(self, index): 417 | img = Image.open(self.images[index]).convert('RGB') 418 | if self.transforms: 419 | img = self.transforms(img) 420 | return img 421 | 422 | def __len__(self): 423 | return self.images.shape[0] 424 | 425 | def get_feature_map_size(feature_maps, layer=None): 426 | total_size = 0 427 | if layer is None: 428 | for map_key in feature_maps: 429 | if isinstance(feature_maps[map_key], np.ndarray): 430 | total_size += feature_maps[map_key].nbytes / 1000000 431 | elif torch.is_tensor(feature_maps[map_key]): 432 | total_size += feature_maps[map_key].numpy().nbytes / 1000000 433 | return total_size 434 | 435 | if layer is not None: 436 | if isinstance(feature_maps, np.ndarray): 437 | return feature_maps[layer].nbytes / 1000000 438 | elif torch.is_tensor(feature_maps): 439 | return feature_maps[layer].nbytes / 1000000 440 | 441 | class CSV2StimulusSet(Dataset): 442 | def __init__(self, csv, root_dir, transforms=None): 443 | 444 | self.root = os.path.expanduser(root_dir) 445 | self.transforms = transforms 446 | 447 | if isinstance(csv, pd.DataFrame): 448 | self.df = csv 449 | if isinstance(csv, str): 450 | self.df = pd.read_csv(csv) 451 | 452 | self.images = self.df.image_name 453 | 454 | def __getitem__(self, index): 455 | filename = os.path.join(self.root, self.images.iloc[index]) 456 | img = Image.open(filename).convert('RGB') 457 | 458 | if self.transforms: 459 | img = self.transforms(img) 460 | 461 | return img 462 | 463 | def __len__(self): 464 | return len(self.images) 465 | 466 | class Array2StimulusSet(Dataset): 467 | def __init__(self, img_array, transforms=None): 468 | self.transforms = transforms 469 | if isinstance(img_array, np.ndarray): 470 | self.images = img_array 471 | if isinstance(img_array, str): 472 | self.images = np.load(img_array) 473 | 474 | def __getitem__(self, index): 475 | img = Image.fromarray(self.images[index]).convert('RGB') 476 | if self.transforms: 477 | img = self.transforms(img) 478 | return img 479 | 480 | def __len__(self): 481 | return self.images.shape[0] 482 | -------------------------------------------------------------------------------- /deepdive/feature_reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm.auto import tqdm as tqdm 4 | import os, sys, time, pickle, argparse 5 | sys.path.append('..') 6 | 7 | import torch as torch 8 | from torch.autograd import Variable 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from sklearn.random_projection import johnson_lindenstrauss_min_dim 12 | from sklearn.random_projection import SparseRandomProjection 13 | from sklearn.decomposition import PCA 14 | 15 | from feature_extraction import * 16 | from model_options import * 17 | 18 | def check_reduction_inputs(feature_maps = None, model_inputs = None): 19 | if feature_maps == None and model_inputs == None: 20 | raise ValueError('Neither feature_maps nor model_inputs are defined.') 21 | 22 | if model_inputs is not None and not isinstance(model_inputs, (DataLoader, torch.Tensor)): 23 | raise ValueError('model_inputs not supplied in recognizable format.') 24 | 25 | def get_feature_map_filepaths(feature_map_names, output_dir): 26 | return {feature_map_name: os.path.join(output_dir, feature_map_name + '.npy') 27 | for feature_map_name in feature_map_names} 28 | 29 | #source: stackoverflow.com/questions/26774892 30 | def recursive_delete_if_empty(path): 31 | if not os.path.isdir(path): 32 | return False 33 | 34 | recurse_list = [recursive_delete_if_empty(os.path.join(path, filename)) 35 | for filename in os.listdir(path)] 36 | 37 | if all(recurse_list): 38 | os.rmdir(path) 39 | return True 40 | if not all(recurse_list): 41 | return False 42 | 43 | def delete_saved_output(output_filepaths, output_dir = None, remove_empty_output_dir = False): 44 | for file_path in output_filepaths: 45 | os.remove(output_filepaths[file_path]) 46 | if output_dir is not None and remove_empty_output_dir: 47 | output_dir = output_dir.split('/')[0] 48 | recursive_delete_if_empty(output_dir) 49 | 50 | 51 | def torch_corrcoef(m): 52 | #calculate the covariance matrix 53 | m_exp = torch.mean(m, dim=1) 54 | x = m - m_exp[:, None] 55 | cov_m = 1 / (x.size(1) - 1) * x.mm(x.t()) 56 | 57 | #convert covariance to correlation 58 | d = torch.diag(cov_m) 59 | sigma = torch.pow(d, 0.5) 60 | cor_m = cov_m.div(sigma.expand_as(cov_m)) 61 | cor_m = cor_m.div(sigma.expand_as(cor_m).t()) 62 | cor_m = torch.clamp(cor_m, -1.0, 1.0) 63 | return cor_m 64 | 65 | 66 | #### Sparse Random Projection ------------------------------------------------------------------- 67 | 68 | def get_feature_map_srps(feature_maps, n_projections = None, upsampling = True, eps=0.1, seed = 0, 69 | save_outputs = False, output_dir = 'temp_data/srp', 70 | delete_originals = False, delete_saved_outputs = True): 71 | 72 | if n_projections is None: 73 | if isinstance(feature_maps, np.ndarray): 74 | n_samples = feature_maps.shape[0] 75 | if isinstance(feature_maps, dict): 76 | n_samples = next(iter(feature_maps.values())).shape[0] 77 | n_projections = johnson_lindenstrauss_min_dim(n_samples, eps=eps) 78 | 79 | srp = SparseRandomProjection(n_projections, random_state=seed) 80 | 81 | def get_srps(feature_map): 82 | if feature_map.shape[1] <= n_projections and not upsampling: 83 | srp_feature_map = feature_map 84 | if feature_map.shape[1] >= n_projections or upsampling: 85 | srp_feature_map = srp.fit_transform(feature_map) 86 | 87 | return srp_feature_map 88 | 89 | if isinstance(feature_maps, np.ndarray) and save_outputs: 90 | raise ValueError('Please provide a dictionary of the form {feature_map_name: feature_map}' + 91 | 'in order to save_outputs.') 92 | 93 | if isinstance(feature_maps, np.ndarray) and not save_outputs: 94 | return srp.fit_transform(feature_maps) 95 | 96 | if isinstance(feature_maps, dict) and not save_outputs: 97 | srp_feature_maps = {} 98 | for feature_map_name in tqdm(list(feature_maps), desc = 'SRP Extraction (Layer)'): 99 | srp_feature_maps[feature_map_name] = get_srps(feature_maps[feature_map_name]) 100 | 101 | if delete_originals: 102 | feature_maps.pop(feature_map_name) 103 | 104 | return srp_feature_maps 105 | 106 | if isinstance(feature_maps, dict) and save_outputs: 107 | output_dir = os.path.join(output_dir, '_'.join(['projections', str(n_projections), 'seed', str(seed)])) 108 | output_filepaths = get_feature_map_filepaths(feature_maps, output_dir) 109 | if not os.path.exists(output_dir): 110 | os.makedirs(output_dir) 111 | 112 | srp_feature_maps = {} 113 | for feature_map_name in tqdm(list(feature_maps), desc = 'SRP Extraction (Layer)'): 114 | output_filepath = output_filepaths[feature_map_name] 115 | if not os.path.exists(output_filepath): 116 | srp_feature_maps[feature_map_name] = get_srps(feature_maps[feature_map_name]) 117 | np.save(output_filepath, srp_feature_maps[feature_map_name]) 118 | if os.path.exists(output_filepath): 119 | srp_feature_maps[feature_map_name] = np.load(output_filepath, allow_pickle=True) 120 | 121 | if delete_originals: 122 | feature_maps.pop(feature_map_name) 123 | 124 | if delete_saved_outputs: 125 | delete_saved_output(output_filepaths, output_dir, remove_empty_output_dir = True) 126 | 127 | return srp_feature_maps 128 | 129 | def srp_extraction(model_string, model = None, inputs = None, feature_maps = None, 130 | n_projections = None, upsampling = True, eps=0.1, seed = 0, 131 | output_dir='temp_data/srp_arrays', delete_saved_outputs = True, 132 | delete_original_feature_maps = False, verbose = False): 133 | 134 | check_reduction_inputs(feature_maps, inputs) 135 | output_dir_stem = os.path.join(output_dir, model_string.replace('/','-')) 136 | 137 | device_name = 'CPU' if not torch.cuda.is_available() else torch.cuda.get_device_name() 138 | 139 | if n_projections is None: 140 | if feature_maps is None: 141 | if isinstance(inputs, torch.Tensor): 142 | n_samples = len(inputs) 143 | if isinstance(inputs, DataLoader): 144 | n_samples = len(inputs.dataset) 145 | if feature_maps is not None: 146 | n_samples = next(iter(feature_maps.values())).shape[0] 147 | n_projections = johnson_lindenstrauss_min_dim(n_samples, eps=eps) 148 | 149 | if verbose: 150 | print('Computing {} SRPs for {}; using {} for feature extraction...' 151 | .format(n_projections, model_string, device_name)) 152 | 153 | output_dir_ext = '_'.join(['projections', str(n_projections), 'seed', str(seed)]) 154 | output_dir = os.path.join(output_dir_stem, output_dir_ext) 155 | 156 | if feature_maps is None or isinstance(feature_maps, list): 157 | check_model(model_string, model) 158 | 159 | if model == None: 160 | model = get_prepped_model(model_string) 161 | 162 | model = prep_model_for_extraction(model) 163 | feature_maps = get_all_feature_maps(model, inputs) 164 | 165 | srp_args = {'feature_maps': feature_maps, 'n_projections': n_projections, 166 | 'upsampling': upsampling, 'eps': eps, 'seed': seed, 167 | 'save_outputs': True, 'output_dir': output_dir_stem, 168 | 'delete_saved_outputs': delete_saved_outputs, 169 | 'delete_originals': delete_original_feature_maps} 170 | 171 | return get_feature_map_srps(**srp_args) 172 | 173 | 174 | #### Principal Components Analysis ------------------------------------------------------------------- 175 | 176 | def get_feature_map_pcs(feature_maps, n_components = None, return_pca_object = False, 177 | save_outputs = False, output_dir = 'temp_data/pca', 178 | delete_originals = False, delete_saved_outputs = True): 179 | 180 | def get_pca(feature_map): 181 | n_samples, n_features = feature_map.shape 182 | n_components_ = n_components 183 | if n_components_ is not None: 184 | if n_components_ > n_samples: 185 | n_components_ = n_samples 186 | print('More components requested than samples. Reducing...') 187 | pca = PCA(n_components_, random_state=0) 188 | if return_pca_object: 189 | return pca.fit(feature_map) 190 | if not return_pca_object: 191 | return pca.fit_transform(feature_map) 192 | 193 | if return_pca_object and save_outputs: 194 | raise ValueError('Saving fitted PCA objects is not currently supported.') 195 | 196 | if isinstance(feature_maps, np.ndarray) and save_outputs: 197 | raise ValueError('Please provide a dictionary of the form {feature_map_name: feature_map}' + 198 | 'in order to save_outputs.') 199 | 200 | if isinstance(feature_maps, np.ndarray) and not save_outputs: 201 | return get_pca(feature_maps) 202 | 203 | if isinstance(feature_maps, dict) and not save_outputs: 204 | pca_feature_maps = {} 205 | for feature_map_name in tqdm(list(feature_maps), desc = 'PCA Extraction (Layer)'): 206 | pca_feature_maps[feature_map_name] = get_pca(feature_maps[feature_map_name]) 207 | 208 | if delete_originals: 209 | feature_maps.pop(feature_map_name) 210 | 211 | return pca_feature_maps 212 | 213 | if isinstance(feature_maps, dict) and save_outputs: 214 | output_filepaths = get_feature_map_filepaths(feature_maps, output_dir) 215 | if not os.path.exists(output_dir): 216 | os.makedirs(output_dir) 217 | 218 | pca_feature_maps = {} 219 | for feature_map_name in tqdm(list(feature_maps), desc = 'PCA Extraction (Layer)'): 220 | output_filepath = output_filepaths[feature_map_name] 221 | if not os.path.exists(output_filepath): 222 | pca_feature_maps[feature_map_name] = get_pca(feature_maps[feature_map_name]) 223 | np.save(output_filepath, srp_feature_maps[feature_map_name]) 224 | if os.path.exists(output_filepath): 225 | pca_feature_maps[feature_map_name] = np.load(output_filepath, allow_pickle=True) 226 | 227 | if delete_originals: 228 | feature_maps.pop(feature_map_name) 229 | 230 | if delete_saved_outputs: 231 | delete_saved_output(output_filepaths, output_dir, remove_empty_output_dir = True) 232 | 233 | return pca_feature_maps 234 | 235 | 236 | def pca_extraction(model_string, model = None, inputs = None, feature_maps = None, 237 | n_components = None, aux_inputs = None, aux_feature_maps = None, 238 | output_dir='temp_data/pca_arrays', delete_saved_outputs = True, 239 | delete_original_feature_maps = False, verbose = False): 240 | 241 | check_reduction_inputs(feature_maps, inputs) 242 | 243 | use_aux_pca = aux_inputs is not None or aux_feature_maps is not None 244 | 245 | if feature_maps is None: 246 | if isinstance(inputs, torch.Tensor): 247 | n_samples = len(inputs) 248 | if isinstance(inputs, DataLoader): 249 | n_samples = len(inputs.dataset) 250 | if feature_maps is not None: 251 | n_samples = next(iter(feature_maps.values())).shape[0] 252 | 253 | if aux_feature_maps is None: 254 | if aux_inputs is not None: 255 | if isinstance(aux_inputs, torch.Tensor): 256 | n_aux_samples = len(inputs) 257 | if isinstance(aux_inputs, DataLoader): 258 | n_aux_samples = len(inputs.dataset) 259 | if aux_feature_maps is not None: 260 | n_aux_amples = next(iter(feature_maps.values())).shape[0] 261 | 262 | if n_components is not None: 263 | if n_components > n_aux_samples and use_aux_pca: 264 | raise ValueError('Requesting more components than are available with PCs from auxiliary sample.') 265 | if n_components > n_samples: 266 | raise ValueError('Requesting more components than are available with stimulus set sample size.') 267 | 268 | if n_components is None: 269 | if use_aux_pca: 270 | n_components = n_aux_samples 271 | if not use_aux_pca: 272 | n_components = n_samples 273 | 274 | device_name = 'CPU' if not torch.cuda.is_available() else torch.cuda.get_device_name() 275 | 276 | pca_type = 'auxiliary_input_pcs' if use_aux_pca else 'stimulus_direct' 277 | pca_printout = '{} Independent PCs' if use_aux_pca else 'up to {} Stimulus PCs'.format(n_components) 278 | 279 | if verbose: 280 | print('Computing {} for {}; using {} for feature extraction...' 281 | .format(pca_printout, model_string, device_name)) 282 | 283 | output_dir = os.path.join(output_dir, model_string.replace('/','-'), pca_type) 284 | if not os.path.exists(output_dir): 285 | os.makedirs(output_dir) 286 | 287 | if (feature_maps is None) or (aux_feature_maps is None and use_aux_pca): 288 | check_model(model_string, model) 289 | 290 | if model is None: 291 | model = get_prepped_model(model_string) 292 | 293 | model = prep_model_for_extraction(model) 294 | 295 | if feature_maps is None: 296 | feature_maps = get_all_feature_maps(model, inputs) 297 | 298 | if aux_feature_maps is None: 299 | aux_feature_maps = get_all_feature_maps(model, aux_inputs, layers_to_retain = list(feature_maps.keys())) 300 | 301 | if use_aux_pca: 302 | pca_args = {'feature_maps': aux_feature_maps, 'n_components': n_components, 303 | 'return_pca_object': True, 'save_outputs': False, 304 | 'delete_originals': delete_original_feature_maps} 305 | 306 | if save_outputs: 307 | raise Warning('save_outputs incompatible with using auxiliary PCA. Ignoring (and not saving)...') 308 | 309 | aux_pcas = get_feature_map_pcs(**pca_args) 310 | 311 | pca_feature_maps = {} 312 | for feature_map in feature_maps: 313 | pca_feature_maps = aux_pcas[feature_map].transform(feature_maps[feature_map]) 314 | 315 | if not use_aux_pca: 316 | pca_args = {'feature_maps': feature_maps, 'n_components': n_components, 317 | 'return_pca_object': False, 'save_outputs': False, 'output_dir': output_dir, 318 | 'delete_saved_outputs': delete_saved_outputs, 319 | 'delete_originals': delete_original_feature_maps} 320 | 321 | pca_feature_maps = get_feature_map_pcs(**pca_args) 322 | 323 | return(pca_feature_maps) 324 | 325 | 326 | #### Representational Similarity Analysis ------------------------------------------------------------------- 327 | 328 | 329 | def rdm_extraction(model_string, model = None, model_inputs = None, feature_maps = None, 330 | use_torch_corr = False, append_filename_suffix = False, 331 | output_dir='temp_data/rdm', delete_saved_outputs = True, 332 | delete_original_feature_maps = False, verbose = True): 333 | 334 | check_reduction_inputs(feature_maps, model_inputs) 335 | 336 | device_name = 'CPU' if not torch.cuda.is_available() else torch.cuda.get_device_name() 337 | 338 | if verbose: print('Computing RDMS for {}; using {} for feature extraction...' 339 | .format(model_string, device_name)) 340 | 341 | if not os.path.exists(output_dir): 342 | os.makedirs(output_dir) 343 | 344 | output_file = os.path.join(output_dir, model_string + '_rdms.pkl' 345 | if append_filename_suffix else model_string + '.pkl') 346 | 347 | if os.path.exists(output_file): 348 | model_rdms = pickle.load(open(output_file,'rb')) 349 | 350 | if not os.path.exists(output_file): 351 | if feature_maps is None: 352 | check_model(model_string, model) 353 | 354 | if model == None: 355 | model = get_prepped_model(model_string) 356 | 357 | model = prep_model_for_extraction(model) 358 | feature_maps = get_all_feature_maps(model, model_inputs, numpy = not use_torch_corr) 359 | 360 | model_rdms = {} 361 | for model_layer in tqdm(list(feature_maps), leave=False): 362 | if use_torch_corr: 363 | feature_map = feature_maps[model_layer] 364 | if torch.cuda.is_available(): 365 | feature_map = feature_map.cuda() 366 | model_rdm = 1 - torch_corrcoef(feature_map).cpu() 367 | model_rdms[model_layer] = model_rdm.numpy() 368 | if not use_torch_corr: 369 | model_rdms[model_layer] = 1 - np.corrcoef(feature_maps[model_layer]) 370 | if delete_original_feature_maps: 371 | feature_maps.pop(feature_map) 372 | with open(output_file, 'wb') as file: 373 | pickle.dump(model_rdms, file) 374 | 375 | return(model_rdms) 376 | -------------------------------------------------------------------------------- /deepdive/mapping_methods.py: -------------------------------------------------------------------------------- 1 | import warnings; warnings.filterwarnings("ignore") 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm.auto import tqdm as tqdm 6 | 7 | from sklearn.metrics import r2_score, explained_variance_score 8 | from sklearn.linear_model import LinearRegression, ElasticNet 9 | from sklearn.linear_model import Ridge, RidgeCV 10 | from sklearn.cross_decomposition import PLSRegression 11 | from sklearn.model_selection import KFold, RepeatedKFold 12 | from sklearn.preprocessing import scale 13 | from scipy.stats import pearsonr, spearmanr 14 | 15 | ### Response Transforms ----------------------------------------------------- 16 | 17 | def anscombe_transform(x): 18 | return 2.0*np.sqrt(x + 3.0/8.0) 19 | 20 | ### Scoring Options --------------------------------------------------------- 21 | 22 | pearsonr_vec = np.vectorize(pearsonr, signature='(n),(n)->(),()') 23 | 24 | def pearson_r_score(y_true, y_pred, multioutput=None): 25 | y_true_ = y_true.transpose() 26 | y_pred_ = y_pred.transpose() 27 | return(pearsonr_vec(y_true_, y_pred_)[0]) 28 | 29 | def pearson_r2_score(y_true, y_pred, multioutput=None): 30 | return(pearson_r_score(y_true, y_pred)**2) 31 | 32 | def get_predicted_values(y_true, y_pred, transform = None, multioutput = None): 33 | if transform == None: 34 | return(y_pred) 35 | 36 | scoring_options = {'r2': r2_score, 'pearson_r': pearson_r_score, 'pearson_r2': pearson_r2_score, 37 | 'explained_variance': explained_variance_score, 'predicted_values': get_predicted_values} 38 | 39 | def get_scoring_options(): 40 | return scoring_options 41 | 42 | def score_func(y_true, y_pred, score_type='pearson_r'): 43 | if not isinstance(score_type, list): 44 | return(scoring_options[score_type](y_true, y_pred, multioutput='raw_values')) 45 | 46 | if isinstance(score_type, list): 47 | scoring_dict = {} 48 | for score_type_i in score_type: 49 | scoring_dict[score_type_i] = scoring_options[score_type_i](y_true, y_pred, multioutput='raw_values') 50 | 51 | return(scoring_dict) 52 | 53 | ### Neural Regression Methods --------------------------------------------------------- 54 | 55 | def kfold_regression(X, y, regression, n_splits, score_type, use_tqdm): 56 | if regression == 'ridge': 57 | regression = Ridge(alpha = 1.0) 58 | if regression == 'pls': 59 | regression = PLSRegression(n_components = 10) 60 | if isinstance(regression, str) and regression not in ('ridge','pls'): 61 | assert "Unknown regression string. Please use one of ('ridge', 'pls') or an sklearn regression object." 62 | 63 | kfolds = KFold(n_splits, shuffle=False).split(np.arange(y.shape[0])) 64 | kfolds = tqdm(kfolds, total = n_splits, leave=False) if use_tqdm else kfolds 65 | 66 | y_pred = np.zeros((y.shape[0],y.shape[1])) 67 | for train_indices, test_indices in kfolds: 68 | X_train, X_test = X[train_indices, :], X[test_indices, :] 69 | y_train, y_test = y[train_indices], y[test_indices] 70 | regression = regression.fit(X_train, y_train) 71 | y_pred[test_indices] = regression.predict(X_test) 72 | 73 | return score_func(y, y_pred, score_type) 74 | 75 | def gcv_ridge_regression(X,y, score_type, alphas = [1.0], return_best_alpha = False): 76 | regression = RidgeCV(alphas=alphas, store_cv_values = True, 77 | scoring = 'explained_variance').fit(X,y) 78 | 79 | y_pred = regression.cv_values_.squeeze() 80 | 81 | if len(alphas) > 1: 82 | best_alpha_index = 0 83 | current_best_score = 0 84 | for alpha, alpha_index in enumerate(alphas): 85 | y_pred = regression.cv_values_[:,:,alpha_index] 86 | score = score_func(y, y_pred, score_type).mean() 87 | if score > current_best_score: 88 | current_best_score = score 89 | best_alpha_index = alpha_index 90 | 91 | y_pred = regression.cv_values_[:,:,best_alpha_index] 92 | 93 | scores = score_func(y, y_pred, score_type) 94 | 95 | if not return_best_alpha: 96 | return scores 97 | if return_best_alpha: 98 | return (scores, regression.alpha_) 99 | 100 | def neural_regression(feature_map, neural_response, regression = Ridge(alpha=1.0), cv_splits = 5, 101 | score_type = 'pearson_r', use_tqdm = False, **kwargs): 102 | 103 | if cv_splits == 'gcv' and regression != 'ridge' and not isinstance(regression, Ridge): 104 | raise Warning("gcv mode selected, but regression is not ridge.") 105 | 106 | X,y = feature_map, neural_response 107 | 108 | if cv_splits is None: 109 | warnings.warn('No cv_splits selected. Returning fitted regression object...') 110 | return regression.fit(X,y) 111 | 112 | if cv_splits == 'gcv': 113 | return gcv_ridge_regression(X, y, score_type, **kwargs) 114 | 115 | if isinstance(cv_splits, int): 116 | return kfold_regression(X, y, regression, cv_splits, score_type, use_tqdm) 117 | 118 | ### Classic Representational Similarity -------------------------------------------------- 119 | 120 | def compare_rdms(rdm1, rdm2, dist_type = 'pearson'): 121 | rdm1_triu = rdm1[np.triu_indices(rdm1.shape[0], k=1)] 122 | rdm2_triu = rdm2[np.triu_indices(rdm2.shape[0], k=1)] 123 | 124 | if dist_type == 'pearson': 125 | return pearsonr(rdm1_triu, rdm2_triu)[0] 126 | if dist_type == 'spearman': 127 | return spearmanr(rdm1_triu, rdm2_triu)[0] 128 | 129 | ### Representational Similarity Regression -------------------------------------------------- 130 | 131 | def rdm_regression(target_rdm, model_rdms, regression_type='linear', 132 | n_splits=10, n_repeats=None, random_state=None): 133 | '''Non-negative least squares linear regression on RDMs with k-fold cross-validation. 134 | Parameters 135 | ---------- 136 | target_rdm: your brain data RDM (n_samples x n_samples) 137 | model_rdms: your model layer RDMs (n_samples x n_samples x n_layers) 138 | n_splits: how many cross_validated folds 139 | n_repeats: how many times to perform k-fold splits 140 | random_state: used if you want to use a particular set of random splits 141 | Attributes 142 | ---------- 143 | r : correlation between predicted and actual RDM 144 | coefficients : the coefficients across k-fold splits 145 | intercepts : the intercepts across k-fold splits 146 | ''' 147 | n_items = target_rdm.shape[0] 148 | 149 | predicted_rdm = np.zeros(target_rdm.shape) 150 | predicted_sum = np.zeros(target_rdm.shape) 151 | predicted_count = np.zeros(target_rdm.shape) 152 | 153 | coefficients = [] 154 | intercepts = [] 155 | i,j = np.triu_indices(target_rdm.shape[0],k=1) 156 | if n_repeats == None: 157 | kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) 158 | if n_repeats != None: 159 | kf = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=random_state) 160 | 161 | for train_indices, test_indices in kf.split(list(range(n_items))): 162 | 163 | # indices for training and test cells of matrix 164 | test_idx = (np.isin(i, test_indices) | np.isin(j, test_indices)) 165 | train_idx = ~test_idx 166 | 167 | # target data (excluding test_indices) 168 | y_train = target_rdm[i[train_idx], j[train_idx]] 169 | 170 | # model data (excluding test_indices) 171 | X_train = model_rdms[i[train_idx], j[train_idx], :] 172 | 173 | # test data (test_indices) 174 | X_test = model_rdms[i[test_idx], j[test_idx], :] 175 | 176 | # fit the regression model 177 | if regression_type == 'linear': 178 | regression = LinearRegression(fit_intercept=True, positive=True) 179 | regression.fit(X_train, y_train) 180 | if regression_type == 'ridge': 181 | regression = ElasticNet(alpha = 1.0, l1_ratio = 0, positive = True) 182 | regression.fit(X_train, y_train) 183 | 184 | # predict the held out cells 185 | # note that for a k-fold procedure, some cells are predicted more than once 186 | # so we keep a sum and count, and later will average (sum/count) these predictions 187 | predicted_sum[i[test_idx],j[test_idx]] += regression.predict(X_test) 188 | predicted_count[i[test_idx],j[test_idx]] += 1 189 | 190 | # save the regression coefficients 191 | coefficients.append(regression.coef_) 192 | intercepts.append(regression.intercept_) 193 | 194 | predicted_rdm = predicted_sum / predicted_count 195 | coefficients = np.stack(coefficients) 196 | intercepts = np.stack(intercepts) 197 | 198 | # make sure each cell received one value 199 | cell_counts = predicted_count[np.triu_indices(target_rdm.shape[0], k=1)] 200 | assert cell_counts.min()>=1, "A cell of the predicted matrix contains less than one value." 201 | 202 | # compute correlation between target and predicted upper triangle 203 | target = target_rdm[np.triu_indices(target_rdm.shape[0], k=1)] 204 | predicted = predicted_rdm[np.triu_indices(predicted_rdm.shape[0], k=1)] 205 | 206 | r = pearsonr(target, predicted)[0] 207 | 208 | return r, coefficients, intercepts 209 | 210 | ### Data Transforms --------------------------------------------------------- 211 | 212 | 213 | def max_transform(df, group_vars, measure_var = 'score', deduplicate=True): 214 | if not isinstance(group_vars, list): 215 | group_vars = list(group_vars) 216 | 217 | max_df = (df[df.groupby(group_vars)[measure_var] 218 | .transform(max) == df[measure_var]]).reset_index(drop=True) 219 | 220 | if deduplicate: 221 | max_df = max_df[~max_df.duplicated(group_vars + [measure_var])] 222 | 223 | return max_df 224 | 225 | def min_transform(df, group_vars, measure_var = 'score', deduplicate=True): 226 | if not isinstance(group_vars, list): 227 | group_vars = list(group_vars) 228 | 229 | min_df = (df[df.groupby(group_vars)[measure_var] 230 | .transform(min) == df[measure_var]]).reset_index(drop=True) 231 | 232 | if deduplicate: 233 | min_df = min_df[~min_df.duplicated(group_vars + [measure_var])] 234 | 235 | return min_df 236 | -------------------------------------------------------------------------------- /deepdive/model_accuracy.csv: -------------------------------------------------------------------------------- 1 | model,top1_accuracy,top5_accuracy 2 | beit_large_patch16_512,88.6,98.656 3 | beit_large_patch16_384,88.404,98.608 4 | tf_efficientnet_l2_ns,88.348,98.648 5 | tf_efficientnet_l2_ns_475,88.232,98.546 6 | convnext_xlarge_384_in22ft1k,87.546,98.486 7 | beit_large_patch16_224,87.474,98.304 8 | convnext_large_384_in22ft1k,87.396,98.368 9 | swin_large_patch4_window12_384,87.15,98.24 10 | vit_large_patch16_384,87.08,98.3 11 | volo_d5_512,87.042,97.968 12 | convnext_xlarge_in22ft1k,87.002,98.212 13 | volo_d5_448,86.952,97.94 14 | tf_efficientnet_b7_ns,86.838,98.096 15 | beit_base_patch16_384,86.798,98.136 16 | volo_d4_448,86.79,97.882 17 | convnext_large_in22ft1k,86.636,98.028 18 | convnext_base_384_in22ft1k,86.544,98.19 19 | volo_d3_448,86.496,97.71 20 | cait_m48_448,86.486,97.752 21 | tf_efficientnet_b6_ns,86.454,97.882 22 | swin_base_patch4_window12_384,86.432,98.056 23 | tf_efficientnetv2_xl_in21ft1k,86.418,97.866 24 | swin_large_patch4_window7_224,86.314,97.892 25 | tf_efficientnetv2_l_in21ft1k,86.304,97.978 26 | vit_large_r50_s32_384,86.182,97.92 27 | dm_nfnet_f6,86.142,97.73 28 | tf_efficientnet_b5_ns,86.09,97.75 29 | volo_d5_224,86.072,97.578 30 | cait_m36_384,86.052,97.73 31 | volo_d2_384,86.04,97.572 32 | vit_base_patch16_384,86.01,98.002 33 | xcit_large_24_p8_384_dist,85.998,97.684 34 | volo_d4_224,85.874,97.468 35 | vit_large_patch16_224,85.838,97.82 36 | convnext_base_in22ft1k,85.824,97.866 37 | dm_nfnet_f5,85.816,97.486 38 | xcit_medium_24_p8_384_dist,85.814,97.592 39 | vit_base_patch8_224,85.794,97.794 40 | xcit_large_24_p16_384_dist,85.754,97.538 41 | dm_nfnet_f4,85.714,97.522 42 | tf_efficientnetv2_m_in21ft1k,85.59,97.744 43 | xcit_small_24_p8_384_dist,85.556,97.572 44 | dm_nfnet_f3,85.522,97.462 45 | tf_efficientnetv2_l,85.49,97.372 46 | cait_s36_384,85.462,97.48 47 | ig_resnext101_32x48d,85.436,97.576 48 | xcit_medium_24_p16_384_dist,85.422,97.406 49 | deit_base_distilled_patch16_384,85.422,97.332 50 | volo_d3_224,85.412,97.28 51 | xcit_large_24_p8_224_dist,85.398,97.41 52 | tf_efficientnet_b8_ap,85.372,97.292 53 | tf_efficientnet_b8,85.368,97.39 54 | swin_base_patch4_window7_224,85.248,97.562 55 | volo_d1_384,85.248,97.214 56 | beit_base_patch16_224,85.228,97.658 57 | volo_d2_224,85.194,97.188 58 | tf_efficientnet_b4_ns,85.162,97.47 59 | tf_efficientnet_b7_ap,85.12,97.252 60 | ig_resnext101_32x32d,85.1,97.434 61 | xcit_small_24_p16_384_dist,85.094,97.31 62 | xcit_small_12_p8_384_dist,85.08,97.28 63 | xcit_medium_24_p8_224_dist,85.072,97.28 64 | dm_nfnet_f2,85.062,97.24 65 | cait_s24_384,85.046,97.346 66 | tf_efficientnetv2_m,85.038,97.278 67 | regnetz_e8,85.03,97.264 68 | resnetrs420,85.01,97.124 69 | ecaresnet269d,84.976,97.226 70 | vit_base_r50_s16_384,84.97,97.29 71 | tf_efficientnet_b7,84.936,97.204 72 | xcit_large_24_p16_224_dist,84.922,97.132 73 | resnetv2_152x4_bitm,84.916,97.442 74 | xcit_small_24_p8_224_dist,84.876,97.188 75 | efficientnetv2_rw_m,84.812,97.146 76 | tf_efficientnet_b6_ap,84.786,97.138 77 | resnetrs350,84.718,96.988 78 | xcit_small_12_p16_384_dist,84.71,97.118 79 | eca_nfnet_l2,84.696,97.264 80 | dm_nfnet_f1,84.624,97.1 81 | vit_base_patch16_224,84.528,97.294 82 | resnest269e,84.52,96.986 83 | resnetv2_152x2_bitm,84.506,97.434 84 | regnetz_040h,84.496,97.006 85 | resnetv2_101x3_bitm,84.442,97.382 86 | resnetrs200,84.438,97.08 87 | resnetrs270,84.436,96.97 88 | vit_large_r50_s32_224,84.424,97.166 89 | resmlp_big_24_224_in22ft1k,84.394,97.12 90 | xcit_large_24_p8_224,84.394,96.656 91 | seresnet152d,84.36,97.04 92 | tf_efficientnetv2_s_in21ft1k,84.298,97.254 93 | convnext_large,84.292,96.894 94 | swsl_resnext101_32x8d,84.29,97.18 95 | xcit_medium_24_p16_224_dist,84.274,96.94 96 | vit_base_patch16_224_miil,84.272,96.802 97 | tf_efficientnet_b5_ap,84.256,96.976 98 | xcit_small_12_p8_224_dist,84.236,96.874 99 | regnetz_040,84.234,96.932 100 | seresnext101_32x8d,84.204,96.876 101 | crossvit_18_dagger_408,84.194,96.818 102 | ig_resnext101_32x16d,84.17,97.198 103 | volo_d1_224,84.162,96.776 104 | pit_b_distilled_224,84.14,96.856 105 | tf_efficientnet_b6,84.11,96.888 106 | cait_xs24_384,84.062,96.89 107 | regnetz_d8_evos,84.054,96.996 108 | regnetz_d8,84.052,96.994 109 | tf_efficientnet_b3_ns,84.048,96.908 110 | vit_small_r26_s32_384,84.042,97.328 111 | regnetz_d32,84.022,96.866 112 | resnetv2_50x3_bitm,84.014,97.124 113 | eca_nfnet_l1,84.012,97.028 114 | resnet200d,83.964,96.824 115 | swin_s3_base_224,83.932,96.66 116 | regnety_080,83.926,96.888 117 | tf_efficientnetv2_s,83.886,96.696 118 | xcit_small_24_p16_224_dist,83.868,96.724 119 | resnetv2_152x2_bit_teacher_384,83.844,97.118 120 | convnext_base,83.838,96.75 121 | xcit_small_24_p8_224,83.838,96.636 122 | crossvit_15_dagger_408,83.836,96.784 123 | resnest200e,83.828,96.892 124 | tf_efficientnet_b5,83.812,96.748 125 | efficientnetv2_rw_s,83.81,96.722 126 | vit_small_patch16_384,83.804,97.102 127 | swin_s3_small_224,83.768,96.452 128 | xcit_tiny_24_p8_384_dist,83.742,96.71 129 | xcit_medium_24_p8_224,83.736,96.394 130 | regnety_064,83.72,96.722 131 | resnetrs152,83.714,96.614 132 | regnetv_064,83.714,96.746 133 | regnety_160,83.69,96.776 134 | twins_svt_large,83.68,96.594 135 | resnet152d,83.678,96.738 136 | resmlp_big_24_distilled_224,83.588,96.648 137 | jx_nest_base,83.556,96.362 138 | cait_s24_224,83.458,96.564 139 | efficientnet_b4,83.424,96.596 140 | deit_base_distilled_patch16_224,83.392,96.486 141 | dm_nfnet_f0,83.386,96.574 142 | swsl_resnext101_32x16d,83.35,96.844 143 | xcit_small_12_p16_224_dist,83.35,96.414 144 | vit_base_patch32_384,83.348,96.834 145 | xcit_small_12_p8_224,83.344,96.48 146 | tf_efficientnet_b4_ap,83.252,96.394 147 | swsl_resnext101_32x4d,83.236,96.764 148 | swin_small_patch4_window7_224,83.216,96.324 149 | regnetv_040,83.2,96.662 150 | xception65,83.18,96.592 151 | convnext_small,83.15,96.432 152 | resnext101_64x4d,83.14,96.37 153 | twins_svt_base,83.138,96.42 154 | twins_pcpvt_large,83.134,96.604 155 | xception65p,83.126,96.478 156 | jx_nest_small,83.118,96.33 157 | deit_base_patch16_384,83.106,96.368 158 | tresnet_m,83.07,96.12 159 | tresnet_xl_448,83.054,96.172 160 | regnety_040,83.036,96.506 161 | tf_efficientnet_b4,83.024,96.3 162 | resnet101d,83.022,96.448 163 | xcit_large_24_p16_224,82.894,95.882 164 | resnest101e,82.89,96.318 165 | resnetv2_152x2_bit_teacher,82.872,96.57 166 | resnetv2_50x1_bit_distilled,82.822,96.528 167 | resnet152,82.82,96.13 168 | pnasnet5large,82.79,96.04 169 | nfnet_l0,82.752,96.518 170 | regnety_032,82.726,96.424 171 | twins_pcpvt_base,82.708,96.35 172 | ig_resnext101_32x8d,82.7,96.63 173 | xcit_medium_24_p16_224,82.638,95.974 174 | regnetz_c16_evos,82.632,96.476 175 | nasnetalarge,82.626,96.046 176 | levit_384,82.588,96.022 177 | xcit_small_24_p16_224,82.58,96.006 178 | eca_nfnet_l0,82.576,96.49 179 | xcit_tiny_24_p16_384_dist,82.572,96.288 180 | xcit_tiny_24_p8_224_dist,82.564,96.17 181 | resnet61q,82.526,96.134 182 | crossvit_18_dagger_240,82.52,96.07 183 | regnetz_c16,82.516,96.36 184 | gc_efficientnetv2_rw_t,82.466,96.296 185 | poolformer_m48,82.462,95.958 186 | pit_b_224,82.446,95.71 187 | crossvit_18_240,82.398,96.058 188 | xcit_tiny_12_p8_384_dist,82.394,96.22 189 | tf_efficientnet_b2_ns,82.38,96.248 190 | resnet51q,82.362,96.18 191 | ecaresnet50t,82.348,96.138 192 | efficientnetv2_rw_t,82.344,96.196 193 | resnetv2_101x1_bitm,82.336,96.516 194 | crossvit_15_dagger_240,82.33,95.958 195 | mixer_b16_224_miil,82.308,95.718 196 | coat_lite_small,82.304,95.848 197 | resnetrs101,82.288,96.01 198 | convit_base,82.286,95.938 199 | tresnet_l_448,82.268,95.982 200 | efficientnet_b3,82.24,96.114 201 | convnext_tiny_hnf,82.222,95.866 202 | crossvit_base_240,82.216,95.834 203 | cait_xxs36_384,82.194,96.144 204 | swsl_resnext50_32x4d,82.176,96.232 205 | ecaresnet101d,82.172,96.046 206 | swin_s3_tiny_224,82.126,95.95 207 | poolformer_m36,82.112,95.69 208 | visformer_small,82.106,95.874 209 | convnext_tiny,82.064,95.852 210 | halo2botnet50ts_256,82.06,95.642 211 | tresnet_xl,82.058,95.936 212 | fbnetv3_g,82.046,96.064 213 | resnetv2_101,82.042,95.864 214 | deit_base_patch16_224,81.996,95.732 215 | pit_s_distilled_224,81.994,95.798 216 | resnetv2_50d_evos,81.98,95.91 217 | xcit_small_12_p16_224,81.976,95.818 218 | tf_efficientnetv2_b3,81.966,95.78 219 | xception41p,81.96,95.794 220 | resnet101,81.932,95.77 221 | xcit_tiny_24_p8_224,81.892,95.976 222 | vit_small_r26_s32_224,81.856,96.02 223 | ssl_resnext101_32x16d,81.854,96.096 224 | tf_efficientnet_b3_ap,81.826,95.622 225 | resnetv2_50d_gn,81.818,95.922 226 | tresnet_m_448,81.704,95.572 227 | twins_svt_small,81.68,95.67 228 | halonet50ts,81.66,95.612 229 | tf_efficientnet_b3,81.636,95.718 230 | rexnet_200,81.63,95.668 231 | ssl_resnext101_32x8d,81.608,96.042 232 | lamhalobotnet50ts_256,81.546,95.502 233 | crossvit_15_240,81.542,95.69 234 | tf_efficientnet_lite4,81.536,95.668 235 | tnt_s_patch16_224,81.52,95.744 236 | levit_256,81.506,95.492 237 | vit_large_patch32_384,81.506,96.094 238 | tresnet_l,81.492,95.624 239 | wide_resnet50_2,81.452,95.53 240 | convit_small,81.42,95.74 241 | jx_nest_tiny,81.42,95.618 242 | poolformer_s36,81.418,95.45 243 | vit_small_patch16_224,81.396,96.132 244 | tf_efficientnet_b1_ns,81.388,95.736 245 | swin_tiny_patch4_window7_224,81.374,95.544 246 | convmixer_1536_20,81.366,95.614 247 | gernet_l,81.344,95.532 248 | efficientnet_el,81.31,95.53 249 | legacy_senet154,81.31,95.496 250 | coat_mini,81.266,95.394 251 | seresnext50_32x4d,81.258,95.63 252 | gluon_senet154,81.232,95.348 253 | deit_small_distilled_patch16_224,81.208,95.374 254 | xcit_tiny_12_p8_224_dist,81.208,95.6 255 | swsl_resnet50,81.174,95.978 256 | sebotnet33ts_256,81.156,95.17 257 | resmlp_36_distilled_224,81.154,95.488 258 | lambda_resnet50ts,81.146,95.102 259 | resnest50d_4s2x40d,81.11,95.564 260 | resnext50_32x4d,81.108,95.326 261 | pit_s_224,81.1,95.33 262 | twins_pcpvt_small,81.09,95.64 263 | haloregnetz_b,81.05,95.196 264 | resmlp_big_24_224,81.032,95.02 265 | crossvit_small_240,81.018,95.456 266 | gluon_resnet152_v1s,81.016,95.412 267 | resnest50d_1s4x24d,80.99,95.324 268 | resnest50d,80.982,95.38 269 | cait_xxs24_384,80.966,95.646 270 | sehalonet33ts,80.964,95.272 271 | xcit_tiny_12_p16_384_dist,80.944,95.412 272 | gcresnet50t,80.942,95.454 273 | ssl_resnext101_32x4d,80.924,95.726 274 | gluon_seresnext101_32x4d,80.906,95.294 275 | gluon_seresnext101_64x4d,80.878,95.298 276 | efficientnet_b3_pruned,80.858,95.244 277 | ecaresnet101d_pruned,80.814,95.63 278 | regnety_320,80.81,95.244 279 | resmlp_24_distilled_224,80.764,95.224 280 | gernet_m,80.744,95.184 281 | vit_base_patch32_224,80.722,95.566 282 | regnetz_b16,80.714,95.478 283 | nf_resnet50,80.654,95.334 284 | efficientnet_b2,80.614,95.316 285 | gluon_resnext101_64x4d,80.604,94.992 286 | ecaresnet50d,80.6,95.32 287 | gcresnext50ts,80.578,95.17 288 | resnet50d,80.522,95.162 289 | repvgg_b3,80.496,95.264 290 | vit_small_patch32_384,80.484,95.6 291 | gluon_resnet152_v1d,80.476,95.204 292 | mixnet_xl,80.474,94.934 293 | inception_resnet_v2,80.46,95.308 294 | ecaresnetlight,80.452,95.25 295 | xcit_tiny_24_p16_224_dist,80.446,95.216 296 | resnetv2_50,80.42,95.074 297 | gluon_resnet101_v1d,80.42,95.016 298 | resnet50,80.376,94.616 299 | regnety_120,80.376,95.126 300 | seresnet33ts,80.35,95.106 301 | gluon_resnext101_32x4d,80.344,94.926 302 | resnetv2_50x1_bitm,80.342,95.68 303 | ssl_resnext50_32x4d,80.316,95.41 304 | poolformer_s24,80.314,95.046 305 | rexnet_150,80.31,95.166 306 | tf_efficientnet_b2_ap,80.302,95.028 307 | efficientnet_el_pruned,80.302,95.216 308 | gluon_resnet101_v1s,80.298,95.164 309 | seresnet50,80.264,95.072 310 | tf_efficientnet_el,80.25,95.122 311 | vit_base_patch16_224_sam,80.242,94.754 312 | regnetx_320,80.24,95.022 313 | legacy_seresnext101_32x4d,80.224,95.01 314 | repvgg_b3g4,80.212,95.106 315 | tf_efficientnetv2_b2,80.206,95.042 316 | dpn107,80.172,94.906 317 | convmixer_768_32,80.164,95.072 318 | inception_v4,80.162,94.966 319 | skresnext50_32x4d,80.152,94.644 320 | eca_resnet33ts,80.08,94.97 321 | gcresnet33ts,80.08,95.0 322 | tf_efficientnet_b2,80.08,94.908 323 | cspdarknet53,80.062,95.084 324 | resnet50_gn,80.054,94.948 325 | cspresnext50,80.05,94.946 326 | dpn92,80.016,94.824 327 | ens_adv_inception_resnet_v2,79.978,94.938 328 | efficientnet_b2_pruned,79.916,94.854 329 | gluon_seresnext50_32x4d,79.914,94.832 330 | gluon_resnet152_v1c,79.908,94.848 331 | resnetrs50,79.886,94.966 332 | xception71,79.876,94.922 333 | deit_small_patch16_224,79.86,95.046 334 | regnetx_160,79.85,94.83 335 | ecaresnet26t,79.848,95.086 336 | levit_192,79.832,94.786 337 | dpn131,79.824,94.708 338 | tf_efficientnet_lite3,79.82,94.912 339 | resmlp_36_224,79.768,94.886 340 | cait_xxs36_224,79.748,94.866 341 | gluon_xception65,79.716,94.86 342 | ecaresnet50d_pruned,79.708,94.88 343 | xcit_tiny_12_p8_224,79.69,95.054 344 | fbnetv3_d,79.682,94.948 345 | gluon_resnet152_v1b,79.68,94.738 346 | resnext50d_32x4d,79.67,94.864 347 | dpn98,79.646,94.596 348 | gmlp_s16_224,79.64,94.624 349 | regnetx_120,79.592,94.734 350 | cspresnet50,79.582,94.704 351 | gluon_resnet101_v1c,79.534,94.58 352 | rexnet_130,79.5,94.684 353 | eca_halonext26ts,79.49,94.598 354 | hrnet_w64,79.472,94.652 355 | tf_efficientnetv2_b1,79.464,94.724 356 | dla102x2,79.446,94.632 357 | xcit_tiny_24_p16_224,79.444,94.884 358 | resmlp_24_224,79.382,94.546 359 | repvgg_b2g4,79.37,94.688 360 | gluon_resnext50_32x4d,79.364,94.426 361 | resnext101_32x8d,79.316,94.518 362 | ese_vovnet39b,79.312,94.714 363 | pit_xs_distilled_224,79.306,94.364 364 | tf_efficientnet_cc_b1_8e,79.306,94.372 365 | resnetblur50,79.304,94.634 366 | gluon_resnet101_v1b,79.302,94.52 367 | hrnet_w48,79.302,94.512 368 | nf_regnet_b1,79.288,94.748 369 | tf_efficientnet_b1_ap,79.28,94.304 370 | eca_botnext26ts_256,79.274,94.616 371 | botnet26t_256,79.252,94.528 372 | efficientnet_em,79.25,94.794 373 | ssl_resnet50,79.226,94.836 374 | dpn68b,79.22,94.418 375 | resnet33ts,79.21,94.572 376 | regnetx_080,79.202,94.554 377 | res2net101_26w_4s,79.196,94.436 378 | fbnetv3_b,79.148,94.746 379 | halonet26t,79.116,94.31 380 | lambda_resnet26t,79.098,94.588 381 | coat_lite_mini,79.096,94.604 382 | gluon_resnet50_v1d,79.076,94.472 383 | legacy_seresnext50_32x4d,79.068,94.434 384 | regnetx_064,79.066,94.458 385 | xception,79.05,94.392 386 | resnet32ts,79.012,94.358 387 | res2net50_26w_8s,78.98,94.294 388 | mixnet_l,78.976,94.178 389 | lambda_resnet26rpt_256,78.968,94.428 390 | hrnet_w40,78.916,94.474 391 | hrnet_w44,78.9,94.374 392 | wide_resnet101_2,78.854,94.29 393 | tf_efficientnet_b1,78.828,94.198 394 | gluon_inception_v3,78.804,94.37 395 | efficientnet_b1,78.796,94.342 396 | repvgg_b2,78.792,94.418 397 | tf_mixnet_l,78.774,93.996 398 | gluon_resnet50_v1s,78.712,94.24 399 | dla169,78.692,94.34 400 | tf_efficientnet_b0_ns,78.658,94.378 401 | legacy_seresnet152,78.652,94.37 402 | xcit_tiny_12_p16_224_dist,78.576,94.196 403 | res2net50_26w_6s,78.566,94.134 404 | dla102x,78.516,94.226 405 | xception41,78.51,94.278 406 | levit_128,78.492,94.006 407 | regnetx_040,78.482,94.244 408 | resnest26d,78.476,94.292 409 | dla60_res2net,78.462,94.206 410 | hrnet_w32,78.448,94.194 411 | dla60_res2next,78.44,94.15 412 | vit_tiny_patch16_384,78.434,94.542 413 | coat_tiny,78.43,94.04 414 | selecsls60b,78.412,94.174 415 | legacy_seresnet101,78.388,94.264 416 | cait_xxs24_224,78.384,94.31 417 | repvgg_b1,78.368,94.096 418 | tf_efficientnetv2_b0,78.36,94.02 419 | tv_resnet152,78.316,94.034 420 | mobilevit_s,78.312,94.152 421 | res2next50,78.252,93.886 422 | bat_resnext26ts,78.25,94.098 423 | dla60x,78.244,94.018 424 | efficientnet_b1_pruned,78.24,93.834 425 | hrnet_w30,78.198,94.224 426 | pit_xs_224,78.186,94.164 427 | regnetx_032,78.172,94.088 428 | res2net50_14w_8s,78.144,93.848 429 | tf_efficientnet_em,78.132,94.044 430 | hardcorenas_f,78.098,93.802 431 | efficientnet_es,78.056,93.936 432 | gmixer_24_224,78.036,93.67 433 | dla102,78.03,93.948 434 | gluon_resnet50_v1c,78.012,93.99 435 | seresnext26t_32x4d,77.976,93.746 436 | selecsls60,77.976,93.83 437 | res2net50_26w_4s,77.96,93.852 438 | resmlp_12_distilled_224,77.942,93.558 439 | mobilenetv3_large_100_miil,77.916,92.906 440 | tf_efficientnet_cc_b0_8e,77.906,93.656 441 | resnet26t,77.862,93.844 442 | regnety_016,77.86,93.722 443 | rexnet_100,77.858,93.87 444 | tf_inception_v3,77.856,93.64 445 | seresnext26ts,77.852,93.79 446 | gcresnext26ts,77.82,93.83 447 | xcit_nano_12_p8_384_dist,77.818,94.044 448 | hardcorenas_e,77.794,93.696 449 | efficientnet_b0,77.69,93.53 450 | tinynet_a,77.65,93.536 451 | legacy_seresnet50,77.63,93.75 452 | tv_resnext50_32x4d,77.616,93.7 453 | seresnext26d_32x4d,77.604,93.608 454 | repvgg_b1g4,77.586,93.83 455 | adv_inception_v3,77.582,93.736 456 | gluon_resnet50_v1b,77.58,93.722 457 | res2net50_48w_2s,77.52,93.552 458 | coat_lite_tiny,77.514,93.916 459 | tf_efficientnet_lite2,77.468,93.756 460 | eca_resnext26ts,77.454,93.566 461 | inception_v3,77.438,93.474 462 | hardcorenas_d,77.43,93.482 463 | tv_resnet101,77.378,93.542 464 | densenet161,77.354,93.636 465 | tf_efficientnet_cc_b0_4e,77.302,93.334 466 | mobilenetv2_120d,77.294,93.496 467 | densenet201,77.29,93.478 468 | mixnet_m,77.264,93.424 469 | poolformer_s12,77.236,93.504 470 | selecsls42b,77.174,93.392 471 | xcit_tiny_12_p16_224,77.126,93.716 472 | resnet34d,77.114,93.38 473 | legacy_seresnext26_32x4d,77.106,93.318 474 | tf_efficientnet_b0_ap,77.094,93.256 475 | hardcorenas_c,77.05,93.158 476 | dla60,77.03,93.32 477 | crossvit_9_dagger_240,76.982,93.61 478 | regnetx_016,76.95,93.422 479 | convmixer_1024_20_ks9_p14,76.942,93.356 480 | tf_mixnet_m,76.942,93.154 481 | gernet_s,76.908,93.132 482 | skresnet34,76.904,93.32 483 | tf_efficientnet_b0,76.844,93.228 484 | ese_vovnet19b_dw,76.802,93.272 485 | resnext26ts,76.78,93.128 486 | hrnet_w18,76.754,93.44 487 | resnet26d,76.704,93.15 488 | resmlp_12_224,76.656,93.18 489 | tf_efficientnet_lite1,76.64,93.22 490 | mixer_b16_224,76.612,92.228 491 | tf_efficientnet_es,76.596,93.204 492 | densenetblur121d,76.584,93.192 493 | hardcorenas_b,76.536,92.754 494 | mobilenetv2_140,76.522,92.996 495 | levit_128s,76.52,92.872 496 | repvgg_a2,76.458,93.01 497 | xcit_nano_12_p8_224_dist,76.32,93.088 498 | regnety_008,76.31,93.07 499 | dpn68,76.306,92.974 500 | tv_resnet50,76.134,92.868 501 | mixnet_s,75.992,92.798 502 | vit_small_patch32_224,75.986,93.27 503 | vit_tiny_r_s16_p8_384,75.954,93.264 504 | hardcorenas_a,75.92,92.52 505 | densenet169,75.898,93.03 506 | mobilenetv3_large_100,75.766,92.544 507 | tf_mixnet_s,75.65,92.628 508 | mobilenetv3_rw,75.632,92.708 509 | densenet121,75.584,92.652 510 | tf_mobilenetv3_large_100,75.518,92.604 511 | resnest14d,75.504,92.52 512 | efficientnet_lite0,75.476,92.512 513 | vit_tiny_patch16_224,75.462,92.844 514 | xcit_nano_12_p16_384_dist,75.456,92.69 515 | semnasnet_100,75.45,92.6 516 | resnet26,75.3,92.578 517 | regnety_006,75.25,92.534 518 | repvgg_b0,75.16,92.418 519 | fbnetc_100,75.13,92.386 520 | hrnet_w18_small_v2,75.118,92.416 521 | resnet34,75.114,92.284 522 | mobilenetv2_110d,75.038,92.184 523 | regnetx_008,75.034,92.34 524 | efficientnet_es_pruned,74.996,92.44 525 | tinynet_b,74.976,92.184 526 | tf_efficientnet_lite0,74.832,92.174 527 | legacy_seresnet34,74.808,92.126 528 | tv_densenet121,74.744,92.152 529 | mnasnet_100,74.658,92.112 530 | mobilevit_xs,74.644,92.356 531 | dla34,74.62,92.072 532 | gluon_resnet34_v1b,74.588,91.988 533 | pit_ti_distilled_224,74.532,92.096 534 | deit_tiny_distilled_patch16_224,74.512,91.886 535 | vgg19_bn,74.214,91.848 536 | spnasnet_100,74.084,91.82 537 | regnety_004,74.024,91.754 538 | ghostnet_100,73.974,91.46 539 | crossvit_9_240,73.96,91.968 540 | xcit_nano_12_p8_224,73.91,92.168 541 | regnetx_006,73.86,91.672 542 | vit_base_patch32_224_sam,73.694,91.01 543 | tf_mobilenetv3_large_075,73.436,91.344 544 | vgg16_bn,73.35,91.504 545 | crossvit_tiny_240,73.332,91.914 546 | tv_resnet34,73.306,91.424 547 | swsl_resnet18,73.276,91.736 548 | convit_tiny,73.114,91.714 549 | skresnet18,73.036,91.168 550 | semnasnet_075,72.972,91.136 551 | mobilenetv2_100,72.97,91.02 552 | pit_ti_224,72.912,91.406 553 | ssl_resnet18,72.608,91.424 554 | regnetx_004,72.392,90.832 555 | vgg19,72.366,90.87 556 | hrnet_w18_small,72.338,90.68 557 | xcit_nano_12_p16_224_dist,72.302,90.858 558 | resnet18d,72.25,90.688 559 | tf_mobilenetv3_large_minimal_100,72.25,90.63 560 | deit_tiny_patch16_224,72.172,91.114 561 | lcnet_100,72.104,90.376 562 | mixer_l16_224,72.054,87.662 563 | vit_tiny_r_s16_p8_224,71.792,90.822 564 | legacy_seresnet18,71.742,90.332 565 | vgg13_bn,71.594,90.376 566 | vgg16,71.59,90.382 567 | tinynet_c,71.228,89.75 568 | gluon_resnet18_v1b,70.834,89.762 569 | vgg11_bn,70.36,89.802 570 | regnety_002,70.254,89.532 571 | xcit_nano_12_p16_224,69.954,89.754 572 | vgg13,69.926,89.246 573 | resnet18,69.744,89.082 574 | vgg11,69.028,88.626 575 | mobilevit_xxs,68.92,88.944 576 | lcnet_075,68.816,88.37 577 | regnetx_002,68.756,88.556 578 | tf_mobilenetv3_small_100,67.924,87.664 579 | dla60x_c,67.892,88.426 580 | mobilenetv3_small_100,67.656,87.634 581 | tinynet_d,66.958,87.064 582 | mnasnet_small,66.206,86.508 583 | dla46x_c,65.97,86.98 584 | mobilenetv2_050,65.942,86.082 585 | tf_mobilenetv3_small_075,65.714,86.134 586 | mobilenetv3_small_075,65.242,85.438 587 | dla46_c,64.866,86.294 588 | lcnet_050,63.1,84.382 589 | tf_mobilenetv3_small_minimal_100,62.908,84.234 590 | tinynet_e,59.856,81.764 591 | mobilenetv3_small_050,57.89,80.194 592 | alexnet,56.522,79.066 593 | vgg11,69.02,88.628 594 | vgg13,69.928,89.246 595 | vgg16,71.592,90.382 596 | vgg19,72.376,90.876 597 | vgg11_bn,70.37,89.81 598 | vgg13_bn,71.586,90.374 599 | vgg16_bn,73.36,91.516 600 | vgg19_bn,74.218,91.842 601 | resnet18,69.758,89.078 602 | resnet34,73.314,91.42 603 | resnet50,76.13,92.862 604 | resnet101,77.374,93.546 605 | resnet152,78.312,94.046 606 | squeezenet_1_0,58.092,80.42 607 | squeezenet_1_1,58.178,80.624 608 | densenet121,74.434,91.972 609 | densenet169,75.6,92.806 610 | densenet201,76.896,93.37 611 | densenet161,77.138,93.56 612 | inception_v3,77.294,93.45 613 | goolenet,69.778,89.53 614 | shufflenet_v2_x0_5,69.362,88.316 615 | shufflenet_v2_x1_0,60.552,81.746 616 | mobilenet_v2,71.878,90.286 617 | mobilenet_v3_large,74.042,91.34 618 | mobilenet_v3_small,67.668,87.402 619 | resnext50_32x4d,77.618,93.698 620 | resnext101_32x8d,79.312,94.526 621 | wide_resnet50_2,78.468,94.086 622 | wide_resnet101_2,78.848,94.284 623 | mnasnet0_5,73.456,91.51 624 | mnasnet1_0,67.734,87.49 625 | efficientnet_b0,77.692,93.532 626 | efficientnet_b1,78.642,94.186 627 | efficientnet_b2,80.608,95.31 628 | efficientnet_b3,82.008,96.054 629 | efficientnet_b4,83.384,96.594 630 | efficientnet_b5,83.444,96.628 631 | efficientnet_b6,84.008,96.916 632 | efficientnet_b7,84.122,96.908 633 | regnet_x_400mf,72.834,90.95 634 | regnet_x_800mf,75.212,92.348 635 | regnet_x_1_6gf,77.04,93.44 636 | regnet_x_3_2gf,78.364,93.992 637 | regnet_x_8gf,79.344,94.686 638 | regnet_x_16gf,80.058,94.944 639 | regnet_x_32gf,80.622,95.248 640 | regnet_y_400mf,74.046,91.716 641 | regnet_y_800mf,76.42,93.136 642 | regnet_y_1_6gf,77.95,93.966 643 | regnet_y_3_2gf,78.948,94.576 644 | regnet_y_8gf,80.032,95.048 645 | regnet_y_16gf,80.424,95.24 646 | regnet_y_32gf,80.878,95.34 647 | vit_b_16,81.072,95.318 648 | vit_b_32,75.912,92.466 649 | vit_l_16,79.662,94.638 650 | vit_l_32,76.972,93.07 651 | convnext_tiny,82.52,96.146 652 | convnext_small,83.616,96.65 653 | convnext_base,84.062,96.87 654 | convnext_large,84.414,96.976 655 | -------------------------------------------------------------------------------- /deepdive/model_metadata.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings; warnings.filterwarnings(\"ignore\")" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os, sys, json\n", 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "from tqdm.auto import tqdm as tqdm" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "os.environ['CUDA_VISIBLE_DEVICES'] = ''" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 4, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "sys.path.append('../model_options')\n", 40 | "from feature_extraction import *\n", 41 | "from model_options import *" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "### Model Metadata Build" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "def get_model_metadata(model_option, convert_to_dataframe = True):\n", 58 | " model_name = model_option['model_name']\n", 59 | " train_type = model_option['train_type']\n", 60 | " model = eval(model_option['call'])\n", 61 | " model = prep_model_for_extraction(model)\n", 62 | " \n", 63 | " layer_metadata = get_feature_map_metadata(model)\n", 64 | " layer_count = len(layer_metadata.keys())\n", 65 | " layer_order = [model_layer for model_layer in layer_metadata]\n", 66 | " feature_counts = [layer_metadata[layer]['feature_count'] / 3\n", 67 | " for layer in layer_metadata]\n", 68 | " parameter_counts = [layer_metadata[layer]['parameter_count'] \n", 69 | " for layer in layer_metadata]\n", 70 | " feature_map_shapes = [layer_metadata[layer]['feature_map_shape'] \n", 71 | " for layer in layer_metadata]\n", 72 | " total_feature_count = int(np.array(feature_counts).sum())\n", 73 | " total_parameter_count = int(np.array(parameter_counts).sum())\n", 74 | " model_metadata = {'total_feature_count': total_feature_count,\n", 75 | " 'total_parameter_count': total_parameter_count,\n", 76 | " 'layer_count': layer_count,\n", 77 | " 'layer_metadata': layer_metadata}\n", 78 | " \n", 79 | " if not convert_to_dataframe:\n", 80 | " return(model_metadata)\n", 81 | " \n", 82 | " if convert_to_dataframe:\n", 83 | "\n", 84 | " model_metadata_dictlist = []\n", 85 | " \n", 86 | " for layer_index, layer in enumerate(layer_metadata):\n", 87 | " model_metadata_dictlist.append({'model': model_name, 'train_type': train_type,\n", 88 | " 'model_layer': layer, 'model_layer_index': layer_index + 1,\n", 89 | " 'model_layer_depth': (layer_index + 1) / layer_count,\n", 90 | " 'feature_count': layer_metadata[layer]['feature_count'],\n", 91 | " 'parameter_count': layer_metadata[layer]['parameter_count']})\n", 92 | "\n", 93 | " return(pd.DataFrame(model_metadata_dictlist))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "model_options = {**get_model_options(model_source = 'torchvision'),\n", 103 | " **get_model_options(model_source = 'timm'),\n", 104 | " **get_model_options(model_source = 'detectron'),\n", 105 | " **get_model_options(train_type = 'taskonomy'),\n", 106 | " **get_model_options(model_source = 'clip'),\n", 107 | " **get_model_options(model_source = 'slip'),\n", 108 | " **get_model_options(model_source = 'dino'),\n", 109 | " **get_model_options(model_source = 'midas'),\n", 110 | " **get_model_options(model_source = 'yolo'),\n", 111 | " **get_model_options(model_source = 'vissl'),\n", 112 | " **get_model_options(model_source = 'bit_expert')}\n", 113 | "\n", 114 | "model_metadata_dflist = []\n", 115 | "\n", 116 | "def process(model_option):\n", 117 | " incoming_metadata = get_model_metadata(model_options[model_option])\n", 118 | " model_metadata_dflist.append(incoming_metadata)\n", 119 | " \n", 120 | "problematic_model_options = []\n", 121 | "\n", 122 | "def remark(model_option):\n", 123 | " problematic_model_options.append(model_option)\n", 124 | " \n", 125 | "model_option_iterator = tqdm(model_options)\n", 126 | "for model_option in model_option_iterator:\n", 127 | " model_option_iterator.set_description(model_option)\n", 128 | " try: process(model_option)\n", 129 | " except: remark(model_option)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "pd.concat(model_metadata_dflist).to_csv('model_metadata.csv', index = None)\n", 139 | "\n", 140 | "if 'model_metadata_json' in globals():\n", 141 | " with open('model_metadata.json', 'w') as file:\n", 142 | " json.dump(model_metadata_json, file)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 8, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "model_accuracy_files = {'timm': 'timm_imagenet_accuracies.csv',\n", 152 | " 'torchvision': 'torchvision_imagenet_accuracies.csv'}\n", 153 | "\n", 154 | "model_accuracy_dflist = []\n", 155 | "for source in model_accuracy_files:\n", 156 | " model_accuracy_dflist.append(pd.read_csv('model_statistics/' + model_accuracy_files[source]))\n", 157 | " \n", 158 | "model_accuracy = pd.concat(model_accuracy_dflist)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 9, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "model_accuracy.to_csv('model_accuracy.csv', index = None)" 168 | ] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Synthese", 174 | "language": "python", 175 | "name": "synthese" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.7.6" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 4 192 | } 193 | -------------------------------------------------------------------------------- /deepdive/model_metadata.py: -------------------------------------------------------------------------------- 1 | import os, sys, json 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm.auto import tqdm as tqdm 5 | 6 | from feature_extraction import * 7 | from model_options import * 8 | 9 | def count_parameters(model): 10 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | 12 | def get_model_metadata(model_option, convert_to_dataframe = True): 13 | model_name = model_option['model_name'] 14 | train_type = model_option['train_type'] 15 | model = eval(model_option['call']) 16 | model = prep_model_for_extraction(model) 17 | 18 | layer_metadata = get_feature_map_metadata(model) 19 | layer_count = len(layer_metadata.keys()) 20 | layer_order = [model_layer for model_layer in layer_metadata] 21 | feature_counts = [layer_metadata[layer]['feature_count'] 22 | for layer in layer_metadata] 23 | parameter_counts = [layer_metadata[layer]['parameter_count'] 24 | for layer in layer_metadata] 25 | feature_map_shapes = [layer_metadata[layer]['feature_map_shape'] 26 | for layer in layer_metadata] 27 | total_feature_count = int(np.array(feature_counts).sum()) 28 | total_parameter_count = int(np.array(parameter_counts).sum()) 29 | model_metadata = {'total_feature_count': total_feature_count, 30 | 'total_parameter_count': total_parameter_count, 31 | 'layer_count': layer_count, 32 | 'layer_metadata': layer_metadata} 33 | 34 | if not convert_to_dataframe: 35 | return(model_metadata) 36 | 37 | if convert_to_dataframe: 38 | 39 | model_metadata_dictlist = [] 40 | 41 | for layer_index, layer in enumerate(layer_metadata): 42 | model_metadata_dictlist.append({'model': model_name, 'train_type': train_type, 43 | 'model_layer': layer, 'model_layer_index': layer_index + 1, 44 | 'model_layer_depth': (layer_index + 1) / layer_count, 45 | 'feature_count': layer_metadata[layer]['feature_count'] / 3, 46 | 'parameter_count': layer_metadata[layer]['parameter_count']}) 47 | 48 | return(pd.DataFrame(model_metadata_dictlist)) 49 | 50 | return model 51 | 52 | if __name__ == "__main__": 53 | 54 | model_options = {**get_model_options(train_type = 'classification', model_source = 'torchvision'), 55 | **get_model_options(train_type = 'classification', model_source = 'timm'), 56 | **get_model_options(train_type = 'random'), 57 | **get_model_options(model_source = 'detectron'), 58 | **get_model_options(train_type = 'taskonomy'), 59 | **get_model_options(model_source = 'clip'), 60 | **get_model_options(model_source = 'slip'), 61 | **get_model_options(model_source = 'dino'), 62 | **get_model_options(model_source = 'midas'), 63 | **get_model_options(model_source = 'yolo'), 64 | **get_model_options(model_source = 'vissl'), 65 | **get_model_options(model_source = 'bit_expert')} 66 | 67 | model_metadata_dflist = [] 68 | 69 | def process(model_option): 70 | incoming_metadata = get_model_metadata(model_options[model_option]) 71 | model_metadata_dflist.append(incoming_metadata) 72 | 73 | problematic_model_options = [] 74 | 75 | def remark(model_option): 76 | problematic_model_options.append(model_option) 77 | 78 | model_option_iterator = tqdm(model_options) 79 | for model_option in model_option_iterator: 80 | model_option_iterator.set_description(model_option) 81 | try: process(model_option) 82 | except: remark(model_option) 83 | 84 | print(problematic_model_options) 85 | 86 | pd.concat(model_metadata_dflist).to_csv('model_metadata.csv', index = None) 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /deepdive/model_options.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os, sys, torch 4 | import importlib 5 | 6 | filepath = os.path.dirname(os.path.abspath(__file__)) 7 | model_typology = pd.read_csv(filepath + '/model_typology.csv') 8 | model_typology['model_name'] = model_typology['model'] 9 | 10 | def subset_typology(model_source): 11 | return model_typology[model_typology['model_source'] == model_source].copy() 12 | 13 | _CACHED_TRANSFORMS = {} # a dictionary for storing transforms when they're included with the model 14 | 15 | # Torchvision Options --------------------------------------------------------------------------- 16 | 17 | from torch.hub import load_state_dict_from_url 18 | 19 | def get_torchvision_model(model_path, pretrained = True): 20 | from torchvision import models 21 | return eval(f'{model_path}(pretrained = {pretrained})') 22 | 23 | def define_torchvision_options(): 24 | torchvision_options = {} 25 | 26 | model_types = ['classification','segmentation', 'detection', 'video'] 27 | torchvision_dirs = dict(zip(model_types, ['.','.segmentation.', '.detection.', '.video.'])) 28 | 29 | def get_torchvision_directory(model_name): 30 | from torchvision import models 31 | for torchvision_dir in torchvision_dirs.values(): 32 | if model_name in eval('models{}__dict__'.format(torchvision_dir)): 33 | return torchvision_dir 34 | 35 | torchvision_typology = model_typology[model_typology['model_source'] == 'torchvision'].copy() 36 | training_calls = {'random': False, 'pretrained': True} 37 | for index, row in torchvision_typology.iterrows(): 38 | model_name = row['model_name'] 39 | train_type = row['train_type'] 40 | train_data = row['train_data'] 41 | if train_type == 'random': 42 | train_data = 'None' 43 | model_source = 'torchvision' 44 | torchvision_dir = get_torchvision_directory(model_name) 45 | model_string = '_'.join([model_name, train_type]) 46 | training = 'random' if train_type == 'random' else 'pretrained' 47 | model_path = 'models' + torchvision_dir + model_name 48 | model_call = "get_torchvision_model('{}', {})".format(model_path, training_calls[training]) 49 | torchvision_options[model_string] = {'model_name': model_name, 'train_type': train_type, 50 | 'train_data': train_data, 'model_source': model_source, 'call': model_call} 51 | 52 | return torchvision_options 53 | 54 | import torchvision.transforms as transforms 55 | 56 | def get_torchvision_transforms(train_type, input_type = 'PIL'): 57 | imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 58 | 59 | base_transforms = [transforms.Resize((224,224)), transforms.ToTensor()] 60 | 61 | if train_type == 'random': specific_transforms = base_transforms 62 | 63 | if train_type == 'classification' or train_type == 'imagenet': 64 | specific_transforms = base_transforms + [transforms.Normalize(**imagenet_stats)] 65 | 66 | if input_type == 'PIL': 67 | recommended_transforms = specific_transforms 68 | if input_type == 'numpy': 69 | recommended_transforms = [transforms.ToPILImage()] + specific_transforms 70 | 71 | return transforms.Compose(recommended_transforms) 72 | 73 | # Timm Options --------------------------------------------------------------------------- 74 | 75 | def get_timm_model(model_name, pretrained = True): 76 | from timm import create_model 77 | return create_model(model_name, pretrained) 78 | 79 | def define_timm_options(): 80 | timm_options = {} 81 | 82 | timm_typology = model_typology[model_typology['model_source'] == 'timm'].copy() 83 | for index, row in timm_typology.iterrows(): 84 | model_name = row['model_name'] 85 | train_type = row['train_type'] 86 | train_data = row['train_data'] 87 | if train_type == 'random': 88 | train_data = 'None' 89 | model_source = 'timm' 90 | model_string = '_'.join([model_name, train_type]) 91 | train_bool = False if train_type == 'random' else True 92 | model_call = "get_timm_model('{}', pretrained = {})".format(model_name, train_bool) 93 | timm_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 94 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 95 | 96 | return timm_options 97 | 98 | def modify_timm_transform(timm_transform): 99 | 100 | transform_list = timm_transform.transforms 101 | 102 | crop_index, crop_size = next((index, transform.size) for index, transform 103 | in enumerate(transform_list) if 'CenterCrop' in str(transform)) 104 | resize_index, resize_size = next((index, transform.size) for index, transform 105 | in enumerate(transform_list) if 'Resize' in str(transform)) 106 | 107 | transform_list[resize_index].size = crop_size 108 | transform_list.pop(crop_index) 109 | return transforms.Compose(transform_list) 110 | 111 | def get_timm_transforms(model_name, input_type = 'PIL'): 112 | from timm.data.transforms_factory import create_transform 113 | from timm.data import resolve_data_config 114 | 115 | config = resolve_data_config({}, model = model_name) 116 | timm_transforms = create_transform(**config) 117 | timm_transform = modify_timm_transform(timm_transforms) 118 | 119 | if input_type == 'PIL': 120 | recommended_transforms = timm_transform.transforms 121 | if input_type == 'numpy': 122 | recommended_transforms = [transforms.ToPILImage()] + timm_transform.transforms 123 | 124 | return transforms.Compose(recommended_transforms) 125 | 126 | 127 | # Taskonomy Options --------------------------------------------------------------------------- 128 | 129 | def get_taskonomy_encoder(model_name, pretrained = True, verbose = False): 130 | from visualpriors.taskonomy_network import TASKONOMY_PRETRAINED_URLS 131 | from visualpriors import taskonomy_network 132 | 133 | weights_url = TASKONOMY_PRETRAINED_URLS[model_name + '_encoder'] 134 | weights = torch.utils.model_zoo.load_url(weights_url) 135 | if verbose: print('{} weights loaded succesfully.'.format(model_name)) 136 | model = taskonomy_network.TaskonomyEncoder() 137 | model.load_state_dict(weights['state_dict']) 138 | 139 | return model 140 | 141 | def random_taskonomy_encoder(): 142 | from visualpriors import taskonomy_network 143 | return taskonomy_network.TaskonomyEncoder() 144 | 145 | def define_taskonomy_options(): 146 | taskonomy_options = {} 147 | 148 | task_typology = model_typology[model_typology['model_source'] == 'taskonomy'].copy() 149 | for index, row in task_typology.iterrows(): 150 | model_name = row['model_name'] 151 | train_type = row['train_type'] 152 | train_data = row['train_data'] 153 | model_source = 'taskonomy' 154 | model_string = model_name + '_' + train_type 155 | model_call = "get_taskonomy_encoder('{}')".format(model_name) 156 | taskonomy_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 157 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 158 | 159 | taskonomy_options['random_weights_taskonomy'] = {'model_name': 'random_weights', 'train_type': 'taskonomy', 160 | 'train_data': 'None', 'model_source': 'taskonomy', 161 | 'call': 'random_taskonomy_encoder()'} 162 | 163 | return taskonomy_options 164 | 165 | import torchvision.transforms.functional as functional 166 | 167 | def taskonomy_transform(image): 168 | return (functional.to_tensor(functional.resize(image, (256,256))) * 2 - 1)#.unsqueeze_(0) 169 | 170 | def get_taskonomy_transforms(input_type = 'PIL'): 171 | recommended_transforms = taskonomy_transform 172 | if input_type == 'PIL': 173 | return recommended_transforms 174 | if input_type == 'numpy': 175 | def functional_from_numpy(image): 176 | image = functional.to_pil_image(image) 177 | return recommended_transforms(image) 178 | return functional_from_numpy 179 | 180 | # CLIP Options --------------------------------------------------------------------------- 181 | 182 | def get_clip_model(model_name): 183 | import clip; model, _ = clip.load(model_name, device='cpu') 184 | return model.visual 185 | 186 | def define_clip_options(): 187 | clip_options = {} 188 | 189 | clip_typology = model_typology[model_typology['model_source'] == 'clip'].copy() 190 | for index, row in clip_typology.iterrows(): 191 | model_name = row['model_name'] 192 | train_type = row['train_type'] 193 | train_data = row['train_data'] 194 | model_source = 'clip' 195 | model_string = '_'.join([model_name, train_type]) 196 | model_call = "get_clip_model('{}')".format(model_name) 197 | clip_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 198 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 199 | 200 | return clip_options 201 | 202 | def get_clip_transforms(model_name, input_type = 'PIL'): 203 | import clip; _, preprocess = clip.load(model_name, device = 'cpu') 204 | if input_type == 'PIL': 205 | recommended_transforms = preprocess.transforms 206 | if input_type == 'numpy': 207 | recommended_transforms = [transforms.ToPILImage()] + preprocess.transforms 208 | recommended_transforms = transforms.Compose(recommended_transforms) 209 | 210 | using_half_tensor_models = False 211 | if using_half_tensor_models: 212 | if 'ViT' in model_name: 213 | def transform_plus_retype(image_input): 214 | return recommended_transforms(image_input).type(torch.HalfTensor) 215 | return transform_plus_retype 216 | if 'ViT' not in model_name: 217 | return recommended_transforms 218 | 219 | if not using_half_tensor_models: 220 | return recommended_transforms 221 | 222 | # VISSL Options --------------------------------------------------------------------------- 223 | 224 | def get_vissl_model(model_name): 225 | vissl_data = (model_typology[model_typology['model_source'] == 'vissl'] 226 | .set_index('model_name').to_dict('index')) 227 | 228 | weights = load_state_dict_from_url(vissl_data[model_name]['weights_url'], map_location = torch.device('cpu')) 229 | 230 | def replace_module_prefix(state_dict, prefix, replace_with = ''): 231 | return {(key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val 232 | for (key, val) in state_dict.items()} 233 | 234 | def convert_model_weights(model): 235 | if "classy_state_dict" in model.keys(): 236 | model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"] 237 | elif "model_state_dict" in model.keys(): 238 | model_trunk = model["model_state_dict"] 239 | else: 240 | model_trunk = model 241 | return replace_module_prefix(model_trunk, "_feature_blocks.") 242 | 243 | converted_weights = convert_model_weights(weights) 244 | excess_weights = ['fc','projection', 'prototypes'] 245 | converted_weights = {key:value for (key,value) in converted_weights.items() 246 | if not any([x in key for x in excess_weights])} 247 | 248 | if 'module' in next(iter(converted_weights)): 249 | converted_weights = {key.replace('module.',''):value for (key,value) in converted_weights.items() 250 | if 'fc' not in key} 251 | 252 | from torchvision.models import resnet50 253 | import torch.nn as nn 254 | 255 | class Identity(nn.Module): 256 | def __init__(self): 257 | super(Identity, self).__init__() 258 | 259 | def forward(self, x): 260 | return x 261 | 262 | model = resnet50() 263 | model.fc = Identity() 264 | 265 | model.load_state_dict(converted_weights) 266 | 267 | return model 268 | 269 | 270 | def define_vissl_options(): 271 | vissl_options = {} 272 | 273 | vissl_typology = model_typology[model_typology['model_source'] == 'vissl'].copy() 274 | for index, row in vissl_typology.iterrows(): 275 | model_name = row['model_name'] 276 | train_type = row['train_type'] 277 | train_data = row['train_data'] 278 | model_source = 'vissl' 279 | model_string = '_'.join([model_name, train_type]) 280 | model_call = "get_vissl_model('{}')".format(model_name) 281 | vissl_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 282 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 283 | 284 | return vissl_options 285 | 286 | 287 | def get_vissl_transforms(input_type = 'PIL'): 288 | imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 289 | 290 | base_transforms = [transforms.Resize((224,224)), transforms.ToTensor()] 291 | specific_transforms = base_transforms + [transforms.Normalize(**imagenet_stats)] 292 | 293 | if input_type == 'PIL': 294 | recommended_transforms = specific_transforms 295 | if input_type == 'numpy': 296 | recommended_transforms = [transforms.ToPILImage()] + specific_transforms 297 | 298 | return transforms.Compose(recommended_transforms) 299 | 300 | # Dino Options --------------------------------------------------------------------------- 301 | 302 | def define_dino_options(): 303 | dino_options = {} 304 | 305 | dino_typology = model_typology[model_typology['model_source'] == 'dino'].copy() 306 | for index, row in dino_typology.iterrows(): 307 | model_name = row['model_name'] 308 | train_type = row['train_type'] 309 | train_data = row['train_data'] 310 | model_source = 'dino' 311 | model_string = '_'.join([model_name, train_type]) 312 | model_call = "torch.hub.load('facebookresearch/dino:main', '{}')".format(model_name) 313 | dino_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 314 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 315 | 316 | return dino_options 317 | 318 | def get_dino_transforms(input_type = 'PIL'): 319 | imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} 320 | 321 | base_transforms = [transforms.Resize((224,224)), transforms.ToTensor()] 322 | 323 | specific_transforms = base_transforms + [transforms.Normalize(**imagenet_stats)] 324 | 325 | if input_type == 'PIL': 326 | recommended_transforms = specific_transforms 327 | if input_type == 'numpy': 328 | recommended_transforms = [transforms.ToPILImage()] + specific_transforms 329 | 330 | return transforms.Compose(recommended_transforms) 331 | 332 | 333 | # MiDas Options --------------------------------------------------------------------------- 334 | 335 | def define_midas_options(): 336 | midas_options = {} 337 | 338 | midas_typology = model_typology[model_typology['model_source'] == 'midas'].copy() 339 | for index, row in midas_typology.iterrows(): 340 | model_name = row['model'] 341 | train_type = row['train_type'] 342 | train_data = row['train_data'] 343 | model_source = 'midas' 344 | model_string = '_'.join([model_name, train_type]) 345 | model_call = "torch.hub.load('intel-isl/MiDaS','{}')".format(model_name) 346 | midas_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 347 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 348 | 349 | return midas_options 350 | 351 | def get_midas_transforms(model_name, input_type = 'PIL'): 352 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 353 | 354 | if model_name in ['DPT_Large', 'DPT_Hybrid']: 355 | transform = midas_transforms.dpt_transform 356 | if model_name not in ['DPT_Large', 'DPT_Hybrid']: 357 | transform = midas_transforms.small_transform 358 | 359 | transforms_lambda = [lambda img: np.array(img)] + transform.transforms 360 | transforms_lambda += [lambda tensor: tensor.squeeze()] 361 | 362 | if input_type == 'PIL': 363 | recommended_transforms = transforms_lambda 364 | if input_type == 'numpy': 365 | recommended_transforms = [transforms.ToPILImage()] + transforms_lambda 366 | 367 | return transforms.Compose(recommended_transforms) 368 | 369 | # MiDas Options --------------------------------------------------------------------------- 370 | 371 | def define_yolo_options(): 372 | yolo_options = {} 373 | 374 | yolo_typology = model_typology[model_typology['model_source'] == 'yolo'].copy() 375 | for index, row in yolo_typology.iterrows(): 376 | model_name = row['model'] 377 | train_type = row['train_type'] 378 | train_data = row['train_data'] 379 | model_source = 'yolo' 380 | model_string = '_'.join([model_name, train_type]) 381 | model_call = "torch.hub.load('ultralytics/yolov5','{}', autoshape = False, force_reload = True)".format(model_name) 382 | yolo_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 383 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 384 | 385 | return yolo_options 386 | 387 | def get_yolo_transforms(model_name, input_type = 'PIL'): 388 | assert input_type == 'PIL', "YoloV5 models requires input_type == 'PIL'" 389 | from PIL import Image 390 | 391 | def yolo_transforms(pil_image, size = (256,256)): 392 | img = np.asarray(pil_image.resize(size, Image.BICUBIC)) 393 | if img.shape[0] < 5: # image in CHW 394 | img = img.transpose((1, 2, 0)) 395 | img = img[:, :, :3] if img.ndim == 3 else np.tile(img[:, :, None], 3) 396 | img = img if img.data.contiguous else np.ascontiguousarray(img) 397 | img = np.ascontiguousarray(img.transpose((2, 0, 1))) 398 | img_tensor = torch.from_numpy(img) / 255. 399 | 400 | return img_tensor 401 | 402 | return yolo_transforms 403 | 404 | # Detectron2 Options --------------------------------------------------------------------------- 405 | 406 | def get_detectron_model(model_name, downsize = True, backbone_only = True): 407 | from detectron2.modeling import build_model 408 | from detectron2 import model_zoo 409 | from detectron2.checkpoint import DetectionCheckpointer 410 | 411 | detectron_data = subset_typology('detectron') 412 | detectron_dict = (detectron_data.set_index('model').to_dict('index')) 413 | weights_path = detectron_dict[model_name]['weights_url'] 414 | 415 | cfg = model_zoo.get_config(weights_path) 416 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(weights_path) 417 | 418 | cfg_clone = cfg.clone() 419 | model = build_model(cfg_clone) 420 | model = model.eval() 421 | 422 | checkpointer = DetectionCheckpointer(model) 423 | checkpointer.load(cfg.MODEL.WEIGHTS) 424 | 425 | if backbone_only: 426 | return model.backbone 427 | if not backbone_only: 428 | return model 429 | 430 | 431 | def define_detectron_options(): 432 | detectron_options = {} 433 | 434 | detectron_typology = model_typology[model_typology['model_source'] == 'detectron'].copy() 435 | for index, row in detectron_typology.iterrows(): 436 | model_name = row['model'] 437 | train_type = row['train_type'] 438 | train_data = row['train_data'] 439 | model_source = 'detectron' 440 | model_string = '_'.join([model_name, train_type]) 441 | model_call = "get_detectron_model('{}')".format(model_name) 442 | detectron_options[model_string] = ({'model_name': model_name, 'train_type': train_type, 443 | 'train_data': train_data, 'model_source': model_source, 'call': model_call}) 444 | 445 | return detectron_options 446 | 447 | def get_detectron_transforms(model_name, input_type = 'PIL', downsize = True): 448 | import detectron2.data.transforms as detectron_transform 449 | from detectron2 import model_zoo 450 | 451 | detectron_data = subset_typology('detectron') 452 | detectron_dict = (detectron_data.set_index('model').to_dict('index')) 453 | weights_path = detectron_dict[model_name]['weights_url'] 454 | 455 | cfg = model_zoo.get_config(weights_path) 456 | model = get_detectron_model(model_name, backbone_only = False) 457 | 458 | augment = detectron_transform.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, 459 | cfg.INPUT.MIN_SIZE_TEST], 460 | cfg.INPUT.MAX_SIZE_TEST) 461 | 462 | if downsize: 463 | augment = detectron_transform.ResizeShortestEdge([224,224], 256) 464 | 465 | def detectron_transforms(original_image): 466 | if input_type == 'PIL': 467 | original_image = np.asarray(original_image) 468 | original_image = original_image[:, :, ::-1] 469 | height, width = original_image.shape[:2] 470 | image = augment.get_transform(original_image).apply_image(original_image) 471 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 472 | 473 | inputs = {"image": image, "height": height, "width": width} 474 | return model.preprocess_image([inputs]).tensor.squeeze() 475 | 476 | return detectron_transforms 477 | 478 | # Aggregate Options --------------------------------------------------------------------------- 479 | 480 | if importlib.util.find_spec('custom_models') is not None: 481 | from custom_models import * 482 | 483 | def get_model_options(train_type=None, train_data = None, model_source=None): 484 | model_options = {**define_torchvision_options(), 485 | **define_taskonomy_options(), 486 | **define_timm_options(), 487 | **define_clip_options(), 488 | **define_vissl_options(), 489 | **define_yolo_options(), 490 | **define_dino_options(), 491 | **define_midas_options(), 492 | **define_detectron_options()} 493 | 494 | if importlib.util.find_spec('custom_models') is not None: 495 | model_options = {**model_options, 496 | **get_custom_model_options()} 497 | 498 | if train_type is not None: 499 | model_options = {string: info for (string, info) in model_options.items() 500 | if model_options[string]['train_type'] == train_type} 501 | 502 | if train_data is not None: 503 | model_options = {string: info for (string, info) in model_options.items() 504 | if model_options[string]['train_data'] == train_data} 505 | 506 | if model_source is not None: 507 | model_options = {string: info for (string, info) in model_options.items() 508 | if model_options[string]['model_source'] == model_source} 509 | 510 | return model_options 511 | 512 | 513 | transform_options = {'torchvision': get_torchvision_transforms, 514 | 'timm': get_timm_transforms, 515 | 'taskonomy': get_taskonomy_transforms, 516 | 'clip': get_clip_transforms, 517 | 'vissl': get_vissl_transforms, 518 | 'yolo': get_yolo_transforms, 519 | 'dino': get_dino_transforms, 520 | 'midas': get_midas_transforms, 521 | 'detectron': get_detectron_transforms} 522 | 523 | def get_transform_options(): 524 | return transform_options 525 | 526 | def get_transform_types(): 527 | return list(transform_options.keys()) 528 | 529 | def get_recommended_transforms(model_query, input_type = 'PIL'): 530 | cached_model_types = ['imagenet','taskonomy','vissl'] 531 | model_types = model_typology['train_type'].unique() 532 | if model_query in get_model_options(): 533 | model_option = get_model_options()[model_query] 534 | model_type = model_option['train_type'] 535 | model_name = model_option['model_name'] 536 | model_source = model_option['model_source'] 537 | if model_query in model_types: 538 | model_type = model_query 539 | 540 | if model_type in cached_model_types: 541 | if model_type == 'imagenet': 542 | return get_torchvision_transforms('imagenet', input_type) 543 | if model_type == 'vissl': 544 | return get_vissl_transforms(input_type) 545 | if model_type == 'taskonomy': 546 | return get_taskonomy_transforms(input_type) 547 | 548 | if model_type not in cached_model_types: 549 | if model_source == 'torchvision': 550 | return transform_options[model_source](model_type, input_type) 551 | if model_source in ['timm', 'clip', 'detectron']: 552 | return transform_options[model_source](model_name, input_type) 553 | if model_source in ['taskonomy', 'vissl', 'dino', 'yolo', 'midas']: 554 | return transform_options[model_source](input_type) 555 | 556 | if importlib.util.find_spec('custom_models') is not None: 557 | if model_query in (list(get_custom_model_options()) + 558 | list(custom_transform_options)): 559 | return get_custom_transforms(model_query, input_type) 560 | 561 | if model_query not in list(get_model_options()) + list(model_types): 562 | raise ValueError('No reference available for this model query.') 563 | -------------------------------------------------------------------------------- /deepdive/model_opts_utils.py: -------------------------------------------------------------------------------- 1 | ### Auxiliary Functions: Feature Extraction --------------------------------------------------------- 2 | 3 | from feature_extraction import * 4 | 5 | def chunk_list(lst, n): 6 | for i in range(0, len(lst), n): 7 | yield lst[i:i + n] 8 | 9 | def get_image_transforms(): 10 | imagenet_stats = {'mean': [0.485, 0.456, 0.406], 11 | 'std': [0.229, 0.224, 0.225]} 12 | 13 | options = {'imagenet': [transforms.Resize((224,224)), 14 | transforms.ToTensor(), 15 | transforms.Normalize(**imagenet_stats)]} 16 | 17 | return {key:transforms.Compose(value) 18 | for (key, value) in options.items()} 19 | 20 | def get_weights_dtype(model): 21 | module = list(model.children())[0] 22 | if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList): 23 | return module.weight.dtype 24 | if isinstance(module, nn.Sequential) or isinstance(module, nn.ModuleList): 25 | return get_weights_dtype(module) 26 | 27 | def reverse_typical_transforms(img_array): 28 | if torch.is_tensor(img_array): 29 | img_array = img_array.numpy() 30 | if len(img_array.shape) == 3: 31 | img_array = img_array.transpose((1,2,0)) 32 | if len(img_array.shape) == 4: 33 | img_array = img_array.transpose((0,2,3,1)) 34 | 35 | return(img_array) 36 | 37 | def reverse_imagenet_transforms(img_array): 38 | if torch.is_tensor(img_array): 39 | img_array = img_array.numpy() 40 | if len(img_array.shape) == 3: 41 | img_array = img_array.transpose((1,2,0)) 42 | if len(img_array.shape) == 4: 43 | img_array = img_array.transpose((0,2,3,1)) 44 | mean = np.array([0.485, 0.456, 0.406]) 45 | std = np.array([0.229, 0.224, 0.225]) 46 | img_array = np.clip(std * img_array + mean, 0, 1) 47 | 48 | return(img_array) 49 | 50 | def numpy_to_pil(img_array): 51 | img_dim = np.array(img_array.shape) 52 | if (img_dim[-1] not in (1,3)) & (len(img_dim) == 3): 53 | img_array = img_array.transpose(1,2,0) 54 | if (img_dim[-1] not in (1,3)) & (len(img_dim) == 4): 55 | img_array = img_array.transpose(0,2,3,1) 56 | if ((img_array >= 0) & (img_array <= 1)).all(): 57 | img_array = img_array * 255 58 | if img_array.dtype != 'uint8': 59 | img_array = np.uint8(img_array) 60 | 61 | return (img_array) 62 | 63 | from torchvision.utils import make_grid 64 | 65 | def get_dataloader_sample(dataloader, nrow = 5, figsize = (5,5), title=None, 66 | reverse_transforms = reverse_imagenet_transforms): 67 | 68 | image_batch = next(iter(dataloader)) 69 | batch_size = image_batch.shape[0] 70 | image_grid = make_grid(image_batch, nrow = batch_size // nrow) 71 | if reverse_transforms: 72 | image_grid = reverse_transforms(image_grid) 73 | plt.figure(figsize=figsize) 74 | plt.imshow(image_grid) 75 | plt.axis('off') 76 | if title is not None: 77 | plt.title(title) 78 | plt.pause(0.001) 79 | 80 | ### Auxiliary Functions: Mapping Methods --------------------------------------------------------- 81 | 82 | from mapping_methods import * 83 | 84 | def get_best_alpha_index(regression): 85 | best_score = 0; best_alpha_index = 0 86 | for alpha_index, alpha_value in enumerate(regression.alphas): 87 | score = score_func(xy['train']['y'], regression.cv_values_[:, :, alpha_index].squeeze()).mean() 88 | if score >= best_score: 89 | best_alpha_index, best_score = alpha_index, score 90 | 91 | return best_alpha_index -------------------------------------------------------------------------------- /deepdive/model_scores.csv: -------------------------------------------------------------------------------- 1 | model,imagenet_top1,imagenet_top5,brain_score 2 | alexnet,43.45,20.91,0.324290212 3 | densenet121,25.35,7.83,0.358620416 4 | densenet161,22.35,6.2, 5 | densenet169,24.0,7.0,0.364866221 6 | densenet201,22.8,6.43,0.362708788 7 | googlenet,30.22,10.47, 8 | mnasnet0_5,,, 9 | mnasnet1_0,26.49,8.456,0.323444394 10 | mobilenet_v2,28.12,9.71, 11 | resnet101,22.63,6.44,0.36466118 12 | resnet152,21.69,5.94,0.362920606 13 | resnet18,30.24,10.92,0.34994673299999995 14 | resnet34,26.7,8.58,0.343839329 15 | resnet50,23.85,7.13,0.363578741 16 | resnext101_32x8d,20.69,5.47, 17 | resnext50_32x4d,22.38,6.3, 18 | shufflenet_v2_x0_5,,, 19 | shufflenet_v2_x1_0,30.64,11.68, 20 | squeezenet1_0,41.9,19.58,0.29096129 21 | squeezenet1_1,41.81,19.38,0.308516645 22 | vgg11,30.98,11.37, 23 | vgg11_bn,29.62,10.19, 24 | vgg13,30.07,10.75, 25 | vgg13_bn,28.45,9.63, 26 | vgg16,28.41,9.62,0.34200869100000003 27 | vgg16_bn,26.63,8.5, 28 | vgg19,27.62,9.12,0.342096768 29 | vgg19_bn,25.76,8.15, 30 | wide_resnet101_2,21.16,5.72, 31 | wide_resnet50_2,21.49,5.91, 32 | -------------------------------------------------------------------------------- /deepdive/model_statistics/__init__.py: -------------------------------------------------------------------------------- 1 | __name__ ='deepdive' 2 | 3 | from .feature_extraction import * 4 | from .feature_reduction import * 5 | from .model_metadata import * 6 | from .mapping_methods import * 7 | from .model_opts_utils import * 8 | from .ridge_gcv_mod import * -------------------------------------------------------------------------------- /deepdive/model_statistics/timm_imagenet_accuracies.csv: -------------------------------------------------------------------------------- 1 | model,top1_accuracy,top5_accuracy 2 | beit_large_patch16_512,88.600,98.656 3 | beit_large_patch16_384,88.404,98.608 4 | tf_efficientnet_l2_ns,88.348,98.648 5 | tf_efficientnet_l2_ns_475,88.232,98.546 6 | convnext_xlarge_384_in22ft1k,87.546,98.486 7 | beit_large_patch16_224,87.474,98.304 8 | convnext_large_384_in22ft1k,87.396,98.368 9 | swin_large_patch4_window12_384,87.150,98.240 10 | vit_large_patch16_384,87.080,98.300 11 | volo_d5_512,87.042,97.968 12 | convnext_xlarge_in22ft1k,87.002,98.212 13 | volo_d5_448,86.952,97.940 14 | tf_efficientnet_b7_ns,86.838,98.096 15 | beit_base_patch16_384,86.798,98.136 16 | volo_d4_448,86.790,97.882 17 | convnext_large_in22ft1k,86.636,98.028 18 | convnext_base_384_in22ft1k,86.544,98.190 19 | volo_d3_448,86.496,97.710 20 | cait_m48_448,86.486,97.752 21 | tf_efficientnet_b6_ns,86.454,97.882 22 | swin_base_patch4_window12_384,86.432,98.056 23 | tf_efficientnetv2_xl_in21ft1k,86.418,97.866 24 | swin_large_patch4_window7_224,86.314,97.892 25 | tf_efficientnetv2_l_in21ft1k,86.304,97.978 26 | vit_large_r50_s32_384,86.182,97.920 27 | dm_nfnet_f6,86.142,97.730 28 | tf_efficientnet_b5_ns,86.090,97.750 29 | volo_d5_224,86.072,97.578 30 | cait_m36_384,86.052,97.730 31 | volo_d2_384,86.040,97.572 32 | vit_base_patch16_384,86.010,98.002 33 | xcit_large_24_p8_384_dist,85.998,97.684 34 | volo_d4_224,85.874,97.468 35 | vit_large_patch16_224,85.838,97.820 36 | convnext_base_in22ft1k,85.824,97.866 37 | dm_nfnet_f5,85.816,97.486 38 | xcit_medium_24_p8_384_dist,85.814,97.592 39 | vit_base_patch8_224,85.794,97.794 40 | xcit_large_24_p16_384_dist,85.754,97.538 41 | dm_nfnet_f4,85.714,97.522 42 | tf_efficientnetv2_m_in21ft1k,85.590,97.744 43 | xcit_small_24_p8_384_dist,85.556,97.572 44 | dm_nfnet_f3,85.522,97.462 45 | tf_efficientnetv2_l,85.490,97.372 46 | cait_s36_384,85.462,97.480 47 | ig_resnext101_32x48d,85.436,97.576 48 | xcit_medium_24_p16_384_dist,85.422,97.406 49 | deit_base_distilled_patch16_384,85.422,97.332 50 | volo_d3_224,85.412,97.280 51 | xcit_large_24_p8_224_dist,85.398,97.410 52 | tf_efficientnet_b8_ap,85.372,97.292 53 | tf_efficientnet_b8,85.368,97.390 54 | swin_base_patch4_window7_224,85.248,97.562 55 | volo_d1_384,85.248,97.214 56 | beit_base_patch16_224,85.228,97.658 57 | volo_d2_224,85.194,97.188 58 | tf_efficientnet_b4_ns,85.162,97.470 59 | tf_efficientnet_b7_ap,85.120,97.252 60 | ig_resnext101_32x32d,85.100,97.434 61 | xcit_small_24_p16_384_dist,85.094,97.310 62 | xcit_small_12_p8_384_dist,85.080,97.280 63 | xcit_medium_24_p8_224_dist,85.072,97.280 64 | dm_nfnet_f2,85.062,97.240 65 | cait_s24_384,85.046,97.346 66 | tf_efficientnetv2_m,85.038,97.278 67 | regnetz_e8,85.030,97.264 68 | resnetrs420,85.010,97.124 69 | ecaresnet269d,84.976,97.226 70 | vit_base_r50_s16_384,84.970,97.290 71 | tf_efficientnet_b7,84.936,97.204 72 | xcit_large_24_p16_224_dist,84.922,97.132 73 | resnetv2_152x4_bitm,84.916,97.442 74 | xcit_small_24_p8_224_dist,84.876,97.188 75 | efficientnetv2_rw_m,84.812,97.146 76 | tf_efficientnet_b6_ap,84.786,97.138 77 | resnetrs350,84.718,96.988 78 | xcit_small_12_p16_384_dist,84.710,97.118 79 | eca_nfnet_l2,84.696,97.264 80 | dm_nfnet_f1,84.624,97.100 81 | vit_base_patch16_224,84.528,97.294 82 | resnest269e,84.520,96.986 83 | resnetv2_152x2_bitm,84.506,97.434 84 | regnetz_040h,84.496,97.006 85 | resnetv2_101x3_bitm,84.442,97.382 86 | resnetrs200,84.438,97.080 87 | resnetrs270,84.436,96.970 88 | vit_large_r50_s32_224,84.424,97.166 89 | resmlp_big_24_224_in22ft1k,84.394,97.120 90 | xcit_large_24_p8_224,84.394,96.656 91 | seresnet152d,84.360,97.040 92 | tf_efficientnetv2_s_in21ft1k,84.298,97.254 93 | convnext_large,84.292,96.894 94 | swsl_resnext101_32x8d,84.290,97.180 95 | xcit_medium_24_p16_224_dist,84.274,96.940 96 | vit_base_patch16_224_miil,84.272,96.802 97 | tf_efficientnet_b5_ap,84.256,96.976 98 | xcit_small_12_p8_224_dist,84.236,96.874 99 | regnetz_040,84.234,96.932 100 | seresnext101_32x8d,84.204,96.876 101 | crossvit_18_dagger_408,84.194,96.818 102 | ig_resnext101_32x16d,84.170,97.198 103 | volo_d1_224,84.162,96.776 104 | pit_b_distilled_224,84.140,96.856 105 | tf_efficientnet_b6,84.110,96.888 106 | cait_xs24_384,84.062,96.890 107 | regnetz_d8_evos,84.054,96.996 108 | regnetz_d8,84.052,96.994 109 | tf_efficientnet_b3_ns,84.048,96.908 110 | vit_small_r26_s32_384,84.042,97.328 111 | regnetz_d32,84.022,96.866 112 | resnetv2_50x3_bitm,84.014,97.124 113 | eca_nfnet_l1,84.012,97.028 114 | resnet200d,83.964,96.824 115 | swin_s3_base_224,83.932,96.660 116 | regnety_080,83.926,96.888 117 | tf_efficientnetv2_s,83.886,96.696 118 | xcit_small_24_p16_224_dist,83.868,96.724 119 | resnetv2_152x2_bit_teacher_384,83.844,97.118 120 | convnext_base,83.838,96.750 121 | xcit_small_24_p8_224,83.838,96.636 122 | crossvit_15_dagger_408,83.836,96.784 123 | resnest200e,83.828,96.892 124 | tf_efficientnet_b5,83.812,96.748 125 | efficientnetv2_rw_s,83.810,96.722 126 | vit_small_patch16_384,83.804,97.102 127 | swin_s3_small_224,83.768,96.452 128 | xcit_tiny_24_p8_384_dist,83.742,96.710 129 | xcit_medium_24_p8_224,83.736,96.394 130 | regnety_064,83.720,96.722 131 | resnetrs152,83.714,96.614 132 | regnetv_064,83.714,96.746 133 | regnety_160,83.690,96.776 134 | twins_svt_large,83.680,96.594 135 | resnet152d,83.678,96.738 136 | resmlp_big_24_distilled_224,83.588,96.648 137 | jx_nest_base,83.556,96.362 138 | cait_s24_224,83.458,96.564 139 | efficientnet_b4,83.424,96.596 140 | deit_base_distilled_patch16_224,83.392,96.486 141 | dm_nfnet_f0,83.386,96.574 142 | swsl_resnext101_32x16d,83.350,96.844 143 | xcit_small_12_p16_224_dist,83.350,96.414 144 | vit_base_patch32_384,83.348,96.834 145 | xcit_small_12_p8_224,83.344,96.480 146 | tf_efficientnet_b4_ap,83.252,96.394 147 | swsl_resnext101_32x4d,83.236,96.764 148 | swin_small_patch4_window7_224,83.216,96.324 149 | regnetv_040,83.200,96.662 150 | xception65,83.180,96.592 151 | convnext_small,83.150,96.432 152 | resnext101_64x4d,83.140,96.370 153 | twins_svt_base,83.138,96.420 154 | twins_pcpvt_large,83.134,96.604 155 | xception65p,83.126,96.478 156 | jx_nest_small,83.118,96.330 157 | deit_base_patch16_384,83.106,96.368 158 | tresnet_m,83.070,96.120 159 | tresnet_xl_448,83.054,96.172 160 | regnety_040,83.036,96.506 161 | tf_efficientnet_b4,83.024,96.300 162 | resnet101d,83.022,96.448 163 | xcit_large_24_p16_224,82.894,95.882 164 | resnest101e,82.890,96.318 165 | resnetv2_152x2_bit_teacher,82.872,96.570 166 | resnetv2_50x1_bit_distilled,82.822,96.528 167 | resnet152,82.820,96.130 168 | pnasnet5large,82.790,96.040 169 | nfnet_l0,82.752,96.518 170 | regnety_032,82.726,96.424 171 | twins_pcpvt_base,82.708,96.350 172 | ig_resnext101_32x8d,82.700,96.630 173 | xcit_medium_24_p16_224,82.638,95.974 174 | regnetz_c16_evos,82.632,96.476 175 | nasnetalarge,82.626,96.046 176 | levit_384,82.588,96.022 177 | xcit_small_24_p16_224,82.580,96.006 178 | eca_nfnet_l0,82.576,96.490 179 | xcit_tiny_24_p16_384_dist,82.572,96.288 180 | xcit_tiny_24_p8_224_dist,82.564,96.170 181 | resnet61q,82.526,96.134 182 | crossvit_18_dagger_240,82.520,96.070 183 | regnetz_c16,82.516,96.360 184 | gc_efficientnetv2_rw_t,82.466,96.296 185 | poolformer_m48,82.462,95.958 186 | pit_b_224,82.446,95.710 187 | crossvit_18_240,82.398,96.058 188 | xcit_tiny_12_p8_384_dist,82.394,96.220 189 | tf_efficientnet_b2_ns,82.380,96.248 190 | resnet51q,82.362,96.180 191 | ecaresnet50t,82.348,96.138 192 | efficientnetv2_rw_t,82.344,96.196 193 | resnetv2_101x1_bitm,82.336,96.516 194 | crossvit_15_dagger_240,82.330,95.958 195 | mixer_b16_224_miil,82.308,95.718 196 | coat_lite_small,82.304,95.848 197 | resnetrs101,82.288,96.010 198 | convit_base,82.286,95.938 199 | tresnet_l_448,82.268,95.982 200 | efficientnet_b3,82.240,96.114 201 | convnext_tiny_hnf,82.222,95.866 202 | crossvit_base_240,82.216,95.834 203 | cait_xxs36_384,82.194,96.144 204 | swsl_resnext50_32x4d,82.176,96.232 205 | ecaresnet101d,82.172,96.046 206 | swin_s3_tiny_224,82.126,95.950 207 | poolformer_m36,82.112,95.690 208 | visformer_small,82.106,95.874 209 | convnext_tiny,82.064,95.852 210 | halo2botnet50ts_256,82.060,95.642 211 | tresnet_xl,82.058,95.936 212 | fbnetv3_g,82.046,96.064 213 | resnetv2_101,82.042,95.864 214 | deit_base_patch16_224,81.996,95.732 215 | pit_s_distilled_224,81.994,95.798 216 | resnetv2_50d_evos,81.980,95.910 217 | xcit_small_12_p16_224,81.976,95.818 218 | tf_efficientnetv2_b3,81.966,95.780 219 | xception41p,81.960,95.794 220 | resnet101,81.932,95.770 221 | xcit_tiny_24_p8_224,81.892,95.976 222 | vit_small_r26_s32_224,81.856,96.020 223 | ssl_resnext101_32x16d,81.854,96.096 224 | tf_efficientnet_b3_ap,81.826,95.622 225 | resnetv2_50d_gn,81.818,95.922 226 | tresnet_m_448,81.704,95.572 227 | twins_svt_small,81.680,95.670 228 | halonet50ts,81.660,95.612 229 | tf_efficientnet_b3,81.636,95.718 230 | rexnet_200,81.630,95.668 231 | ssl_resnext101_32x8d,81.608,96.042 232 | lamhalobotnet50ts_256,81.546,95.502 233 | crossvit_15_240,81.542,95.690 234 | tf_efficientnet_lite4,81.536,95.668 235 | tnt_s_patch16_224,81.520,95.744 236 | levit_256,81.506,95.492 237 | vit_large_patch32_384,81.506,96.094 238 | tresnet_l,81.492,95.624 239 | wide_resnet50_2,81.452,95.530 240 | convit_small,81.420,95.740 241 | jx_nest_tiny,81.420,95.618 242 | poolformer_s36,81.418,95.450 243 | vit_small_patch16_224,81.396,96.132 244 | tf_efficientnet_b1_ns,81.388,95.736 245 | swin_tiny_patch4_window7_224,81.374,95.544 246 | convmixer_1536_20,81.366,95.614 247 | gernet_l,81.344,95.532 248 | efficientnet_el,81.310,95.530 249 | legacy_senet154,81.310,95.496 250 | coat_mini,81.266,95.394 251 | seresnext50_32x4d,81.258,95.630 252 | gluon_senet154,81.232,95.348 253 | deit_small_distilled_patch16_224,81.208,95.374 254 | xcit_tiny_12_p8_224_dist,81.208,95.600 255 | swsl_resnet50,81.174,95.978 256 | sebotnet33ts_256,81.156,95.170 257 | resmlp_36_distilled_224,81.154,95.488 258 | lambda_resnet50ts,81.146,95.102 259 | resnest50d_4s2x40d,81.110,95.564 260 | resnext50_32x4d,81.108,95.326 261 | pit_s_224,81.100,95.330 262 | twins_pcpvt_small,81.090,95.640 263 | haloregnetz_b,81.050,95.196 264 | resmlp_big_24_224,81.032,95.020 265 | crossvit_small_240,81.018,95.456 266 | gluon_resnet152_v1s,81.016,95.412 267 | resnest50d_1s4x24d,80.990,95.324 268 | resnest50d,80.982,95.380 269 | cait_xxs24_384,80.966,95.646 270 | sehalonet33ts,80.964,95.272 271 | xcit_tiny_12_p16_384_dist,80.944,95.412 272 | gcresnet50t,80.942,95.454 273 | ssl_resnext101_32x4d,80.924,95.726 274 | gluon_seresnext101_32x4d,80.906,95.294 275 | gluon_seresnext101_64x4d,80.878,95.298 276 | efficientnet_b3_pruned,80.858,95.244 277 | ecaresnet101d_pruned,80.814,95.630 278 | regnety_320,80.810,95.244 279 | resmlp_24_distilled_224,80.764,95.224 280 | gernet_m,80.744,95.184 281 | vit_base_patch32_224,80.722,95.566 282 | regnetz_b16,80.714,95.478 283 | nf_resnet50,80.654,95.334 284 | efficientnet_b2,80.614,95.316 285 | gluon_resnext101_64x4d,80.604,94.992 286 | ecaresnet50d,80.600,95.320 287 | gcresnext50ts,80.578,95.170 288 | resnet50d,80.522,95.162 289 | repvgg_b3,80.496,95.264 290 | vit_small_patch32_384,80.484,95.600 291 | gluon_resnet152_v1d,80.476,95.204 292 | mixnet_xl,80.474,94.934 293 | inception_resnet_v2,80.460,95.308 294 | ecaresnetlight,80.452,95.250 295 | xcit_tiny_24_p16_224_dist,80.446,95.216 296 | resnetv2_50,80.420,95.074 297 | gluon_resnet101_v1d,80.420,95.016 298 | resnet50,80.376,94.616 299 | regnety_120,80.376,95.126 300 | seresnet33ts,80.350,95.106 301 | gluon_resnext101_32x4d,80.344,94.926 302 | resnetv2_50x1_bitm,80.342,95.680 303 | ssl_resnext50_32x4d,80.316,95.410 304 | poolformer_s24,80.314,95.046 305 | rexnet_150,80.310,95.166 306 | tf_efficientnet_b2_ap,80.302,95.028 307 | efficientnet_el_pruned,80.302,95.216 308 | gluon_resnet101_v1s,80.298,95.164 309 | seresnet50,80.264,95.072 310 | tf_efficientnet_el,80.250,95.122 311 | vit_base_patch16_224_sam,80.242,94.754 312 | regnetx_320,80.240,95.022 313 | legacy_seresnext101_32x4d,80.224,95.010 314 | repvgg_b3g4,80.212,95.106 315 | tf_efficientnetv2_b2,80.206,95.042 316 | dpn107,80.172,94.906 317 | convmixer_768_32,80.164,95.072 318 | inception_v4,80.162,94.966 319 | skresnext50_32x4d,80.152,94.644 320 | eca_resnet33ts,80.080,94.970 321 | gcresnet33ts,80.080,95.000 322 | tf_efficientnet_b2,80.080,94.908 323 | cspdarknet53,80.062,95.084 324 | resnet50_gn,80.054,94.948 325 | cspresnext50,80.050,94.946 326 | dpn92,80.016,94.824 327 | ens_adv_inception_resnet_v2,79.978,94.938 328 | efficientnet_b2_pruned,79.916,94.854 329 | gluon_seresnext50_32x4d,79.914,94.832 330 | gluon_resnet152_v1c,79.908,94.848 331 | resnetrs50,79.886,94.966 332 | xception71,79.876,94.922 333 | deit_small_patch16_224,79.860,95.046 334 | regnetx_160,79.850,94.830 335 | ecaresnet26t,79.848,95.086 336 | levit_192,79.832,94.786 337 | dpn131,79.824,94.708 338 | tf_efficientnet_lite3,79.820,94.912 339 | resmlp_36_224,79.768,94.886 340 | cait_xxs36_224,79.748,94.866 341 | gluon_xception65,79.716,94.860 342 | ecaresnet50d_pruned,79.708,94.880 343 | xcit_tiny_12_p8_224,79.690,95.054 344 | fbnetv3_d,79.682,94.948 345 | gluon_resnet152_v1b,79.680,94.738 346 | resnext50d_32x4d,79.670,94.864 347 | dpn98,79.646,94.596 348 | gmlp_s16_224,79.640,94.624 349 | regnetx_120,79.592,94.734 350 | cspresnet50,79.582,94.704 351 | gluon_resnet101_v1c,79.534,94.580 352 | rexnet_130,79.500,94.684 353 | eca_halonext26ts,79.490,94.598 354 | hrnet_w64,79.472,94.652 355 | tf_efficientnetv2_b1,79.464,94.724 356 | dla102x2,79.446,94.632 357 | xcit_tiny_24_p16_224,79.444,94.884 358 | resmlp_24_224,79.382,94.546 359 | repvgg_b2g4,79.370,94.688 360 | gluon_resnext50_32x4d,79.364,94.426 361 | resnext101_32x8d,79.316,94.518 362 | ese_vovnet39b,79.312,94.714 363 | pit_xs_distilled_224,79.306,94.364 364 | tf_efficientnet_cc_b1_8e,79.306,94.372 365 | resnetblur50,79.304,94.634 366 | gluon_resnet101_v1b,79.302,94.520 367 | hrnet_w48,79.302,94.512 368 | nf_regnet_b1,79.288,94.748 369 | tf_efficientnet_b1_ap,79.280,94.304 370 | eca_botnext26ts_256,79.274,94.616 371 | botnet26t_256,79.252,94.528 372 | efficientnet_em,79.250,94.794 373 | ssl_resnet50,79.226,94.836 374 | dpn68b,79.220,94.418 375 | resnet33ts,79.210,94.572 376 | regnetx_080,79.202,94.554 377 | res2net101_26w_4s,79.196,94.436 378 | fbnetv3_b,79.148,94.746 379 | halonet26t,79.116,94.310 380 | lambda_resnet26t,79.098,94.588 381 | coat_lite_mini,79.096,94.604 382 | gluon_resnet50_v1d,79.076,94.472 383 | legacy_seresnext50_32x4d,79.068,94.434 384 | regnetx_064,79.066,94.458 385 | xception,79.050,94.392 386 | resnet32ts,79.012,94.358 387 | res2net50_26w_8s,78.980,94.294 388 | mixnet_l,78.976,94.178 389 | lambda_resnet26rpt_256,78.968,94.428 390 | hrnet_w40,78.916,94.474 391 | hrnet_w44,78.900,94.374 392 | wide_resnet101_2,78.854,94.290 393 | tf_efficientnet_b1,78.828,94.198 394 | gluon_inception_v3,78.804,94.370 395 | efficientnet_b1,78.796,94.342 396 | repvgg_b2,78.792,94.418 397 | tf_mixnet_l,78.774,93.996 398 | gluon_resnet50_v1s,78.712,94.240 399 | dla169,78.692,94.340 400 | tf_efficientnet_b0_ns,78.658,94.378 401 | legacy_seresnet152,78.652,94.370 402 | xcit_tiny_12_p16_224_dist,78.576,94.196 403 | res2net50_26w_6s,78.566,94.134 404 | dla102x,78.516,94.226 405 | xception41,78.510,94.278 406 | levit_128,78.492,94.006 407 | regnetx_040,78.482,94.244 408 | resnest26d,78.476,94.292 409 | dla60_res2net,78.462,94.206 410 | hrnet_w32,78.448,94.194 411 | dla60_res2next,78.440,94.150 412 | vit_tiny_patch16_384,78.434,94.542 413 | coat_tiny,78.430,94.040 414 | selecsls60b,78.412,94.174 415 | legacy_seresnet101,78.388,94.264 416 | cait_xxs24_224,78.384,94.310 417 | repvgg_b1,78.368,94.096 418 | tf_efficientnetv2_b0,78.360,94.020 419 | tv_resnet152,78.316,94.034 420 | mobilevit_s,78.312,94.152 421 | res2next50,78.252,93.886 422 | bat_resnext26ts,78.250,94.098 423 | dla60x,78.244,94.018 424 | efficientnet_b1_pruned,78.240,93.834 425 | hrnet_w30,78.198,94.224 426 | pit_xs_224,78.186,94.164 427 | regnetx_032,78.172,94.088 428 | res2net50_14w_8s,78.144,93.848 429 | tf_efficientnet_em,78.132,94.044 430 | hardcorenas_f,78.098,93.802 431 | efficientnet_es,78.056,93.936 432 | gmixer_24_224,78.036,93.670 433 | dla102,78.030,93.948 434 | gluon_resnet50_v1c,78.012,93.990 435 | seresnext26t_32x4d,77.976,93.746 436 | selecsls60,77.976,93.830 437 | res2net50_26w_4s,77.960,93.852 438 | resmlp_12_distilled_224,77.942,93.558 439 | mobilenetv3_large_100_miil,77.916,92.906 440 | tf_efficientnet_cc_b0_8e,77.906,93.656 441 | resnet26t,77.862,93.844 442 | regnety_016,77.860,93.722 443 | rexnet_100,77.858,93.870 444 | tf_inception_v3,77.856,93.640 445 | seresnext26ts,77.852,93.790 446 | gcresnext26ts,77.820,93.830 447 | xcit_nano_12_p8_384_dist,77.818,94.044 448 | hardcorenas_e,77.794,93.696 449 | efficientnet_b0,77.690,93.530 450 | tinynet_a,77.650,93.536 451 | legacy_seresnet50,77.630,93.750 452 | tv_resnext50_32x4d,77.616,93.700 453 | seresnext26d_32x4d,77.604,93.608 454 | repvgg_b1g4,77.586,93.830 455 | adv_inception_v3,77.582,93.736 456 | gluon_resnet50_v1b,77.580,93.722 457 | res2net50_48w_2s,77.520,93.552 458 | coat_lite_tiny,77.514,93.916 459 | tf_efficientnet_lite2,77.468,93.756 460 | eca_resnext26ts,77.454,93.566 461 | inception_v3,77.438,93.474 462 | hardcorenas_d,77.430,93.482 463 | tv_resnet101,77.378,93.542 464 | densenet161,77.354,93.636 465 | tf_efficientnet_cc_b0_4e,77.302,93.334 466 | mobilenetv2_120d,77.294,93.496 467 | densenet201,77.290,93.478 468 | mixnet_m,77.264,93.424 469 | poolformer_s12,77.236,93.504 470 | selecsls42b,77.174,93.392 471 | xcit_tiny_12_p16_224,77.126,93.716 472 | resnet34d,77.114,93.380 473 | legacy_seresnext26_32x4d,77.106,93.318 474 | tf_efficientnet_b0_ap,77.094,93.256 475 | hardcorenas_c,77.050,93.158 476 | dla60,77.030,93.320 477 | crossvit_9_dagger_240,76.982,93.610 478 | regnetx_016,76.950,93.422 479 | convmixer_1024_20_ks9_p14,76.942,93.356 480 | tf_mixnet_m,76.942,93.154 481 | gernet_s,76.908,93.132 482 | skresnet34,76.904,93.320 483 | tf_efficientnet_b0,76.844,93.228 484 | ese_vovnet19b_dw,76.802,93.272 485 | resnext26ts,76.780,93.128 486 | hrnet_w18,76.754,93.440 487 | resnet26d,76.704,93.150 488 | resmlp_12_224,76.656,93.180 489 | tf_efficientnet_lite1,76.640,93.220 490 | mixer_b16_224,76.612,92.228 491 | tf_efficientnet_es,76.596,93.204 492 | densenetblur121d,76.584,93.192 493 | hardcorenas_b,76.536,92.754 494 | mobilenetv2_140,76.522,92.996 495 | levit_128s,76.520,92.872 496 | repvgg_a2,76.458,93.010 497 | xcit_nano_12_p8_224_dist,76.320,93.088 498 | regnety_008,76.310,93.070 499 | dpn68,76.306,92.974 500 | tv_resnet50,76.134,92.868 501 | mixnet_s,75.992,92.798 502 | vit_small_patch32_224,75.986,93.270 503 | vit_tiny_r_s16_p8_384,75.954,93.264 504 | hardcorenas_a,75.920,92.520 505 | densenet169,75.898,93.030 506 | mobilenetv3_large_100,75.766,92.544 507 | tf_mixnet_s,75.650,92.628 508 | mobilenetv3_rw,75.632,92.708 509 | densenet121,75.584,92.652 510 | tf_mobilenetv3_large_100,75.518,92.604 511 | resnest14d,75.504,92.520 512 | efficientnet_lite0,75.476,92.512 513 | vit_tiny_patch16_224,75.462,92.844 514 | xcit_nano_12_p16_384_dist,75.456,92.690 515 | semnasnet_100,75.450,92.600 516 | resnet26,75.300,92.578 517 | regnety_006,75.250,92.534 518 | repvgg_b0,75.160,92.418 519 | fbnetc_100,75.130,92.386 520 | hrnet_w18_small_v2,75.118,92.416 521 | resnet34,75.114,92.284 522 | mobilenetv2_110d,75.038,92.184 523 | regnetx_008,75.034,92.340 524 | efficientnet_es_pruned,74.996,92.440 525 | tinynet_b,74.976,92.184 526 | tf_efficientnet_lite0,74.832,92.174 527 | legacy_seresnet34,74.808,92.126 528 | tv_densenet121,74.744,92.152 529 | mnasnet_100,74.658,92.112 530 | mobilevit_xs,74.644,92.356 531 | dla34,74.620,92.072 532 | gluon_resnet34_v1b,74.588,91.988 533 | pit_ti_distilled_224,74.532,92.096 534 | deit_tiny_distilled_patch16_224,74.512,91.886 535 | vgg19_bn,74.214,91.848 536 | spnasnet_100,74.084,91.820 537 | regnety_004,74.024,91.754 538 | ghostnet_100,73.974,91.460 539 | crossvit_9_240,73.960,91.968 540 | xcit_nano_12_p8_224,73.910,92.168 541 | regnetx_006,73.860,91.672 542 | vit_base_patch32_224_sam,73.694,91.010 543 | tf_mobilenetv3_large_075,73.436,91.344 544 | vgg16_bn,73.350,91.504 545 | crossvit_tiny_240,73.332,91.914 546 | tv_resnet34,73.306,91.424 547 | swsl_resnet18,73.276,91.736 548 | convit_tiny,73.114,91.714 549 | skresnet18,73.036,91.168 550 | semnasnet_075,72.972,91.136 551 | mobilenetv2_100,72.970,91.020 552 | pit_ti_224,72.912,91.406 553 | ssl_resnet18,72.608,91.424 554 | regnetx_004,72.392,90.832 555 | vgg19,72.366,90.870 556 | hrnet_w18_small,72.338,90.680 557 | xcit_nano_12_p16_224_dist,72.302,90.858 558 | resnet18d,72.250,90.688 559 | tf_mobilenetv3_large_minimal_100,72.250,90.630 560 | deit_tiny_patch16_224,72.172,91.114 561 | lcnet_100,72.104,90.376 562 | mixer_l16_224,72.054,87.662 563 | vit_tiny_r_s16_p8_224,71.792,90.822 564 | legacy_seresnet18,71.742,90.332 565 | vgg13_bn,71.594,90.376 566 | vgg16,71.590,90.382 567 | tinynet_c,71.228,89.750 568 | gluon_resnet18_v1b,70.834,89.762 569 | vgg11_bn,70.360,89.802 570 | regnety_002,70.254,89.532 571 | xcit_nano_12_p16_224,69.954,89.754 572 | vgg13,69.926,89.246 573 | resnet18,69.744,89.082 574 | vgg11,69.028,88.626 575 | mobilevit_xxs,68.920,88.944 576 | lcnet_075,68.816,88.370 577 | regnetx_002,68.756,88.556 578 | tf_mobilenetv3_small_100,67.924,87.664 579 | dla60x_c,67.892,88.426 580 | mobilenetv3_small_100,67.656,87.634 581 | tinynet_d,66.958,87.064 582 | mnasnet_small,66.206,86.508 583 | dla46x_c,65.970,86.980 584 | mobilenetv2_050,65.942,86.082 585 | tf_mobilenetv3_small_075,65.714,86.134 586 | mobilenetv3_small_075,65.242,85.438 587 | dla46_c,64.866,86.294 588 | lcnet_050,63.100,84.382 589 | tf_mobilenetv3_small_minimal_100,62.908,84.234 590 | tinynet_e,59.856,81.764 591 | mobilenetv3_small_050,57.890,80.194 -------------------------------------------------------------------------------- /deepdive/model_statistics/torchvision_imagenet_accuracies.csv: -------------------------------------------------------------------------------- 1 | model,top1_accuracy,top5_accuracy 2 | alexnet,56.522,79.066 3 | vgg11,69.02,88.628 4 | vgg13,69.928,89.246 5 | vgg16,71.592,90.382 6 | vgg19,72.376,90.876 7 | vgg11_bn,70.37,89.81 8 | vgg13_bn,71.586,90.374 9 | vgg16_bn,73.36,91.516 10 | vgg19_bn,74.218,91.842 11 | resnet18,69.758,89.078 12 | resnet34,73.314,91.42 13 | resnet50,76.13,92.862 14 | resnet101,77.374,93.546 15 | resnet152,78.312,94.046 16 | squeezenet_1_0,58.092,80.42 17 | squeezenet_1_1,58.178,80.624 18 | densenet121,74.434,91.972 19 | densenet169,75.6,92.806 20 | densenet201,76.896,93.37 21 | densenet161,77.138,93.56 22 | inception_v3,77.294,93.45 23 | goolenet,69.778,89.53 24 | shufflenet_v2_x0_5,69.362,88.316 25 | shufflenet_v2_x1_0,60.552,81.746 26 | mobilenet_v2,71.878,90.286 27 | mobilenet_v3_large,74.042,91.34 28 | mobilenet_v3_small,67.668,87.402 29 | resnext50_32x4d,77.618,93.698 30 | resnext101_32x8d,79.312,94.526 31 | wide_resnet50_2,78.468,94.086 32 | wide_resnet101_2,78.848,94.284 33 | mnasnet0_5,73.456,91.51 34 | mnasnet1_0,67.734,87.49 35 | efficientnet_b0,77.692,93.532 36 | efficientnet_b1,78.642,94.186 37 | efficientnet_b2,80.608,95.31 38 | efficientnet_b3,82.008,96.054 39 | efficientnet_b4,83.384,96.594 40 | efficientnet_b5,83.444,96.628 41 | efficientnet_b6,84.008,96.916 42 | efficientnet_b7,84.122,96.908 43 | regnet_x_400mf,72.834,90.95 44 | regnet_x_800mf,75.212,92.348 45 | regnet_x_1_6gf,77.04,93.44 46 | regnet_x_3_2gf,78.364,93.992 47 | regnet_x_8gf,79.344,94.686 48 | regnet_x_16gf,80.058,94.944 49 | regnet_x_32gf,80.622,95.248 50 | regnet_y_400mf,74.046,91.716 51 | regnet_y_800mf,76.42,93.136 52 | regnet_y_1_6gf,77.95,93.966 53 | regnet_y_3_2gf,78.948,94.576 54 | regnet_y_8gf,80.032,95.048 55 | regnet_y_16gf,80.424,95.24 56 | regnet_y_32gf,80.878,95.34 57 | vit_b_16,81.072,95.318 58 | vit_b_32,75.912,92.466 59 | vit_l_16,79.662,94.638 60 | vit_l_32,76.972,93.07 61 | convnext_tiny,82.52,96.146 62 | convnext_small,83.616,96.65 63 | convnext_base,84.062,96.87 64 | convnext_large,84.414,96.976 -------------------------------------------------------------------------------- /deepdive/ridge_gcv_mod.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.linear_model._ridge import MultiOutputMixin, RegressorMixin 4 | from sklearn.linear_model._ridge import _RidgeGCV, _BaseRidgeCV 5 | from sklearn.linear_model._ridge import is_classifier, _check_gcv_mode 6 | from sklearn.linear_model._ridge import _IdentityRegressor, safe_sparse_dot 7 | from sklearn.linear_model._base import _preprocess_data 8 | 9 | from sklearn.metrics import explained_variance_score 10 | from scipy.stats import pearsonr 11 | 12 | pearsonr_vec = np.vectorize(pearsonr, signature='(n),(n)->(),()') 13 | 14 | def pearson_r_score(y_true, y_pred, multioutput=None): 15 | y_true_ = y_true.transpose() 16 | y_pred_ = y_pred.transpose() 17 | return(pearsonr_vec(y_true_, y_pred_)[0]) 18 | 19 | class _RidgeGCVMod(_RidgeGCV): 20 | """Ridge regression with built-in Leave-one-out Cross-Validation.""" 21 | 22 | def __init__( 23 | self, 24 | alphas=(0.1, 1.0, 10.0), 25 | *, 26 | fit_intercept=True, 27 | scoring=None, 28 | copy_X=True, 29 | gcv_mode=None, 30 | store_cv_values=False, 31 | is_clf=False, 32 | alpha_per_target=False, 33 | ): 34 | self.alphas = np.asarray(alphas) 35 | self.fit_intercept = fit_intercept 36 | self.scoring = scoring 37 | self.copy_X = copy_X 38 | self.gcv_mode = gcv_mode 39 | self.store_cv_values = store_cv_values 40 | self.is_clf = is_clf 41 | self.alpha_per_target = alpha_per_target 42 | 43 | def fit(self, X, y, sample_weight=None): 44 | 45 | X, y = self._validate_data( 46 | X, 47 | y, 48 | accept_sparse=["csr", "csc", "coo"], 49 | dtype=[np.float64], 50 | multi_output=True, 51 | y_numeric=True, 52 | ) 53 | 54 | # alpha_per_target cannot be used in classifier mode. All subclasses 55 | # of _RidgeGCV that are classifiers keep alpha_per_target at its 56 | # default value: False, so the condition below should never happen. 57 | assert not (self.is_clf and self.alpha_per_target) 58 | 59 | if sample_weight is not None: 60 | sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) 61 | 62 | if np.any(self.alphas <= 0): 63 | raise ValueError( 64 | "alphas must be strictly positive. Got {} containing some " 65 | "negative or null value instead.".format(self.alphas) 66 | ) 67 | 68 | X, y, X_offset, y_offset, X_scale = _preprocess_data( 69 | X, 70 | y, 71 | self.fit_intercept, 72 | self.copy_X, 73 | sample_weight=sample_weight, 74 | ) 75 | 76 | gcv_mode = _check_gcv_mode(X, self.gcv_mode) 77 | 78 | if gcv_mode == "eigen": 79 | decompose = self._eigen_decompose_gram 80 | solve = self._solve_eigen_gram 81 | elif gcv_mode == "svd": 82 | if sparse.issparse(X): 83 | decompose = self._eigen_decompose_covariance 84 | solve = self._solve_eigen_covariance 85 | else: 86 | decompose = self._svd_decompose_design_matrix 87 | solve = self._solve_svd_design_matrix 88 | 89 | n_samples = X.shape[0] 90 | 91 | if sample_weight is not None: 92 | X, y = _rescale_data(X, y, sample_weight) 93 | sqrt_sw = np.sqrt(sample_weight) 94 | else: 95 | sqrt_sw = np.ones(n_samples, dtype=X.dtype) 96 | 97 | X_mean, *decomposition = decompose(X, y, sqrt_sw) 98 | 99 | if self.scoring not in ['pearson_r', 'explained_variance']: 100 | raise ValueError("modified RidgeCV scoring requires one of ['pearson_r','explained_variance']") 101 | 102 | n_y = 1 if len(y.shape) == 1 else y.shape[1] 103 | n_alphas = 1 if np.ndim(self.alphas) == 0 else len(self.alphas) 104 | 105 | if self.store_cv_values: 106 | self.cv_values_ = np.empty((n_samples * n_y, n_alphas), dtype=X.dtype) 107 | 108 | best_coef, best_score, best_alpha = None, None, None 109 | 110 | for i, alpha in enumerate(np.atleast_1d(self.alphas)): 111 | G_inverse_diag, c = solve(float(alpha), y, sqrt_sw, X_mean, *decomposition) 112 | predictions = y - (c / G_inverse_diag) 113 | if self.store_cv_values: 114 | self.cv_values_[:, i] = predictions.ravel() 115 | 116 | if self.alpha_per_target: 117 | if self.scoring == 'pearson_r': 118 | alpha_score = pearson_r_score(y, predictions) 119 | if self.scoring == 'explained_variance': 120 | alpha_score = explained_variance_score(y, predictions, multioutput = 'raw_values') 121 | else: 122 | if self.scoring == 'pearson_r': 123 | alpha_score = pearson_r_score(y, predictions).mean() 124 | if self.scoring == 'explained_variance': 125 | alpha_score = explained_variance_score(y, predictions, multioutput = 'uniform_average') 126 | 127 | # Keep track of the best model 128 | if best_score is None: 129 | if self.alpha_per_target and n_y > 1: 130 | best_coef = c 131 | best_score = np.atleast_1d(alpha_score) 132 | best_alpha = np.full(n_y, alpha) 133 | else: 134 | best_coef = c 135 | best_score = alpha_score 136 | best_alpha = alpha 137 | else: 138 | if self.alpha_per_target and n_y > 1: 139 | to_update = alpha_score > best_score 140 | best_coef[:, to_update] = c[:, to_update] 141 | best_score[to_update] = alpha_score[to_update] 142 | best_alpha[to_update] = alpha 143 | elif alpha_score > best_score: 144 | best_coef, best_score, best_alpha = c, alpha_score, alpha 145 | 146 | self.alpha_ = best_alpha 147 | self.best_score_ = best_score 148 | self.dual_coef_ = best_coef 149 | self.coef_ = safe_sparse_dot(self.dual_coef_.T, X) 150 | 151 | X_offset += X_mean * X_scale 152 | self._set_intercept(X_offset, y_offset, X_scale) 153 | 154 | if self.store_cv_values: 155 | if len(y.shape) == 1: 156 | cv_values_shape = n_samples, n_alphas 157 | else: 158 | cv_values_shape = n_samples, n_y, n_alphas 159 | self.cv_values_ = self.cv_values_.reshape(cv_values_shape) 160 | 161 | return self 162 | 163 | 164 | class _BaseRidgeCVMod(_BaseRidgeCV): 165 | def fit(self, X, y, sample_weight=None): 166 | cv = self.cv 167 | if cv is None: 168 | estimator = _RidgeGCVMod( 169 | self.alphas, 170 | fit_intercept=self.fit_intercept, 171 | scoring=self.scoring, 172 | gcv_mode=self.gcv_mode, 173 | store_cv_values=self.store_cv_values, 174 | is_clf=is_classifier(self), 175 | alpha_per_target=self.alpha_per_target, 176 | ) 177 | estimator.fit(X, y, sample_weight=sample_weight) 178 | self.alpha_ = estimator.alpha_ 179 | self.best_score_ = estimator.best_score_ 180 | if self.store_cv_values: 181 | self.cv_values_ = estimator.cv_values_ 182 | else: 183 | if self.store_cv_values: 184 | raise ValueError("cv!=None and store_cv_values=True are incompatible") 185 | if self.alpha_per_target: 186 | raise ValueError("cv!=None and alpha_per_target=True are incompatible") 187 | parameters = {"alpha": self.alphas} 188 | solver = "sparse_cg" if sparse.issparse(X) else "auto" 189 | model = RidgeClassifier if is_classifier(self) else Ridge 190 | gs = GridSearchCV( 191 | model( 192 | fit_intercept=self.fit_intercept, 193 | solver=solver, 194 | ), 195 | parameters, 196 | cv=cv, 197 | scoring=self.scoring, 198 | ) 199 | gs.fit(X, y, sample_weight=sample_weight) 200 | estimator = gs.best_estimator_ 201 | self.alpha_ = gs.best_estimator_.alpha 202 | self.best_score_ = gs.best_score_ 203 | 204 | self.coef_ = estimator.coef_ 205 | self.intercept_ = estimator.intercept_ 206 | self.n_features_in_ = estimator.n_features_in_ 207 | if hasattr(estimator, "feature_names_in_"): 208 | self.feature_names_in_ = estimator.feature_names_in_ 209 | 210 | return self 211 | 212 | 213 | class RidgeCVMod(MultiOutputMixin, RegressorMixin, _BaseRidgeCVMod): 214 | """Ridge regression with built-in cross-validation.""" 215 | --------------------------------------------------------------------------------