├── assets ├── test.txt ├── title_figure.pdf └── title_figure.png ├── requirements.txt ├── .gitmodules ├── lee ├── e2e_lee.py ├── e2e_other.py ├── lie_derivs.py ├── layerwise_other.py ├── layerwise_lee.py ├── transforms.py └── loader.py ├── LICENSE ├── README.md ├── .gitignore ├── sweep_configs ├── layerwise_configs.py └── e2e_configs.py ├── exps_e2e.py └── exps_layerwise.py /assets/test.txt: -------------------------------------------------------------------------------- 1 | empty 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | pyyaml 4 | e2cnn 5 | pandas 6 | wandb -------------------------------------------------------------------------------- /assets/title_figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ngruver/lie-deriv/HEAD/assets/title_figure.pdf -------------------------------------------------------------------------------- /assets/title_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ngruver/lie-deriv/HEAD/assets/title_figure.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch-image-models"] 2 | path = pytorch-image-models 3 | url = https://github.com/rwightman/pytorch-image-models 4 | [submodule "stylegan3"] 5 | path = stylegan3 6 | url = https://github.com/NVlabs/stylegan3 7 | -------------------------------------------------------------------------------- /lee/e2e_lee.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from .lie_derivs import * 4 | 5 | def get_equivariance_metrics(model, minibatch): 6 | x, y = minibatch 7 | if torch.cuda.is_available(): 8 | model = model.cuda() 9 | x, y = x.cuda(), y.cuda() 10 | 11 | model = model.eval() 12 | 13 | model_probs = lambda x: F.softmax(model(x), dim=-1) 14 | 15 | errs = { 16 | "trans_x_deriv": translation_lie_deriv(model_probs, x, axis="x"), 17 | "trans_y_deriv": translation_lie_deriv(model_probs, x, axis="y"), 18 | "rot_deriv": rotation_lie_deriv(model_probs, x), 19 | "shear_x_deriv": shear_lie_deriv(model_probs, x, axis="x"), 20 | "shear_y_deriv": shear_lie_deriv(model_probs, x, axis="y"), 21 | "stretch_x_deriv": stretch_lie_deriv(model_probs, x, axis="x"), 22 | "stretch_y_deriv": stretch_lie_deriv(model_probs, x, axis="y"), 23 | "saturate_err": saturate_lie_deriv(model_probs, x), 24 | } 25 | 26 | metrics = {x: pd.Series(errs[x].abs().cpu().data.numpy().mean(-1)) for x in errs} 27 | df = pd.DataFrame.from_dict(metrics) 28 | return df 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Nate Gruver 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Lie Derivative for Learned Equivariance 2 |

3 | 4 |

5 | 6 | # Installation instructions 7 | 8 | Clone submodules and install requirements using 9 | ```bash 10 | git clone --recurse-submodules https://github.com/ngruver/lie-deriv.git 11 | cd lie-deriv 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | # Layer-wise equivariance experiments 16 | 17 | The equivariance of individual models can be calculated using `exps_layerwise.py`, for example 18 | ```bash 19 | python exps_layerwise.py \ 20 | --modelname=resnet50 \ 21 | --output_dir=$HOME/lee_results \ 22 | --num_imgs=20 \ 23 | --num_probes=100 \ 24 | --transform=translation 25 | ``` 26 | Default models and transforms are available in our wandb [sweep configuration](https://github.com/ngruver/lie-deriv/blob/main/sweep_configs/layerwise_configs.py) 27 | 28 | # End-to-end equivariance experiments 29 | 30 | ```bash 31 | python exps_e2e.py \ 32 | --modelname=resnet50 \ 33 | --output_dir=$HOME/lee_results \ 34 | --num_datapoints=100 35 | ``` 36 | We also include the wandb [sweep configuration](https://github.com/ngruver/lie-deriv/blob/main/sweep_configs/e2e_configs.py) for our end-to-end equivariance experiments. 37 | 38 | # Plotting and Visualization 39 | 40 | We make our results and plotting code available in a [google colab notebook](https://colab.research.google.com/drive/1ehsYp4t5AnXJwBypmXcEE9vzkdslByTw?usp=sharing) 41 | 42 | # Citation 43 | 44 | If you find our work helpful, please cite it with 45 | 46 | ``` 47 | @misc{https://doi.org/10.48550/arxiv.2210.02984, 48 | doi = {10.48550/ARXIV.2210.02984}, 49 | url = {https://arxiv.org/abs/2210.02984}, 50 | author = {Gruver, Nate and Finzi, Marc and Goldblum, Micah and Wilson, Andrew Gordon}, 51 | title = {The Lie Derivative for Measuring Learned Equivariance}, 52 | publisher = {arXiv}, 53 | year = {2022}, 54 | copyright = {Creative Commons Attribution 4.0 International} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /sweep_configs/layerwise_configs.py: -------------------------------------------------------------------------------- 1 | minimal_models = { 2 | "name": "minimal_models", 3 | "method": "grid", 4 | "parameters": { 5 | "modelname": {"values": [ 6 | "vgg13", 7 | "inception_resnet_v2", 8 | "resnet50", 9 | "wide_resnet50_2", 10 | "densenet121", 11 | "efficientnet_b1", 12 | ]}, 13 | } 14 | } 15 | 16 | core_models = { 17 | "name": "core_models", 18 | "method": "grid", 19 | "parameters": { 20 | "modelname": {"values": [ 21 | 'vit_tiny_patch16_224', 22 | 'vit_tiny_patch16_384', 23 | 'vit_small_patch32_224', 24 | 'vit_small_patch32_384', 25 | 'vit_small_patch16_224', 26 | 'vit_small_patch16_384', 27 | 'vit_base_patch32_224', 28 | 'vit_base_patch32_384', 29 | 'vit_base_patch16_224', 30 | 'vit_base_patch16_384', 31 | 'vit_large_patch16_224', 32 | 'vit_large_patch32_384', 33 | 'vit_large_patch16_384', 34 | 'vit_base_patch16_224_miil', 35 | 'swin_base_patch4_window12_384', 36 | 'swin_base_patch4_window7_224', 37 | 'swin_large_patch4_window12_384', 38 | 'swin_large_patch4_window7_224', 39 | 'swin_small_patch4_window7_224', 40 | 'swin_tiny_patch4_window7_224', 41 | 'mixer_b16_224', 42 | 'mixer_l16_224', 43 | 'mixer_b16_224_miil', 44 | 'mixer_b16_224_in21k', 45 | 'mixer_l16_224_in21k', 46 | 'resmlp_12_224', 47 | 'resmlp_24_224', 48 | 'resmlp_36_224', 49 | 'resmlp_big_24_224', 50 | 'resmlp_12_distilled_224', 51 | 'resmlp_24_distilled_224', 52 | 'resmlp_36_distilled_224', 53 | 'resmlp_big_24_distilled_224', 54 | 'resmlp_big_24_224_in22ft1k', 55 | 'vgg11', 56 | 'vgg13', 57 | 'vgg16', 58 | 'vgg19', 59 | 'vgg11_bn', 60 | 'vgg13_bn', 61 | 'vgg16_bn', 62 | 'vgg19_bn', 63 | 'inception_resnet_v2', 64 | 'inception_v3', 65 | 'inception_v4', 66 | 'densenet121', 67 | 'densenet161', 68 | 'densenet169', 69 | 'densenet201', 70 | 'densenetblur121d', 71 | 'tv_densenet121', 72 | 'tv_resnet34', 73 | 'tv_resnet50', 74 | 'tv_resnet101', 75 | 'tv_resnet152', 76 | 'resnet34', 77 | 'resnet50', 78 | 'resnet50d', 79 | 'resnet101', 80 | 'resnet101d', 81 | 'resnet152', 82 | 'resnet152d', 83 | 'resnetrs101', 84 | 'resnetrs152', 85 | 'wide_resnet50_2', 86 | 'wide_resnet101_2', 87 | 'resnetblur50', 88 | 'ig_resnext101_32x16d', 89 | 'ig_resnext101_32x32d', 90 | 'ssl_resnext101_32x16d', 91 | 'convmixer_768_32', 92 | 'convmixer_1536_20', 93 | 'mobilenetv2_100', 94 | 'mobilenetv2_110d', 95 | 'mobilenetv2_120d', 96 | 'mobilenetv2_140', 97 | 'efficientnet_b0', 98 | 'efficientnet_b1', 99 | 'efficientnet_b2', 100 | 'efficientnet_b3', 101 | 'efficientnet_b4', 102 | "convnext_tiny", 103 | "convnext_small", 104 | "convnext_base", 105 | "convnext_large", 106 | "convnext_tiny_in22ft1k", 107 | "convnext_small_in22ft1k", 108 | "convnext_base_in22ft1k", 109 | "convnext_large_in22ft1k", 110 | "convnext_xlarge_in22ft1k", 111 | 'beit_base_patch16_224', 112 | 'beit_base_patch16_384', 113 | 'beit_large_patch16_224', 114 | 'beit_large_patch16_384', 115 | 'beit_large_patch16_512', 116 | 'convit_small', 117 | 'convit_base', 118 | ]}, 119 | } 120 | } -------------------------------------------------------------------------------- /exps_e2e.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import wandb 4 | import numpy as np 5 | import argparse 6 | import pandas as pd 7 | from functools import partial 8 | 9 | import sys 10 | sys.path.append('pytorch-image-models') 11 | import timm 12 | 13 | from lee.e2e_lee import get_equivariance_metrics as get_lee_metrics 14 | from lee.e2e_other import get_equivariance_metrics as get_discrete_metrics 15 | from lee.loader import get_loaders, eval_average_metrics_wstd 16 | 17 | def numparams(model): 18 | return sum(p.numel() for p in model.parameters()) 19 | 20 | def get_metrics(args, key, loader, model, max_mbs=400): 21 | discrete_metrics = eval_average_metrics_wstd( 22 | loader, partial(get_discrete_metrics, model), max_mbs=max_mbs, 23 | ) 24 | lee_metrics = eval_average_metrics_wstd( 25 | loader, partial(get_lee_metrics, model), max_mbs=max_mbs, 26 | ) 27 | metrics = pd.concat([lee_metrics, discrete_metrics], axis=1) 28 | 29 | metrics["dataset"] = key 30 | metrics["model"] = args.modelname 31 | metrics["params"] = numparams(model) 32 | 33 | return metrics 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser(description='Training Config', add_help=False) 37 | parser.add_argument('--output_dir', metavar='NAME', default='equivariance_metrics_cnns',help='experiment name') 38 | parser.add_argument('--modelname', metavar='NAME', default='resnet18', help='model name') 39 | parser.add_argument('--num_datapoints', type=int, default=60, help='use pretrained model') 40 | return parser 41 | 42 | def main(args): 43 | wandb.init(project="LieDerivEquivariance", config=args) 44 | args.__dict__.update(wandb.config) 45 | 46 | print(args) 47 | 48 | if not os.path.exists(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | 51 | print(args.modelname) 52 | 53 | model = getattr(timm.models, args.modelname)(pretrained=True) 54 | model.eval() 55 | 56 | evaluated_metrics = [] 57 | 58 | imagenet_train_loader, imagenet_test_loader = get_loaders( 59 | model, 60 | dataset="imagenet", 61 | data_dir="/imagenet", 62 | batch_size=1, 63 | num_train=args.num_datapoints, 64 | num_val=args.num_datapoints, 65 | args=args, 66 | train_split='train', 67 | val_split='validation', 68 | ) 69 | 70 | evaluated_metrics += [ 71 | get_metrics(args, "Imagenet_train", imagenet_train_loader, model), 72 | get_metrics(args, "Imagenet_test", imagenet_test_loader, model) 73 | ] 74 | gc.collect() 75 | 76 | # _, cifar_test_loader = get_loaders( 77 | # model, 78 | # dataset="torch/cifar100", 79 | # data_dir="/scratch/nvg7279/cifar", 80 | # batch_size=1, 81 | # num_train=args.num_datapoints, 82 | # num_val=args.num_datapoints, 83 | # args=args, 84 | # train_split='train', 85 | # val_split='validation', 86 | # ) 87 | 88 | # evaluated_metrics += [get_metrics(args, "cifar100", cifar_test_loader, model, max_mbs=args.num_datapoints)] 89 | # gc.collect() 90 | 91 | # _, retinopathy_loader = get_loaders( 92 | # model, 93 | # dataset="tfds/diabetic_retinopathy_detection", 94 | # data_dir="/scratch/nvg7279/tfds", 95 | # batch_size=1, 96 | # num_train=1e8, 97 | # num_val=1e8, 98 | # args=args, 99 | # train_split="train", 100 | # val_split="train", 101 | # ) 102 | 103 | # evaluated_metrics += [get_metrics(args, "retinopathy", retinopathy_loader, model, max_mbs=args.num_datapoints)] 104 | # gc.collect() 105 | 106 | # _, histology_loader = get_loaders( 107 | # model, 108 | # dataset="tfds/colorectal_histology", 109 | # data_dir="/scratch/nvg7279/tfds", 110 | # batch_size=1, 111 | # num_train=1e8, 112 | # num_val=1e8, 113 | # args=args, 114 | # train_split="train", 115 | # val_split="train", 116 | # ) 117 | 118 | # evaluated_metrics += [get_metrics(args, "histology", histology_loader, model, max_mbs=args.num_datapoints)] 119 | # gc.collect() 120 | 121 | df = pd.concat(evaluated_metrics) 122 | df.to_csv(os.path.join(args.output_dir, args.modelname + ".csv")) 123 | 124 | if __name__ == "__main__": 125 | args = get_args_parser().parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /exps_layerwise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import copy 4 | import argparse 5 | import warnings 6 | import pandas as pd 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import lee.layerwise_lee as lee 14 | import lee.layerwise_other as other_metrics 15 | from lee.loader import get_loaders 16 | 17 | import sys 18 | sys.path.append('pytorch-image-models') 19 | import timm 20 | 21 | def convert_inplace_relu_to_relu(model): 22 | for child_name, child in model.named_children(): 23 | if isinstance(child, nn.ReLU): 24 | setattr(model, child_name, nn.ReLU()) 25 | else: 26 | convert_inplace_relu_to_relu(child) 27 | 28 | def get_layerwise(args, model, loader, func): 29 | errlist = [] 30 | for idx, (x, _) in tqdm.tqdm( 31 | enumerate(loader), total=len(loader) 32 | ): 33 | if idx >= args.num_imgs: 34 | break 35 | 36 | img = x.to(torch.device("cuda")) 37 | errors = func(model, img, num_probes=args.num_probes) 38 | errors["img_idx"] = idx 39 | errlist.append(errors) 40 | 41 | df = pd.concat(errlist, axis=0) 42 | df["model"] = args.modelname 43 | return df 44 | 45 | def main(args): 46 | os.makedirs(args.output_dir, exist_ok=True) 47 | 48 | print(args.modelname) 49 | print(args.transform) 50 | 51 | model = getattr(timm.models, args.modelname)(pretrained=True) 52 | 53 | convert_inplace_relu_to_relu(model) 54 | model = model.eval() 55 | model = model.to(torch.device("cuda")) 56 | 57 | _, loader = get_loaders( 58 | model, 59 | dataset="imagenet", 60 | data_dir="/imagenet", 61 | batch_size=1, 62 | num_train=args.num_imgs, 63 | num_val=args.num_imgs, 64 | args=args, 65 | ) 66 | 67 | lee_transforms = ["translation","rotation","hyper_rotation","scale","saturate"] 68 | if args.use_lee and (args.transform in lee_transforms): 69 | lee_model = copy.deepcopy(model) 70 | lee.apply_hooks(lee_model, args.transform) 71 | lee_metrics = get_layerwise( 72 | args, lee_model, loader, func=lee.compute_equivariance_attribution 73 | ) 74 | 75 | lee_output_dir = os.path.join(args.output_dir, "lee_" + args.transform) 76 | os.makedirs(lee_output_dir, exist_ok=True) 77 | lee_metrics.to_csv(os.path.join(lee_output_dir, args.modelname + ".csv")) 78 | 79 | other_metrics_transforms = ["integer_translation","translation","rotation"] 80 | if (not args.use_lee) and (args.transform in other_metrics_transforms): 81 | other_metrics_model = copy.deepcopy(model) 82 | other_metrics.apply_hooks(other_metrics_model) 83 | func = partial(other_metrics.compute_equivariance_attribution, args.transform) 84 | other_metrics_results = get_layerwise( 85 | args, other_metrics_model, loader, func=func 86 | ) 87 | 88 | other_metrics_output_dir = os.path.join(args.output_dir, "stylegan3_" + args.transform) 89 | os.makedirs(other_metrics_output_dir, exist_ok=True) 90 | results_fn = args.modelname + "_norm_sqrt" + ".csv" 91 | other_metrics_results.to_csv(os.path.join(other_metrics_output_dir, results_fn)) 92 | 93 | 94 | def get_args_parser(): 95 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 96 | 97 | parser.add_argument( 98 | '--output_dir', metavar='NAME', default='equivariance_metrics_cnns',help='experiment name' 99 | ) 100 | parser.add_argument( 101 | "--modelname", metavar="NAME", default="resnet50", help="model name" 102 | ) 103 | parser.add_argument( 104 | "--num_imgs", type=int, default=20, help="Number of images to evaluate over" 105 | ) 106 | parser.add_argument( 107 | "--num_probes", 108 | type=int, 109 | default=100, 110 | help="Number of probes to use in the estimator", 111 | ) 112 | parser.add_argument( 113 | "--transform", 114 | metavar="NAME", 115 | default="translation", 116 | help="translation or rotation", 117 | ) 118 | parser.add_argument( 119 | "--use_lee", type=int, default=0, help="Use LEE (rather than metric not in limit)" 120 | ) 121 | return parser 122 | 123 | if __name__ == "__main__": 124 | args = get_args_parser().parse_args() 125 | main(args) 126 | # with warnings.catch_warnings(): 127 | # warnings.simplefilter("ignore") 128 | # main(args) 129 | -------------------------------------------------------------------------------- /lee/e2e_other.py: -------------------------------------------------------------------------------- 1 | from email import message_from_bytes 2 | from mailbox import Message 3 | import torch 4 | import torch.nn.functional as F 5 | import pandas as pd 6 | import numpy as np 7 | 8 | import sys 9 | sys.path.append("stylegan3") 10 | from stylegan3.metrics.equivariance import ( 11 | apply_integer_translation, 12 | apply_fractional_translation, 13 | apply_fractional_rotation, 14 | ) 15 | 16 | from .transforms import ( 17 | translate, 18 | rotate, 19 | ) 20 | 21 | def EQ_T(model, img, model_out, translate_max=0.125): 22 | d = [] 23 | for _ in range(10): 24 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max 25 | t = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1]) 26 | ref, _ = apply_integer_translation(img, t[0], t[1]) 27 | t_model_out = model(ref) 28 | d.append((t_model_out - model_out).square().mean()) 29 | #psnr = np.log10(2) * 20 - mse.log10() * 10 30 | return torch.stack(d).mean() 31 | 32 | def EQ_T_frac(model, img, model_out, translate_max=0.125): 33 | d = [] 34 | for _ in range(10): 35 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max 36 | ref, _ = apply_fractional_translation(img, t[0], t[1]) 37 | t_model_out = model(ref) 38 | d.append((t_model_out - model_out).square().mean()) 39 | # psnr = np.log10(2) * 20 - mse.log10() * 10 40 | return torch.stack(d).mean() 41 | 42 | def EQ_R(model, img, model_out, rotate_max=1.0): 43 | d = [] 44 | for _ in range(10): 45 | angle = (torch.rand([], device='cuda') * 2 - 1) * (rotate_max * np.pi) 46 | ref, _ = apply_fractional_rotation(img, angle) 47 | t_model_out = model(ref) 48 | d.append((t_model_out - model_out).square().mean()) 49 | # psnr = np.log10(2) * 20 - mse.log10() * 10 50 | return torch.stack(d).mean() 51 | 52 | 53 | def translation_sample_invariance(model,inp_imgs,model_out,axis='x',eta=2.0): 54 | """ Lie derivative of model with respect to translation vector, assumes scalar output """ 55 | shifted_model = lambda t: model(translate(inp_imgs,t,axis)) 56 | d = [] 57 | for _ in range(10): 58 | t_sample = (2 * eta) * torch.rand(1) - eta 59 | d.append((model_out - shifted_model(t_sample)).pow(2).mean()) 60 | return torch.stack(d).mean(0).unsqueeze(0) 61 | 62 | def rotation_sample_invariance(model,inp_imgs,model_out,eta=np.pi//16): 63 | """ Lie derivative of model with respect to rotation, assumes scalar output """ 64 | rotated_model = lambda theta: model(rotate(inp_imgs,theta)) 65 | d = [] 66 | for _ in range(10): 67 | theta_sample = (2 * eta) * torch.rand(1) - eta 68 | d.append((model_out - rotated_model(theta_sample)).pow(2).mean()) 69 | return torch.stack(d).mean(0).unsqueeze(0) 70 | 71 | 72 | def get_equivariance_metrics(model, minibatch, num_probes=20): 73 | x, y = minibatch 74 | if torch.cuda.is_available(): 75 | model = model.cuda() 76 | x, y = x.cuda(), y.cuda() 77 | 78 | model = model.eval() 79 | 80 | model_probs = lambda x: F.softmax(model(x), dim=-1) 81 | model_out = model_probs(x) 82 | 83 | yhat = model_out.argmax(dim=1) # .cpu() 84 | acc = (yhat == y).cpu().float().data.numpy() 85 | 86 | metrics = {} 87 | metrics["acc"] = pd.Series(acc) 88 | 89 | with torch.no_grad(): 90 | for shift_x in range(8): 91 | rolled_img = torch.roll(x, shift_x, 2) 92 | rolled_yhat = model(rolled_img).argmax(dim=1) 93 | consistency = (rolled_yhat == yhat).cpu().data.numpy() 94 | metrics["consistency_x" + str(shift_x)] = pd.Series(consistency) 95 | for shift_y in range(8): 96 | rolled_img = torch.roll(x, shift_y, 3) 97 | rolled_yhat = model(rolled_img).argmax(dim=1) 98 | consistency = (rolled_yhat == yhat).cpu().data.numpy() 99 | metrics["consistency_y" + str(shift_y)] = pd.Series(consistency) 100 | 101 | eq_t = torch.stack([EQ_T(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0) 102 | eq_t_frac = torch.stack([EQ_T_frac(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0) 103 | eq_r = torch.stack([EQ_R(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0) 104 | 105 | metrics["eq_t"] = eq_t.cpu().data.numpy() 106 | metrics["eq_t_frac"] = eq_t_frac.cpu().data.numpy() 107 | metrics["eq_r"] = eq_r.cpu().data.numpy() 108 | 109 | metrics['trans_x_sample'] = translation_sample_invariance(model_probs,x,model_out,axis='x').abs().cpu().data.numpy() 110 | metrics['trans_y_sample'] = translation_sample_invariance(model_probs,x,model_out,axis='y').abs().cpu().data.numpy() 111 | metrics['rotate_sample'] = rotation_sample_invariance(model_probs,x,model_out).abs().cpu().data.numpy() 112 | 113 | df = pd.DataFrame.from_dict(metrics) 114 | return df 115 | -------------------------------------------------------------------------------- /lee/lie_derivs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .transforms import * 3 | 4 | def jvp(f, x, u): 5 | """Jacobian vector product Df(x)u vs typical autograd VJP vTDF(x). 6 | Uses two backwards passes: computes (vTDF(x))u and then derivative wrt to v to get DF(x)u""" 7 | with torch.enable_grad(): 8 | y = f(x) 9 | v = torch.ones_like( 10 | y, requires_grad=True 11 | ) # Dummy variable (could take any value) 12 | vJ = torch.autograd.grad(y, [x], [v], create_graph=True) 13 | Ju = torch.autograd.grad(vJ, [v], [u], create_graph=True) 14 | return Ju[0] 15 | 16 | 17 | def translation_lie_deriv(model, inp_imgs, axis="x"): 18 | """Lie derivative of model with respect to translation vector, output can be a scalar or an image""" 19 | # vector = vector.to(inp_imgs.device) 20 | if not img_like(inp_imgs.shape): 21 | return 0.0 22 | 23 | def shifted_model(t): 24 | # print("Input shape",inp_imgs.shape) 25 | shifted_img = translate(inp_imgs, t, axis) 26 | z = model(shifted_img) 27 | # print("Output shape",z.shape) 28 | # if model produces an output image, shift it back 29 | if img_like(z.shape): 30 | z = translate(z, -t, axis) 31 | # print('zshape',z.shape) 32 | return z 33 | 34 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 35 | lie_deriv = jvp(shifted_model, t, torch.ones_like(t, requires_grad=True)) 36 | # print('Liederiv shape',lie_deriv.shape) 37 | # print(model.__class__.__name__) 38 | # print('') 39 | return lie_deriv 40 | 41 | 42 | def rotation_lie_deriv(model, inp_imgs): 43 | """Lie derivative of model with respect to rotation, assumes scalar output""" 44 | if not img_like(inp_imgs.shape): 45 | return 0.0 46 | 47 | def rotated_model(t): 48 | rotated_img = rotate(inp_imgs, t) 49 | z = model(rotated_img) 50 | if img_like(z.shape): 51 | z = rotate(z, -t) 52 | return z 53 | 54 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 55 | lie_deriv = jvp(rotated_model, t, torch.ones_like(t)) 56 | return lie_deriv 57 | 58 | 59 | def hyperbolic_rotation_lie_deriv(model, inp_imgs): 60 | """Lie derivative of model with respect to rotation, assumes scalar output""" 61 | if not img_like(inp_imgs.shape): 62 | return 0.0 63 | 64 | def rotated_model(t): 65 | rotated_img = hyperbolic_rotate(inp_imgs, t) 66 | z = model(rotated_img) 67 | if img_like(z.shape): 68 | z = hyperbolic_rotate(z, -t) 69 | return z 70 | 71 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 72 | lie_deriv = jvp(rotated_model, t, torch.ones_like(t)) 73 | return lie_deriv 74 | 75 | 76 | def scale_lie_deriv(model, inp_imgs): 77 | """Lie derivative of model with respect to rotation, assumes scalar output""" 78 | if not img_like(inp_imgs.shape): 79 | return 0.0 80 | 81 | def scaled_model(t): 82 | scaled_img = scale(inp_imgs, t) 83 | z = model(scaled_img) 84 | if img_like(z.shape): 85 | z = scale(z, -t) 86 | return z 87 | 88 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 89 | lie_deriv = jvp(scaled_model, t, torch.ones_like(t)) 90 | return lie_deriv 91 | 92 | 93 | def shear_lie_deriv(model, inp_imgs, axis="x"): 94 | """Lie derivative of model with respect to shear, assumes scalar output""" 95 | if not img_like(inp_imgs.shape): 96 | return 0.0 97 | 98 | def sheared_model(t): 99 | sheared_img = shear(inp_imgs, t, axis) 100 | z = model(sheared_img) 101 | if img_like(z.shape): 102 | z = shear(z, -t, axis) 103 | return z 104 | 105 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 106 | lie_deriv = jvp(sheared_model, t, torch.ones_like(t)) 107 | return lie_deriv 108 | 109 | 110 | def stretch_lie_deriv(model, inp_imgs, axis="x"): 111 | """Lie derivative of model with respect to stretch, assumes scalar output""" 112 | if not img_like(inp_imgs.shape): 113 | return 0.0 114 | 115 | def stretched_model(t): 116 | stretched_img = stretch(inp_imgs, t, axis) 117 | z = model(stretched_img) 118 | if img_like(z.shape): 119 | z = stretch(z, -t, axis) 120 | return z 121 | 122 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 123 | lie_deriv = jvp(stretched_model, t, torch.ones_like(t)) 124 | return lie_deriv 125 | 126 | 127 | def saturate_lie_deriv(model, inp_imgs): 128 | """Lie derivative of model with respect to saturation, assumes scalar output""" 129 | if not img_like(inp_imgs.shape): 130 | return 0.0 131 | 132 | def saturated_model(t): 133 | saturated_img = saturate(inp_imgs, t) 134 | z = model(saturated_img) 135 | if img_like(z.shape): 136 | z = saturate(z, -t) 137 | return z 138 | 139 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device) 140 | lie_deriv = jvp(saturated_model, t, torch.ones_like(t)) 141 | return lie_deriv 142 | 143 | -------------------------------------------------------------------------------- /lee/layerwise_other.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from re import U 3 | import pandas as pd 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import sys 11 | sys.path.append("stylegan3") 12 | from metrics.equivariance import ( 13 | apply_integer_translation, 14 | apply_fractional_translation, 15 | apply_fractional_rotation, 16 | ) 17 | 18 | from .layerwise_lee import selective_apply 19 | from .transforms import img_like 20 | 21 | def EQ_T(module, img, model_out, translate_max=0.125): 22 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max 23 | t_in = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1]) 24 | ref, _ = apply_integer_translation(img, t_in[0], t_in[1]) 25 | with torch.no_grad(): 26 | t_model_out = module(ref) 27 | t_out = (t * t_model_out.shape[-2] * t_model_out.shape[-1]).round() / (t_model_out.shape[-2] * t_model_out.shape[-1]) 28 | t_model_out, out_mask = apply_integer_translation(t_model_out, -t_out[0], -t_out[1]) 29 | squared_err = (t_model_out - model_out).square() 30 | denom = out_mask.sum().sqrt() if out_mask.sum() > 0 else 1.0 31 | mse = (squared_err * out_mask).sum() / denom 32 | if module.__class__.__name__ == 'BatchNorm2d': 33 | mse = mse * 0 34 | # psnr = np.log10(2) * 20 - mse.log10() * 10 35 | return mse 36 | 37 | def EQ_T_frac(module, img, model_out, translate_max=0.125): 38 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max 39 | t_in = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1]) 40 | ref, _ = apply_fractional_translation(img, t_in[0], t_in[1]) 41 | with torch.no_grad(): 42 | t_model_out = module(ref) 43 | t_out = (t * t_model_out.shape[-2] * t_model_out.shape[-1]).round() / (t_model_out.shape[-2] * t_model_out.shape[-1]) 44 | t_model_out, out_mask = apply_fractional_translation(t_model_out, -t_out[0], -t_out[1]) 45 | squared_err = (t_model_out - model_out).square() 46 | denom = out_mask.sum().sqrt() if out_mask.sum() > 0 else 1.0 47 | mse = (squared_err * out_mask).sum() / denom 48 | if module.__class__.__name__ == 'BatchNorm2d': 49 | mse = mse * 0 50 | # psnr = np.log10(2) * 20 - mse.log10() * 10 51 | return mse 52 | 53 | def EQ_R(module, img, model_out, rotate_max=1.0): 54 | angle = (torch.rand(device='cuda') * 2 - 1) * (rotate_max * np.pi) 55 | ref, _ = apply_fractional_rotation(img, angle) 56 | with torch.no_grad(): 57 | t_model_out = module(ref) 58 | t_model_out, out_mask = apply_fractional_rotation(t_model_out, -angle) 59 | # model_out, pseudo_mask = apply_fractional_pseudo_rotation(model_out, -angle) 60 | squared_err = (t_model_out - model_out).square() 61 | denom = out_mask.sum() if out_mask.sum() > 0 else 1.0 62 | mse = (squared_err * out_mask).sum()# / denom 63 | # psnr = np.log10(2) * 20 - mse.log10() * 10 64 | return mse 65 | 66 | 67 | def store_inputs(self, inputs, outputs): 68 | self._cached_input = inputs 69 | self._cached_output = outputs 70 | 71 | def apply_hooks(model): 72 | selective_apply( 73 | model, lambda m: m.register_forward_hook(store_inputs) 74 | ) 75 | 76 | def reset(self): 77 | try: 78 | del self._cached_input 79 | except AttributeError: 80 | pass 81 | 82 | def compute_equivariance_attribution(transform, model, img_batch, num_probes=20): 83 | model = model.eval() 84 | 85 | equiv_metric = { 86 | "integer_translation": EQ_T, 87 | "translation": EQ_T_frac, 88 | "rotation": EQ_R, 89 | }[transform] 90 | 91 | model_fn = lambda z: F.softmax(model(z), dim=1) 92 | with torch.no_grad(): 93 | model_fn(img_batch) 94 | 95 | name_counter = defaultdict(int) 96 | cached_info = [] 97 | for i, (name, module) in enumerate(model.named_modules()): 98 | name_counter[name] += 1 99 | 100 | if not hasattr(module, "_cached_input") or \ 101 | not hasattr(module, "_cached_output"): 102 | continue 103 | 104 | model_in = copy.deepcopy(module._cached_input) 105 | model_in = model_in[0] if isinstance(model_in, tuple) else model_in 106 | model_out = copy.deepcopy(module._cached_output) 107 | model_out = model_out[0] if isinstance(model_out, tuple) else model_out 108 | 109 | cached_info.append((name, module, model_in, model_out)) 110 | 111 | model.apply(reset) 112 | 113 | all_errs = [] 114 | for j in range(num_probes): 115 | errs = {} 116 | for i, (name, module, model_in, model_out) in enumerate(cached_info): 117 | num_tag = f"{name_counter[name]}" if name_counter[name] else "" 118 | mod = module.__class__.__name__ 119 | 120 | if not img_like(model_in.shape) or not img_like(model_out.shape): 121 | equiv_err = 0.0 122 | else: 123 | with torch.no_grad(): 124 | equiv_err = equiv_metric(module, model_in, model_out) 125 | equiv_err = equiv_err.cpu().data.numpy() 126 | 127 | errs[ 128 | (name + (num_tag), mod, i, name_counter[name]) 129 | ] = equiv_err 130 | 131 | errs = pd.Series(errs, index=pd.MultiIndex.from_tuples(errs.keys()), name=j) 132 | all_errs.append(errs) 133 | 134 | df = pd.DataFrame(all_errs) 135 | print(df) 136 | return df -------------------------------------------------------------------------------- /lee/layerwise_lee.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import argparse 4 | import pandas as pd 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .lie_derivs import * 12 | 13 | class flag: 14 | pass 15 | 16 | singleton = flag 17 | singleton.compute_lie = True 18 | singleton.op_counter = 0 19 | singleton.fwd = True 20 | 21 | # TODO: record call order 22 | # TODO: if module is called 2nd time, create a copy and store separately 23 | # change self inplace? 24 | def store_inputs(lie_deriv_type, self, inputs, outputs): 25 | if not singleton.fwd or not singleton.compute_lie: 26 | return 27 | # if hasattr(self,'_lie_norm_sum'): return 28 | with torch.no_grad(): 29 | singleton.compute_lie = False 30 | if not hasattr(self, "_lie_norm_sum"): 31 | self._lie_norm_sum = [] 32 | self._lie_norm_sum_sq = [] 33 | self._num_probes = [] 34 | self._op_counter = [] 35 | self._fwd_counter = 0 36 | self._lie_deriv_output = [] 37 | if self._fwd_counter == len(self._lie_norm_sum): 38 | self._lie_norm_sum.append(0) 39 | self._lie_norm_sum_sq.append(0) 40 | self._num_probes.append(0) 41 | self._op_counter.append(singleton.op_counter) 42 | self._lie_deriv_output.append(0) 43 | singleton.op_counter += 1 44 | assert len(inputs) == 1 45 | (x,) = inputs 46 | x = x + torch.zeros_like(x) 47 | self._lie_deriv_output[self._fwd_counter] = lie_deriv_type(self, x) 48 | 49 | self._fwd_counter += 1 50 | singleton.compute_lie = True 51 | 52 | def store_estimator(self, grad_input, grad_output): 53 | if singleton.compute_lie: 54 | with torch.no_grad(): 55 | assert len(grad_output) == 1 56 | bs = grad_output[0].shape[0] 57 | self._fwd_counter -= 1 # reverse of forward ordering 58 | i = self._fwd_counter 59 | # if i!=0: return 60 | # print('bwd',self._fwd_counter) 61 | estimator = ( 62 | ( 63 | (grad_output[0] * self._lie_deriv_output[i]).reshape(bs, -1).sum(-1) 64 | ** 2 65 | ) 66 | .cpu() 67 | .data.numpy() 68 | ) 69 | self._lie_norm_sum[i] += estimator 70 | self._lie_norm_sum_sq[i] += estimator**2 71 | self._num_probes[i] += 1 72 | # print("finished bwd",self) 73 | 74 | 75 | from timm.models.vision_transformer import Attention as A1 76 | # from timm.models.vision_transformer_wconvs import Attention as A2 77 | from timm.models.mlp_mixer import MixerBlock, Affine, SpatialGatingBlock 78 | from timm.models.layers import PatchEmbed, Mlp, DropPath, BlurPool2d 79 | 80 | # from timm.models.layers import FastAdaptiveAvgPool2d,AdaptiveAvgMaxPool2d 81 | from timm.models.layers import GatherExcite, EvoNormBatch2d 82 | from timm.models.senet import SEModule 83 | from timm.models.efficientnet_blocks import SqueezeExcite 84 | from timm.models.convit import MHSA, GPSA 85 | 86 | leaflist = ( 87 | A1, 88 | # A2, 89 | MixerBlock, 90 | Affine, 91 | SpatialGatingBlock, 92 | PatchEmbed, 93 | Mlp, 94 | DropPath, 95 | BlurPool2d, 96 | ) 97 | # leaflist += (nn.AdaptiveAvgPool2d,nn.MaxPool2d,nn.AvgPool2d) 98 | leaflist += ( 99 | GatherExcite, 100 | EvoNormBatch2d, 101 | nn.BatchNorm2d, 102 | nn.BatchNorm1d, 103 | nn.LayerNorm, 104 | nn.GroupNorm, 105 | SEModule, 106 | SqueezeExcite, 107 | ) 108 | leaflist += (MHSA, GPSA) 109 | 110 | def is_leaf(m): 111 | return (not hasattr(m, "children") or not list(m.children())) or isinstance( 112 | m, leaflist 113 | ) 114 | 115 | def is_excluded(m): 116 | excluded_list = nn.Dropout 117 | return isinstance(m, excluded_list) 118 | 119 | def selective_apply(m, fn): 120 | if is_leaf(m): 121 | if not is_excluded(m): 122 | fn(m) 123 | else: 124 | for c in m.children(): 125 | selective_apply(c, fn) 126 | 127 | def apply_hooks(model, lie_deriv_type): 128 | lie_deriv = { 129 | "translation": translation_lie_deriv, 130 | "rotation": rotation_lie_deriv, 131 | "hyper_rotation": hyperbolic_rotation_lie_deriv, 132 | "scale": scale_lie_deriv, 133 | "saturate": saturate_lie_deriv, 134 | }[lie_deriv_type] 135 | 136 | selective_apply( 137 | model, lambda m: m.register_forward_hook(partial(store_inputs, lie_deriv)) 138 | ) 139 | selective_apply(model, lambda m: m.register_backward_hook(store_estimator)) 140 | 141 | 142 | def reset(self): 143 | try: 144 | # del self._input 145 | del self._lie_norm_sum 146 | del self._lie_norm_sum_sq 147 | del self._num_probes 148 | del self._op_counter 149 | del self._bwd_counter 150 | del self._lie_deriv_output 151 | except AttributeError: 152 | pass 153 | 154 | 155 | def reset2(self): 156 | self._lie_norm_sum = [0.0] * len(self._lie_norm_sum) 157 | self._lie_norm_sum_sq = [0.0] * len(self._lie_norm_sum) 158 | self._num_probes = [0.0] * len(self._lie_norm_sum) 159 | 160 | 161 | def compute_equivariance_attribution(model, img_batch, num_probes=100): 162 | model_fn = lambda z: F.softmax(model(z), dim=1) 163 | all_errs = [] 164 | order = [] 165 | for j in range(num_probes): 166 | singleton.fwd = True 167 | y = model_fn(img_batch) 168 | z = torch.randn(img_batch.shape[0], 1000).to(torch.device("cuda")) 169 | loss = (z * y).sum() 170 | singleton.fwd = False 171 | loss.backward() 172 | model.zero_grad() 173 | singleton.op_counter = 0 174 | 175 | errs = {} 176 | for name, module in model.named_modules(): 177 | if hasattr(module, "_lie_norm_sum"): 178 | for i in range(len(module._lie_norm_sum)): 179 | assert module._num_probes[i] == 1 180 | 181 | lie_norm = module._lie_norm_sum[i] / module._num_probes[i] 182 | mod = module.__class__.__name__ 183 | errs[ 184 | (name + (f"{i}" if i else ""), mod, module._op_counter[i], i) 185 | ] = lie_norm.item() 186 | 187 | errs = pd.Series(errs, index=pd.MultiIndex.from_tuples(errs.keys()), name=j) 188 | all_errs.append(errs) 189 | selective_apply(model, reset2) 190 | model.apply(reset) 191 | df = pd.DataFrame(all_errs) 192 | return df -------------------------------------------------------------------------------- /lee/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def grid_sample(image, optical): 6 | N, C, IH, IW = image.shape 7 | _, H, W, _ = optical.shape 8 | 9 | ix = optical[..., 0] 10 | iy = optical[..., 1] 11 | 12 | ix = ((ix + 1) / 2) * (IW - 1) 13 | iy = ((iy + 1) / 2) * (IH - 1) 14 | with torch.no_grad(): 15 | ix_nw = torch.floor(ix) 16 | iy_nw = torch.floor(iy) 17 | ix_ne = ix_nw + 1 18 | iy_ne = iy_nw 19 | ix_sw = ix_nw 20 | iy_sw = iy_nw + 1 21 | ix_se = ix_nw + 1 22 | iy_se = iy_nw + 1 23 | 24 | nw = (ix_se - ix) * (iy_se - iy) 25 | ne = (ix - ix_sw) * (iy_sw - iy) 26 | sw = (ix_ne - ix) * (iy - iy_ne) 27 | se = (ix - ix_nw) * (iy - iy_nw) 28 | 29 | with torch.no_grad(): 30 | ix_nw = IW - 1 - (IW - 1 - ix_nw.abs()).abs() 31 | iy_nw = IH - 1 - (IH - 1 - iy_nw.abs()).abs() 32 | 33 | ix_ne = IW - 1 - (IW - 1 - ix_ne.abs()).abs() 34 | iy_ne = IH - 1 - (IH - 1 - iy_ne.abs()).abs() 35 | 36 | ix_sw = IW - 1 - (IW - 1 - ix_sw.abs()).abs() 37 | iy_sw = IH - 1 - (IH - 1 - iy_sw.abs()).abs() 38 | 39 | ix_se = IW - 1 - (IW - 1 - ix_se.abs()).abs() 40 | iy_se = IH - 1 - (IH - 1 - iy_se.abs()).abs() 41 | 42 | image = image.view(N, C, IH * IW) 43 | 44 | nw_val = torch.gather( 45 | image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1) 46 | ) 47 | ne_val = torch.gather( 48 | image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1) 49 | ) 50 | sw_val = torch.gather( 51 | image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1) 52 | ) 53 | se_val = torch.gather( 54 | image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1) 55 | ) 56 | 57 | out_val = ( 58 | nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) 59 | + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) 60 | + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) 61 | + se_val.view(N, C, H, W) * se.view(N, 1, H, W) 62 | ) 63 | 64 | return out_val 65 | 66 | def img_like(img_shape): 67 | bchw = len(img_shape) == 4 and img_shape[-2:] != (1, 1) 68 | is_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] 69 | is_one_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 1 70 | is_two_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 2 71 | bnc = ( 72 | len(img_shape) == 3 73 | and img_shape[1] != 1 74 | and (is_square or is_one_off_square or is_two_off_square) 75 | ) 76 | return bchw or bnc 77 | 78 | def num_tokens(img_shape): 79 | if len(img_shape) == 4 and img_shape[-2:] != (1, 1): 80 | return 0 81 | is_one_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 1 82 | is_two_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 2 83 | return int(is_one_off_square * 1 or is_two_off_square * 2) 84 | 85 | 86 | def bnc2bchw(bnc, num_tokens): 87 | b, n, c = bnc.shape 88 | h = w = int(np.sqrt(n)) 89 | extra = bnc[:, :num_tokens, :] 90 | img = bnc[:, num_tokens:, :] 91 | return img.reshape(b, h, w, c).permute(0, 3, 1, 2), extra 92 | 93 | 94 | def bchw2bnc(bchw, tokens): 95 | b, c, h, w = bchw.shape 96 | n = h * w 97 | bnc = bchw.permute(0, 2, 3, 1).reshape(b, n, c) 98 | return torch.cat([tokens, bnc], dim=1) # assumes tokens are at the start 99 | 100 | 101 | def affine_transform(affineMatrices, img): 102 | assert img_like(img.shape) 103 | if len(img.shape) == 3: 104 | ntokens = num_tokens(img.shape) 105 | x, extra = bnc2bchw(img, ntokens) 106 | else: 107 | x = img 108 | flowgrid = F.affine_grid( 109 | affineMatrices, size=x.size(), align_corners=True 110 | ) # .double() 111 | # uses manual grid sample implementation to be able to compute 2nd derivatives 112 | # img_out = F.grid_sample(img, flowgrid,padding_mode="reflection",align_corners=True) 113 | transformed = grid_sample(x, flowgrid) 114 | if len(img.shape) == 3: 115 | transformed = bchw2bnc(transformed, extra) 116 | return transformed 117 | 118 | 119 | def translate(img, t, axis="x"): 120 | """Translates an image by a fraction of the size (sx,sy) in (0,1)""" 121 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 122 | affineMatrices[:, 0, 0] = 1 123 | affineMatrices[:, 1, 1] = 1 124 | if axis == "x": 125 | affineMatrices[:, 0, 2] = t 126 | else: 127 | affineMatrices[:, 1, 2] = t 128 | return affine_transform(affineMatrices, img) 129 | 130 | 131 | def rotate(img, angle): 132 | """Rotates an image by angle""" 133 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 134 | affineMatrices[:, 0, 0] = torch.cos(angle) 135 | affineMatrices[:, 0, 1] = torch.sin(angle) 136 | affineMatrices[:, 1, 0] = -torch.sin(angle) 137 | affineMatrices[:, 1, 1] = torch.cos(angle) 138 | return affine_transform(affineMatrices, img) 139 | 140 | 141 | def shear(img, t, axis="x"): 142 | """Shear an image by an amount t""" 143 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 144 | affineMatrices[:, 0, 0] = 1 145 | affineMatrices[:, 1, 1] = 1 146 | if axis == "x": 147 | affineMatrices[:, 0, 1] = t 148 | affineMatrices[:, 1, 0] = 0 149 | else: 150 | affineMatrices[:, 0, 1] = 0 151 | affineMatrices[:, 1, 0] = t 152 | return affine_transform(affineMatrices, img) 153 | 154 | 155 | def stretch(img, x, axis="x"): 156 | """Stretch an image by an amount t""" 157 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 158 | if axis == "x": 159 | affineMatrices[:, 0, 0] = 1 * (1 + x) 160 | else: 161 | affineMatrices[:, 1, 1] = 1 * (1 + x) 162 | return affine_transform(affineMatrices, img) 163 | 164 | 165 | def hyperbolic_rotate(img, angle): 166 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 167 | affineMatrices[:, 0, 0] = torch.cosh(angle) 168 | affineMatrices[:, 0, 1] = torch.sinh(angle) 169 | affineMatrices[:, 1, 0] = torch.sinh(angle) 170 | affineMatrices[:, 1, 1] = torch.cosh(angle) 171 | return affine_transform(affineMatrices, img) 172 | 173 | 174 | def scale(img, s): 175 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device) 176 | affineMatrices[:, 0, 0] = 1 - s 177 | affineMatrices[:, 1, 1] = 1 - s 178 | return affine_transform(affineMatrices, img) 179 | 180 | 181 | def saturate(img, t): 182 | img = img.clone() 183 | img *= 1 + t 184 | return img -------------------------------------------------------------------------------- /lee/loader.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import pandas as pd 3 | import random 4 | from functools import partial 5 | from typing import Callable 6 | 7 | import torch 8 | import torch.utils.data 9 | import numpy as np 10 | 11 | import sys 12 | sys.path.append("pytorch-image-models") 13 | from timm.data import create_dataset, resolve_data_config 14 | from timm.data.transforms_factory import create_transform 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.data.distributed_sampler import OrderedDistributedSampler 17 | 18 | def _worker_init(worker_id, worker_seeding='all'): 19 | worker_info = torch.utils.data.get_worker_info() 20 | assert worker_info.id == worker_id 21 | if isinstance(worker_seeding, Callable): 22 | seed = worker_seeding(worker_info) 23 | random.seed(seed) 24 | torch.manual_seed(seed) 25 | np.random.seed(seed % (2 ** 32 - 1)) 26 | else: 27 | assert worker_seeding in ('all', 'part') 28 | # random / torch seed already called in dataloader iter class w/ worker_info.seed 29 | # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed) 30 | if worker_seeding == 'all': 31 | np.random.seed(worker_info.seed % (2 ** 32 - 1)) 32 | 33 | def create_loader( 34 | dataset, 35 | input_size, 36 | batch_size, 37 | is_training=False, 38 | no_aug=False, 39 | re_prob=0., 40 | re_mode='const', 41 | re_count=1, 42 | re_split=False, 43 | scale=None, 44 | ratio=None, 45 | hflip=0.5, 46 | vflip=0., 47 | color_jitter=0.4, 48 | auto_augment=None, 49 | num_aug_splits=0, 50 | interpolation='bilinear', 51 | mean=IMAGENET_DEFAULT_MEAN, 52 | std=IMAGENET_DEFAULT_STD, 53 | num_workers=1, 54 | crop_pct=None, 55 | collate_fn=None, 56 | pin_memory=False, 57 | tf_preprocessing=False, 58 | persistent_workers=True, 59 | worker_seeding='all', 60 | sampler=None, 61 | ): 62 | inner_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.dataset.Subset) \ 63 | else dataset 64 | 65 | re_num_splits = 0 66 | if re_split: 67 | # apply RE to second half of batch if no aug split otherwise line up with aug split 68 | re_num_splits = num_aug_splits or 2 69 | inner_dataset.transform = create_transform( 70 | input_size, 71 | is_training=is_training, 72 | use_prefetcher=False, 73 | no_aug=no_aug, 74 | scale=scale, 75 | ratio=ratio, 76 | hflip=hflip, 77 | vflip=vflip, 78 | color_jitter=color_jitter, 79 | auto_augment=auto_augment, 80 | interpolation=interpolation, 81 | mean=mean, 82 | std=std, 83 | crop_pct=crop_pct, 84 | tf_preprocessing=tf_preprocessing, 85 | re_prob=re_prob, 86 | re_mode=re_mode, 87 | re_count=re_count, 88 | re_num_splits=re_num_splits, 89 | separate=num_aug_splits > 0, 90 | ) 91 | 92 | if collate_fn is None: 93 | collate_fn = torch.utils.data.dataloader.default_collate 94 | 95 | loader_class = torch.utils.data.DataLoader 96 | 97 | loader_args = dict( 98 | batch_size=batch_size, 99 | shuffle=False, 100 | num_workers=num_workers, 101 | sampler=sampler, 102 | collate_fn=collate_fn, 103 | pin_memory=pin_memory, 104 | drop_last=is_training, 105 | worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), 106 | persistent_workers=persistent_workers 107 | ) 108 | try: 109 | loader = loader_class(dataset, **loader_args) 110 | except TypeError as e: 111 | loader_args.pop('persistent_workers') # only in Pytorch 1.7+ 112 | loader = loader_class(dataset, **loader_args) 113 | 114 | return loader 115 | 116 | def eval_average_metrics_wstd(loader, metrics, max_mbs=None): 117 | total = len(loader) if max_mbs is None else min(max_mbs, len(loader)) 118 | dfs = [] 119 | with torch.no_grad(): 120 | for idx, minibatch in tqdm.tqdm(enumerate(loader), total=total): 121 | dfs.append(metrics(minibatch)) 122 | if max_mbs is not None and idx >= max_mbs: 123 | break 124 | df = pd.concat(dfs) 125 | return df 126 | 127 | def get_loaders( 128 | model, 129 | dataset, 130 | data_dir, 131 | batch_size, 132 | num_train, 133 | num_val, 134 | args, 135 | train_split="train", 136 | val_split="val", 137 | ): 138 | 139 | dataset_train = create_dataset( 140 | dataset, 141 | root=data_dir, 142 | split=train_split, 143 | is_training=False, 144 | batch_size=batch_size, 145 | ) 146 | if num_train < len(dataset_train): 147 | dataset_train, _ = torch.utils.data.random_split( 148 | dataset_train, 149 | [num_train, len(dataset_train) - num_train], 150 | generator=torch.Generator().manual_seed(42), 151 | ) 152 | 153 | dataset_eval = create_dataset( 154 | dataset, 155 | root=data_dir, 156 | split=val_split, 157 | is_training=False, 158 | batch_size=batch_size, 159 | ) 160 | if num_val < len(dataset_eval): 161 | dataset_eval, _ = torch.utils.data.random_split( 162 | dataset_eval, 163 | [num_val, len(dataset_eval) - num_val], 164 | generator=torch.Generator().manual_seed(42), 165 | ) 166 | 167 | data_config = resolve_data_config(vars(args), model=model, verbose=True) 168 | 169 | print(data_config) 170 | 171 | train_loader = create_loader( 172 | dataset_train, 173 | input_size=data_config["input_size"], 174 | batch_size=batch_size, 175 | is_training=False, 176 | interpolation=data_config["interpolation"], 177 | mean=data_config["mean"], 178 | std=data_config["std"], 179 | num_workers=1, 180 | crop_pct=data_config["crop_pct"], 181 | pin_memory=False, 182 | no_aug=True, 183 | hflip=0.0, 184 | color_jitter=0.0, 185 | ) 186 | 187 | eval_loader = create_loader( 188 | dataset_eval, 189 | input_size=data_config["input_size"], 190 | batch_size=batch_size, 191 | is_training=False, 192 | interpolation=data_config["interpolation"], 193 | mean=data_config["mean"], 194 | std=data_config["std"], 195 | num_workers=1, 196 | crop_pct=data_config["crop_pct"], 197 | pin_memory=False, 198 | no_aug=True, 199 | ) 200 | 201 | return train_loader, eval_loader -------------------------------------------------------------------------------- /sweep_configs/e2e_configs.py: -------------------------------------------------------------------------------- 1 | core_models = { 2 | "name": "core_models", 3 | "method": "grid", 4 | "parameters": { 5 | "modelname": {"values": [ 6 | 'vit_tiny_patch16_224', 7 | 'vit_tiny_patch16_384', 8 | 'vit_small_patch32_224', 9 | 'vit_small_patch32_384', 10 | 'vit_small_patch16_224', 11 | 'vit_small_patch16_384', 12 | 'vit_base_patch32_224', 13 | 'vit_base_patch32_384', 14 | 'vit_base_patch16_224', 15 | 'vit_base_patch16_384', 16 | 'vit_large_patch16_224', 17 | 'vit_large_patch32_384', 18 | 'vit_large_patch16_384', 19 | 'vit_base_patch16_224_miil', 20 | 'mixer_b16_224', 21 | 'mixer_l16_224', 22 | 'mixer_b16_224_miil', 23 | 'mixer_b16_224_in21k', 24 | 'mixer_l16_224_in21k', 25 | 'resmlp_12_224', 26 | 'resmlp_24_224', 27 | 'resmlp_36_224', 28 | 'resmlp_big_24_224', 29 | 'resmlp_12_distilled_224', 30 | 'resmlp_24_distilled_224', 31 | 'resmlp_36_distilled_224', 32 | 'resmlp_big_24_distilled_224', 33 | 'resmlp_big_24_224_in22ft1k', 34 | 'vgg11', 35 | 'vgg13', 36 | 'vgg16', 37 | 'vgg19', 38 | 'vgg11_bn', 39 | 'vgg13_bn', 40 | 'vgg16_bn', 41 | 'vgg19_bn', 42 | 'inception_resnet_v2', 43 | 'inception_v3', 44 | 'inception_v4', 45 | 'densenet121', 46 | 'densenet161', 47 | 'densenet169', 48 | 'densenet201', 49 | 'densenetblur121d', 50 | 'tv_densenet121', 51 | 'tv_resnet34', 52 | 'tv_resnet50', 53 | 'tv_resnet101', 54 | 'tv_resnet152', 55 | 'resnet34', 56 | 'resnet50', 57 | 'resnet50d', 58 | 'resnet101', 59 | 'resnet101d', 60 | 'resnet152', 61 | 'resnet152d', 62 | 'resnetrs101', 63 | 'resnetrs152', 64 | 'wide_resnet50_2', 65 | 'wide_resnet101_2', 66 | 'resnetblur50', 67 | 'convmixer_768_32', 68 | 'convmixer_1536_20', 69 | 'mobilenetv2_100', 70 | 'mobilenetv2_110d', 71 | 'mobilenetv2_120d', 72 | 'mobilenetv2_140', 73 | 'efficientnet_b0', 74 | 'efficientnet_b1', 75 | 'efficientnet_b2', 76 | 'efficientnet_b3', 77 | 'efficientnet_b4', 78 | "convnext_tiny", 79 | "convnext_small", 80 | "convnext_base", 81 | "convnext_large", 82 | "convnext_tiny_in22ft1k", 83 | "convnext_small_in22ft1k", 84 | "convnext_base_in22ft1k", 85 | "convnext_large_in22ft1k", 86 | "convnext_xlarge_in22ft1k", 87 | 'convit_small', 88 | 'convit_base', 89 | ]}, 90 | } 91 | } 92 | 93 | all_models = { 94 | "name": "all_models", 95 | "method": "grid", 96 | "parameters": { 97 | "modelname": {"values": [ 98 | 'vit_tiny_patch16_224', 99 | 'vit_tiny_patch16_384', 100 | 'vit_small_patch32_224', 101 | 'vit_small_patch32_384', 102 | 'vit_small_patch16_224', 103 | 'vit_small_patch16_384', 104 | 'vit_base_patch32_224', 105 | 'vit_base_patch32_384', 106 | 'vit_base_patch16_224', 107 | 'vit_base_patch16_384', 108 | 'vit_large_patch16_224', 109 | 'vit_large_patch32_384', 110 | 'vit_large_patch16_384', 111 | 'vit_base_patch16_224_miil', 112 | 'swin_base_patch4_window12_384', 113 | 'swin_base_patch4_window7_224', 114 | 'swin_large_patch4_window12_384', 115 | 'swin_large_patch4_window7_224', 116 | 'swin_small_patch4_window7_224', 117 | 'swin_tiny_patch4_window7_224', 118 | 'mixer_b16_224', 119 | 'mixer_l16_224', 120 | 'mixer_b16_224_miil', 121 | 'mixer_b16_224_in21k', 122 | 'mixer_l16_224_in21k', 123 | 'resmlp_12_224', 124 | 'resmlp_24_224', 125 | 'resmlp_36_224', 126 | 'resmlp_big_24_224', 127 | 'resmlp_12_distilled_224', 128 | 'resmlp_24_distilled_224', 129 | 'resmlp_36_distilled_224', 130 | 'resmlp_big_24_distilled_224', 131 | 'resmlp_big_24_224_in22ft1k', 132 | 'vgg11', 133 | 'vgg13', 134 | 'vgg16', 135 | 'vgg19', 136 | 'vgg11_bn', 137 | 'vgg13_bn', 138 | 'vgg16_bn', 139 | 'vgg19_bn', 140 | 'inception_resnet_v2', 141 | 'inception_v3', 142 | 'inception_v4', 143 | 'densenet121', 144 | 'densenet161', 145 | 'densenet169', 146 | 'densenet201', 147 | 'densenetblur121d', 148 | 'tv_densenet121', 149 | 'tv_resnet34', 150 | 'tv_resnet50', 151 | 'tv_resnet101', 152 | 'tv_resnet152', 153 | 'resnet34', 154 | 'resnet50', 155 | 'resnet50d', 156 | 'resnet101', 157 | 'resnet101d', 158 | 'resnet152', 159 | 'resnet152d', 160 | 'resnetrs101', 161 | 'resnetrs152', 162 | 'wide_resnet50_2', 163 | 'wide_resnet101_2', 164 | 'resnetblur50', 165 | 'ig_resnext101_32x16d', 166 | 'ig_resnext101_32x32d', 167 | 'ssl_resnext101_32x16d', 168 | 'convmixer_768_32', 169 | 'convmixer_1536_20', 170 | 'mobilenetv2_100', 171 | 'mobilenetv2_110d', 172 | 'mobilenetv2_120d', 173 | 'mobilenetv2_140', 174 | 'efficientnet_b0', 175 | 'efficientnet_b1', 176 | 'efficientnet_b2', 177 | 'efficientnet_b3', 178 | 'efficientnet_b4', 179 | "convnext_tiny", 180 | "convnext_small", 181 | "convnext_base", 182 | "convnext_large", 183 | "convnext_tiny_in22ft1k", 184 | "convnext_small_in22ft1k", 185 | "convnext_base_in22ft1k", 186 | "convnext_large_in22ft1k", 187 | "convnext_xlarge_in22ft1k", 188 | 'beit_base_patch16_224', 189 | 'beit_base_patch16_384', 190 | 'beit_large_patch16_224', 191 | 'beit_large_patch16_384', 192 | 'beit_large_patch16_512', 193 | 'convit_small', 194 | 'convit_base', 195 | 'resnet18','tv_resnet34','tv_resnet50','tv_resnet101','tv_resnet152', 196 | 'resnet18d','resnet34','resnet34d','resnet26','resnet26d','resnet26t', 197 | 'resnet50','resnet50d','resnet101','resnet101d','resnet152d', 198 | 'wide_resnet50_2','wide_resnet101_2', 199 | 'resnetblur50', 200 | 'resnext50_32x4d','resnext50d_32x4d', 201 | 'resnext101_32x8d', 'tv_resnext50_32x4d', 202 | 'ig_resnext101_32x8d','ig_resnext101_32x16d','ig_resnext101_32x32d','ig_resnext101_32x48d', 203 | 'ssl_resnet18','ssl_resnet50','ssl_resnext50_32x4d', 204 | 'ssl_resnext101_32x4d','ssl_resnext101_32x8d','ssl_resnext101_32x16d', 205 | 'resnetrs50','resnetrs101','resnetrs152','resnetrs200', 206 | 'resnetrs270','resnetrs350','resnetrs420', 207 | 'vit_tiny_patch16_224','vit_tiny_patch16_384', 208 | 'vit_small_patch32_224','vit_small_patch32_384','vit_small_patch16_224','vit_small_patch16_384', 209 | 'vit_base_patch32_224','vit_base_patch32_384','vit_base_patch16_224','vit_base_patch16_384', 210 | 'vit_large_patch16_224','vit_large_patch32_384','vit_large_patch16_384', 211 | 'vit_base_patch32_sam_224','vit_base_patch16_sam_224', 212 | 'deit_tiny_patch16_224','deit_small_patch16_224', 213 | 'deit_base_patch16_224','deit_base_patch16_384', 214 | 'deit_tiny_distilled_patch16_224','deit_small_distilled_patch16_224', 215 | 'deit_base_distilled_patch16_224','deit_base_distilled_patch16_384', 216 | 'vit_base_patch16_224_miil', 217 | 'vit_tiny_r_s16_p8_224','vit_tiny_r_s16_p8_384', 218 | 'vit_small_r26_s32_224','vit_small_r26_s32_384', 219 | 'vit_base_r50_s16_384', 220 | 'vit_large_r50_s32_224','vit_large_r50_s32_384', 221 | 'crossvit_15_240','crossvit_15_dagger_240','crossvit_15_dagger_408', 222 | 'crossvit_18_240','crossvit_18_dagger_240','crossvit_18_dagger_408', 223 | 'crossvit_9_240','crossvit_9_dagger_240', 224 | 'crossvit_base_240','crossvit_small_240','crossvit_tiny_240', 225 | 'beit_base_patch16_224','beit_base_patch16_384', 226 | 'beit_large_patch16_224','beit_large_patch16_384','beit_large_patch16_512', 227 | 'coat_tiny','coat_mini','coat_lite_tiny','coat_lite_mini','coat_lite_small', 228 | 'cait_xxs24_224','cait_xxs24_224','cait_xxs24_384', 229 | 'cait_xxs36_224','cait_xxs36_384','cait_xs24_384', 230 | 'cait_s24_224','cait_s24_384','cait_s36_384','cait_m36_384','cait_m48_448', 231 | 'convit_tiny','convit_small','convit_base', 232 | 'levit_128s','levit_128','levit_192','levit_256','levit_384', 233 | 'mixer_b16_224','mixer_l16_224','mixer_b16_224_miil', 234 | 'gmixer_24_224', 235 | 'resmlp_12_224','resmlp_24_224','resmlp_36_224','resmlp_big_24_224', 236 | 'resmlp_12_distilled_224','resmlp_24_distilled_224', 237 | 'resmlp_36_distilled_224','resmlp_big_24_distilled_224', 238 | 'resmlp_big_24_224_in22ft1k', 239 | 'gmlp_s16_224', 240 | 'pit_ti_224','pit_xs_224','pit_s_224','pit_b_224', 241 | 'pit_ti_distilled_224','pit_xs_distilled_224','pit_s_distilled_224','pit_b_distilled_224', 242 | 'swin_base_patch4_window12_384','swin_base_patch4_window7_224', 243 | 'swin_large_patch4_window12_384','swin_large_patch4_window7_224', 244 | 'swin_small_patch4_window7_224','swin_tiny_patch4_window7_224', 245 | 'xcit_nano_12_p16_224','xcit_nano_12_p16_224_dist','xcit_nano_12_p16_384_dist', 246 | 'xcit_tiny_12_p16_224','xcit_tiny_12_p16_224_dist','xcit_tiny_12_p16_384_dist', 247 | 'xcit_tiny_24_p16_224','xcit_tiny_24_p16_224_dist','xcit_tiny_24_p16_384_dist', 248 | 'xcit_small_12_p16_224','xcit_small_12_p16_224_dist','xcit_small_12_p16_384_dist', 249 | 'xcit_small_24_p16_224','xcit_small_24_p16_224_dist','xcit_small_24_p16_384_dist', 250 | 'xcit_medium_24_p16_224','xcit_medium_24_p16_224_dist','xcit_medium_24_p16_384_dist', 251 | 'xcit_large_24_p16_224','xcit_large_24_p16_224_dist','xcit_large_24_p16_384_dist', 252 | 'xcit_nano_12_p8_224','xcit_nano_12_p8_224_dist','xcit_nano_12_p8_384_dist', 253 | 'xcit_tiny_12_p8_224','xcit_tiny_12_p8_224_dist','xcit_tiny_12_p8_384_dist', 254 | 'xcit_tiny_24_p8_224','xcit_tiny_24_p8_224_dist','xcit_tiny_24_p8_384_dist', 255 | 'xcit_small_12_p8_224','xcit_small_12_p8_224_dist','xcit_small_12_p8_384_dist', 256 | 'xcit_small_24_p8_224','xcit_small_24_p8_224_dist','xcit_small_24_p8_384_dist', 257 | 'xcit_medium_24_p8_224','xcit_medium_24_p8_224_dist','xcit_medium_24_p8_384_dist', 258 | 'xcit_large_24_p8_224','xcit_large_24_p8_224_dist','xcit_large_24_p8_384_dist', 259 | 'vgg11_bn','vgg13_bn','vgg16_bn','vgg19_bn', 260 | 'inception_resnet_v2','inception_v3','inception_v4', 261 | 'densenet121','densenetblur121d','tv_densenet121', 262 | 'densenet169','densenet201','densenet161', 263 | 'mobilenetv2_110d','mobilenetv2_120d', 264 | 'efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4', 265 | 'gernet_s','gernet_m','gernet_l', 266 | 'repvgg_a2','repvgg_b0','repvgg_b1','repvgg_b1g4', 267 | 'repvgg_b2','repvgg_b2g4','repvgg_b3','repvgg_b3g4', 268 | 'resnet51q','resnet61q', 269 | 'resnext26ts','gcresnext26ts','seresnext26ts','eca_resnext26ts','bat_resnext26ts', 270 | 'resnet32ts','resnet33ts','gcresnet33ts','seresnet33ts','eca_resnet33ts', 271 | 'gcresnet50t','gcresnext50ts', 272 | 'cspresnet50','cspresnext50','cspdarknet53', 273 | 'dla34','dla46_c','dla46x_c','dla60x_c','dla60','dla60x', 274 | 'dla102','dla102x','dla102x2','dla169', 275 | 'dla60_res2net','dla60_res2next', 276 | 'dpn68','dpn68b','dpn92','dpn98','dpn131','dpn107', 277 | 'mnasnet_100','semnasnet_100', 278 | 'mobilenetv2_100','mobilenetv2_110d','mobilenetv2_120d','mobilenetv2_140', 279 | 'fbnetc_100','spnasnet_100', 280 | 'efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4', 281 | 'efficientnet_es','efficientnet_em','efficientnet_el', 282 | 'efficientnet_es_pruned','efficientnet_el_pruned','efficientnet_lite0', 283 | 'efficientnet_b1_pruned','efficientnet_b2_pruned','efficientnet_b3_pruned', 284 | 'efficientnetv2_rw_t','gc_efficientnetv2_rw_t','efficientnetv2_rw_s','efficientnetv2_rw_m', 285 | 'tf_efficientnet_b0','tf_efficientnet_b1','tf_efficientnet_b2','tf_efficientnet_b3', 286 | 'tf_efficientnet_b4','tf_efficientnet_b5','tf_efficientnet_b6','tf_efficientnet_b7','tf_efficientnet_b8', 287 | 'tf_efficientnet_b0_ap','tf_efficientnet_b1_ap','tf_efficientnet_b2_ap', 288 | 'tf_efficientnet_b3_ap','tf_efficientnet_b4_ap','tf_efficientnet_b5_ap', 289 | 'tf_efficientnet_b6_ap','tf_efficientnet_b7_ap','tf_efficientnet_b8_ap', 290 | 'tf_efficientnet_b0_ns','tf_efficientnet_b1_ns','tf_efficientnet_b2_ns', 291 | 'tf_efficientnet_b3_ns','tf_efficientnet_b4_ns','tf_efficientnet_b5_ns', 292 | 'tf_efficientnet_b6_ns','tf_efficientnet_b7_ns', 293 | 'tf_efficientnet_l2_ns_475','tf_efficientnet_l2_ns', 294 | 'tf_efficientnet_es','tf_efficientnet_em','tf_efficientnet_el', 295 | 'tf_efficientnet_cc_b0_4e','tf_efficientnet_cc_b0_8e','tf_efficientnet_cc_b1_8e', 296 | 'tf_efficientnet_lite0','tf_efficientnet_lite1','tf_efficientnet_lite2', 297 | 'tf_efficientnet_lite3','tf_efficientnet_lite4', 298 | 'tf_efficientnetv2_s','tf_efficientnetv2_m','tf_efficientnetv2_l', 299 | 'tf_efficientnetv2_s_in21ft1k','tf_efficientnetv2_m_in21ft1k', 300 | 'tf_efficientnetv2_l_in21ft1k','tf_efficientnetv2_xl_in21ft1k', 301 | 'tf_efficientnetv2_b0','tf_efficientnetv2_b1', 302 | 'tf_efficientnetv2_b2','tf_efficientnetv2_b3', 303 | 'mixnet_s','mixnet_m','mixnet_l','mixnet_xl','mixnet_xxl', 304 | 'tf_mixnet_s','tf_mixnet_m','tf_mixnet_l', 305 | 'ghostnet_100', 306 | 'gluon_resnet18_v1b','gluon_resnet34_v1b','gluon_resnet50_v1b', 307 | 'gluon_resnet101_v1b','gluon_resnet152_v1b', 308 | 'gluon_resnet50_v1c','gluon_resnet101_v1c','gluon_resnet152_v1c', 309 | 'gluon_resnet50_v1d','gluon_resnet101_v1d','gluon_resnet152_v1d', 310 | 'gluon_resnet50_v1s','gluon_resnet101_v1s','gluon_resnet152_v1s', 311 | 'gluon_resnext50_32x4d','gluon_resnext101_32x4d','gluon_resnext101_64x4d', 312 | 'gluon_seresnext50_32x4d','gluon_seresnext101_32x4d','gluon_seresnext101_64x4d', 313 | 'gluon_senet154', 314 | 'gluon_xception65', 315 | 'hrnet_w18_small','hrnet_w18_small_v2','hrnet_w18','hrnet_w30', 316 | 'hrnet_w32','hrnet_w40','hrnet_w44','hrnet_w48','hrnet_w64', 317 | 'inception_resnet_v2','ens_adv_inception_resnet_v2', 318 | 'inception_v3','tf_inception_v3', 'adv_inception_v3','gluon_inception_v3', 319 | 'inception_v4', 320 | 'mobilenetv3_large_100','mobilenetv3_large_100_miil','mobilenetv3_small_075','mobilenetv3_small_100', 321 | 'mobilenetv3_rw', 322 | 'tf_mobilenetv3_large_075','tf_mobilenetv3_large_100','tf_mobilenetv3_large_minimal_100', 323 | 'tf_mobilenetv3_small_075','tf_mobilenetv3_small_100','tf_mobilenetv3_small_minimal_100', 324 | 'nasnetalarge', 325 | 'jx_nest_base','jx_nest_small','jx_nest_tiny', 326 | 'dm_nfnet_f0','dm_nfnet_f1','dm_nfnet_f2','dm_nfnet_f3', 327 | 'dm_nfnet_f4','dm_nfnet_f5','dm_nfnet_f6', 328 | 'nfnet_l0','eca_nfnet_l0','eca_nfnet_l1','eca_nfnet_l2', 329 | 'nf_regnet_b1','nf_resnet50', 330 | 'pnasnet5large', 331 | 'regnetx_002','regnetx_004','regnetx_006','regnetx_008','regnetx_016', 332 | 'regnetx_032','regnetx_040','regnetx_064','regnetx_080','regnetx_120', 333 | 'regnetx_160','regnetx_320','regnety_002','regnety_004','regnety_006', 334 | 'regnety_008','regnety_016','regnety_032','regnety_040','regnety_064', 335 | 'regnety_080','regnety_120','regnety_160','regnety_320', 336 | 'res2net50_26w_4s','res2net50_48w_2s','res2net50_14w_8s', 337 | 'res2net50_26w_6s','res2net50_26w_8s','res2net101_26w_4s','res2next50', 338 | 'resnest14d','resnest26d','resnest50d','resnest101e', 339 | 'resnest200e','resnest269e','resnest50d_4s2x40d','resnest50d_1s4x24d', 340 | 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200', 341 | 'selecsls42b','selecsls60','selecsls60b', 342 | 'legacy_senet154','legacy_seresnet18','legacy_seresnet34','legacy_seresnet50', 343 | 'legacy_seresnet101','legacy_seresnet152','legacy_seresnext26_32x4d', 344 | 'legacy_seresnext50_32x4d','legacy_seresnext101_32x4d', 345 | 'skresnet18','skresnet34','skresnext50_32x4d', 346 | 'tnt_s_patch16_224', 347 | 'tresnet_m','tresnet_l','tresnet_xl', 348 | 'tresnet_m_448','tresnet_l_448','tresnet_xl_448', 349 | 'twins_pcpvt_small','twins_pcpvt_base','twins_pcpvt_large', 350 | 'twins_svt_small','twins_svt_base','twins_svt_large', 351 | 'vgg11','vgg13','vgg16','vgg19','vgg11_bn','vgg13_bn','vgg16_bn','vgg19_bn', 352 | 'ese_vovnet19b_dw','ese_vovnet39b', 353 | 'xception','xception41','xception65','xception71', 354 | ]}, 355 | } 356 | } 357 | --------------------------------------------------------------------------------