├── assets
├── test.txt
├── title_figure.pdf
└── title_figure.png
├── requirements.txt
├── .gitmodules
├── lee
├── e2e_lee.py
├── e2e_other.py
├── lie_derivs.py
├── layerwise_other.py
├── layerwise_lee.py
├── transforms.py
└── loader.py
├── LICENSE
├── README.md
├── .gitignore
├── sweep_configs
├── layerwise_configs.py
└── e2e_configs.py
├── exps_e2e.py
└── exps_layerwise.py
/assets/test.txt:
--------------------------------------------------------------------------------
1 | empty
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | torchvision>=0.5.0
3 | pyyaml
4 | e2cnn
5 | pandas
6 | wandb
--------------------------------------------------------------------------------
/assets/title_figure.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ngruver/lie-deriv/HEAD/assets/title_figure.pdf
--------------------------------------------------------------------------------
/assets/title_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ngruver/lie-deriv/HEAD/assets/title_figure.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pytorch-image-models"]
2 | path = pytorch-image-models
3 | url = https://github.com/rwightman/pytorch-image-models
4 | [submodule "stylegan3"]
5 | path = stylegan3
6 | url = https://github.com/NVlabs/stylegan3
7 |
--------------------------------------------------------------------------------
/lee/e2e_lee.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from .lie_derivs import *
4 |
5 | def get_equivariance_metrics(model, minibatch):
6 | x, y = minibatch
7 | if torch.cuda.is_available():
8 | model = model.cuda()
9 | x, y = x.cuda(), y.cuda()
10 |
11 | model = model.eval()
12 |
13 | model_probs = lambda x: F.softmax(model(x), dim=-1)
14 |
15 | errs = {
16 | "trans_x_deriv": translation_lie_deriv(model_probs, x, axis="x"),
17 | "trans_y_deriv": translation_lie_deriv(model_probs, x, axis="y"),
18 | "rot_deriv": rotation_lie_deriv(model_probs, x),
19 | "shear_x_deriv": shear_lie_deriv(model_probs, x, axis="x"),
20 | "shear_y_deriv": shear_lie_deriv(model_probs, x, axis="y"),
21 | "stretch_x_deriv": stretch_lie_deriv(model_probs, x, axis="x"),
22 | "stretch_y_deriv": stretch_lie_deriv(model_probs, x, axis="y"),
23 | "saturate_err": saturate_lie_deriv(model_probs, x),
24 | }
25 |
26 | metrics = {x: pd.Series(errs[x].abs().cpu().data.numpy().mean(-1)) for x in errs}
27 | df = pd.DataFrame.from_dict(metrics)
28 | return df
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Nate Gruver
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Lie Derivative for Learned Equivariance
2 |
3 |
4 |
5 |
6 | # Installation instructions
7 |
8 | Clone submodules and install requirements using
9 | ```bash
10 | git clone --recurse-submodules https://github.com/ngruver/lie-deriv.git
11 | cd lie-deriv
12 | pip install -r requirements.txt
13 | ```
14 |
15 | # Layer-wise equivariance experiments
16 |
17 | The equivariance of individual models can be calculated using `exps_layerwise.py`, for example
18 | ```bash
19 | python exps_layerwise.py \
20 | --modelname=resnet50 \
21 | --output_dir=$HOME/lee_results \
22 | --num_imgs=20 \
23 | --num_probes=100 \
24 | --transform=translation
25 | ```
26 | Default models and transforms are available in our wandb [sweep configuration](https://github.com/ngruver/lie-deriv/blob/main/sweep_configs/layerwise_configs.py)
27 |
28 | # End-to-end equivariance experiments
29 |
30 | ```bash
31 | python exps_e2e.py \
32 | --modelname=resnet50 \
33 | --output_dir=$HOME/lee_results \
34 | --num_datapoints=100
35 | ```
36 | We also include the wandb [sweep configuration](https://github.com/ngruver/lie-deriv/blob/main/sweep_configs/e2e_configs.py) for our end-to-end equivariance experiments.
37 |
38 | # Plotting and Visualization
39 |
40 | We make our results and plotting code available in a [google colab notebook](https://colab.research.google.com/drive/1ehsYp4t5AnXJwBypmXcEE9vzkdslByTw?usp=sharing)
41 |
42 | # Citation
43 |
44 | If you find our work helpful, please cite it with
45 |
46 | ```
47 | @misc{https://doi.org/10.48550/arxiv.2210.02984,
48 | doi = {10.48550/ARXIV.2210.02984},
49 | url = {https://arxiv.org/abs/2210.02984},
50 | author = {Gruver, Nate and Finzi, Marc and Goldblum, Micah and Wilson, Andrew Gordon},
51 | title = {The Lie Derivative for Measuring Learned Equivariance},
52 | publisher = {arXiv},
53 | year = {2022},
54 | copyright = {Creative Commons Attribution 4.0 International}
55 | }
56 | ```
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/sweep_configs/layerwise_configs.py:
--------------------------------------------------------------------------------
1 | minimal_models = {
2 | "name": "minimal_models",
3 | "method": "grid",
4 | "parameters": {
5 | "modelname": {"values": [
6 | "vgg13",
7 | "inception_resnet_v2",
8 | "resnet50",
9 | "wide_resnet50_2",
10 | "densenet121",
11 | "efficientnet_b1",
12 | ]},
13 | }
14 | }
15 |
16 | core_models = {
17 | "name": "core_models",
18 | "method": "grid",
19 | "parameters": {
20 | "modelname": {"values": [
21 | 'vit_tiny_patch16_224',
22 | 'vit_tiny_patch16_384',
23 | 'vit_small_patch32_224',
24 | 'vit_small_patch32_384',
25 | 'vit_small_patch16_224',
26 | 'vit_small_patch16_384',
27 | 'vit_base_patch32_224',
28 | 'vit_base_patch32_384',
29 | 'vit_base_patch16_224',
30 | 'vit_base_patch16_384',
31 | 'vit_large_patch16_224',
32 | 'vit_large_patch32_384',
33 | 'vit_large_patch16_384',
34 | 'vit_base_patch16_224_miil',
35 | 'swin_base_patch4_window12_384',
36 | 'swin_base_patch4_window7_224',
37 | 'swin_large_patch4_window12_384',
38 | 'swin_large_patch4_window7_224',
39 | 'swin_small_patch4_window7_224',
40 | 'swin_tiny_patch4_window7_224',
41 | 'mixer_b16_224',
42 | 'mixer_l16_224',
43 | 'mixer_b16_224_miil',
44 | 'mixer_b16_224_in21k',
45 | 'mixer_l16_224_in21k',
46 | 'resmlp_12_224',
47 | 'resmlp_24_224',
48 | 'resmlp_36_224',
49 | 'resmlp_big_24_224',
50 | 'resmlp_12_distilled_224',
51 | 'resmlp_24_distilled_224',
52 | 'resmlp_36_distilled_224',
53 | 'resmlp_big_24_distilled_224',
54 | 'resmlp_big_24_224_in22ft1k',
55 | 'vgg11',
56 | 'vgg13',
57 | 'vgg16',
58 | 'vgg19',
59 | 'vgg11_bn',
60 | 'vgg13_bn',
61 | 'vgg16_bn',
62 | 'vgg19_bn',
63 | 'inception_resnet_v2',
64 | 'inception_v3',
65 | 'inception_v4',
66 | 'densenet121',
67 | 'densenet161',
68 | 'densenet169',
69 | 'densenet201',
70 | 'densenetblur121d',
71 | 'tv_densenet121',
72 | 'tv_resnet34',
73 | 'tv_resnet50',
74 | 'tv_resnet101',
75 | 'tv_resnet152',
76 | 'resnet34',
77 | 'resnet50',
78 | 'resnet50d',
79 | 'resnet101',
80 | 'resnet101d',
81 | 'resnet152',
82 | 'resnet152d',
83 | 'resnetrs101',
84 | 'resnetrs152',
85 | 'wide_resnet50_2',
86 | 'wide_resnet101_2',
87 | 'resnetblur50',
88 | 'ig_resnext101_32x16d',
89 | 'ig_resnext101_32x32d',
90 | 'ssl_resnext101_32x16d',
91 | 'convmixer_768_32',
92 | 'convmixer_1536_20',
93 | 'mobilenetv2_100',
94 | 'mobilenetv2_110d',
95 | 'mobilenetv2_120d',
96 | 'mobilenetv2_140',
97 | 'efficientnet_b0',
98 | 'efficientnet_b1',
99 | 'efficientnet_b2',
100 | 'efficientnet_b3',
101 | 'efficientnet_b4',
102 | "convnext_tiny",
103 | "convnext_small",
104 | "convnext_base",
105 | "convnext_large",
106 | "convnext_tiny_in22ft1k",
107 | "convnext_small_in22ft1k",
108 | "convnext_base_in22ft1k",
109 | "convnext_large_in22ft1k",
110 | "convnext_xlarge_in22ft1k",
111 | 'beit_base_patch16_224',
112 | 'beit_base_patch16_384',
113 | 'beit_large_patch16_224',
114 | 'beit_large_patch16_384',
115 | 'beit_large_patch16_512',
116 | 'convit_small',
117 | 'convit_base',
118 | ]},
119 | }
120 | }
--------------------------------------------------------------------------------
/exps_e2e.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import wandb
4 | import numpy as np
5 | import argparse
6 | import pandas as pd
7 | from functools import partial
8 |
9 | import sys
10 | sys.path.append('pytorch-image-models')
11 | import timm
12 |
13 | from lee.e2e_lee import get_equivariance_metrics as get_lee_metrics
14 | from lee.e2e_other import get_equivariance_metrics as get_discrete_metrics
15 | from lee.loader import get_loaders, eval_average_metrics_wstd
16 |
17 | def numparams(model):
18 | return sum(p.numel() for p in model.parameters())
19 |
20 | def get_metrics(args, key, loader, model, max_mbs=400):
21 | discrete_metrics = eval_average_metrics_wstd(
22 | loader, partial(get_discrete_metrics, model), max_mbs=max_mbs,
23 | )
24 | lee_metrics = eval_average_metrics_wstd(
25 | loader, partial(get_lee_metrics, model), max_mbs=max_mbs,
26 | )
27 | metrics = pd.concat([lee_metrics, discrete_metrics], axis=1)
28 |
29 | metrics["dataset"] = key
30 | metrics["model"] = args.modelname
31 | metrics["params"] = numparams(model)
32 |
33 | return metrics
34 |
35 | def get_args_parser():
36 | parser = argparse.ArgumentParser(description='Training Config', add_help=False)
37 | parser.add_argument('--output_dir', metavar='NAME', default='equivariance_metrics_cnns',help='experiment name')
38 | parser.add_argument('--modelname', metavar='NAME', default='resnet18', help='model name')
39 | parser.add_argument('--num_datapoints', type=int, default=60, help='use pretrained model')
40 | return parser
41 |
42 | def main(args):
43 | wandb.init(project="LieDerivEquivariance", config=args)
44 | args.__dict__.update(wandb.config)
45 |
46 | print(args)
47 |
48 | if not os.path.exists(args.output_dir):
49 | os.makedirs(args.output_dir)
50 |
51 | print(args.modelname)
52 |
53 | model = getattr(timm.models, args.modelname)(pretrained=True)
54 | model.eval()
55 |
56 | evaluated_metrics = []
57 |
58 | imagenet_train_loader, imagenet_test_loader = get_loaders(
59 | model,
60 | dataset="imagenet",
61 | data_dir="/imagenet",
62 | batch_size=1,
63 | num_train=args.num_datapoints,
64 | num_val=args.num_datapoints,
65 | args=args,
66 | train_split='train',
67 | val_split='validation',
68 | )
69 |
70 | evaluated_metrics += [
71 | get_metrics(args, "Imagenet_train", imagenet_train_loader, model),
72 | get_metrics(args, "Imagenet_test", imagenet_test_loader, model)
73 | ]
74 | gc.collect()
75 |
76 | # _, cifar_test_loader = get_loaders(
77 | # model,
78 | # dataset="torch/cifar100",
79 | # data_dir="/scratch/nvg7279/cifar",
80 | # batch_size=1,
81 | # num_train=args.num_datapoints,
82 | # num_val=args.num_datapoints,
83 | # args=args,
84 | # train_split='train',
85 | # val_split='validation',
86 | # )
87 |
88 | # evaluated_metrics += [get_metrics(args, "cifar100", cifar_test_loader, model, max_mbs=args.num_datapoints)]
89 | # gc.collect()
90 |
91 | # _, retinopathy_loader = get_loaders(
92 | # model,
93 | # dataset="tfds/diabetic_retinopathy_detection",
94 | # data_dir="/scratch/nvg7279/tfds",
95 | # batch_size=1,
96 | # num_train=1e8,
97 | # num_val=1e8,
98 | # args=args,
99 | # train_split="train",
100 | # val_split="train",
101 | # )
102 |
103 | # evaluated_metrics += [get_metrics(args, "retinopathy", retinopathy_loader, model, max_mbs=args.num_datapoints)]
104 | # gc.collect()
105 |
106 | # _, histology_loader = get_loaders(
107 | # model,
108 | # dataset="tfds/colorectal_histology",
109 | # data_dir="/scratch/nvg7279/tfds",
110 | # batch_size=1,
111 | # num_train=1e8,
112 | # num_val=1e8,
113 | # args=args,
114 | # train_split="train",
115 | # val_split="train",
116 | # )
117 |
118 | # evaluated_metrics += [get_metrics(args, "histology", histology_loader, model, max_mbs=args.num_datapoints)]
119 | # gc.collect()
120 |
121 | df = pd.concat(evaluated_metrics)
122 | df.to_csv(os.path.join(args.output_dir, args.modelname + ".csv"))
123 |
124 | if __name__ == "__main__":
125 | args = get_args_parser().parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/exps_layerwise.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import copy
4 | import argparse
5 | import warnings
6 | import pandas as pd
7 | from functools import partial
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | import lee.layerwise_lee as lee
14 | import lee.layerwise_other as other_metrics
15 | from lee.loader import get_loaders
16 |
17 | import sys
18 | sys.path.append('pytorch-image-models')
19 | import timm
20 |
21 | def convert_inplace_relu_to_relu(model):
22 | for child_name, child in model.named_children():
23 | if isinstance(child, nn.ReLU):
24 | setattr(model, child_name, nn.ReLU())
25 | else:
26 | convert_inplace_relu_to_relu(child)
27 |
28 | def get_layerwise(args, model, loader, func):
29 | errlist = []
30 | for idx, (x, _) in tqdm.tqdm(
31 | enumerate(loader), total=len(loader)
32 | ):
33 | if idx >= args.num_imgs:
34 | break
35 |
36 | img = x.to(torch.device("cuda"))
37 | errors = func(model, img, num_probes=args.num_probes)
38 | errors["img_idx"] = idx
39 | errlist.append(errors)
40 |
41 | df = pd.concat(errlist, axis=0)
42 | df["model"] = args.modelname
43 | return df
44 |
45 | def main(args):
46 | os.makedirs(args.output_dir, exist_ok=True)
47 |
48 | print(args.modelname)
49 | print(args.transform)
50 |
51 | model = getattr(timm.models, args.modelname)(pretrained=True)
52 |
53 | convert_inplace_relu_to_relu(model)
54 | model = model.eval()
55 | model = model.to(torch.device("cuda"))
56 |
57 | _, loader = get_loaders(
58 | model,
59 | dataset="imagenet",
60 | data_dir="/imagenet",
61 | batch_size=1,
62 | num_train=args.num_imgs,
63 | num_val=args.num_imgs,
64 | args=args,
65 | )
66 |
67 | lee_transforms = ["translation","rotation","hyper_rotation","scale","saturate"]
68 | if args.use_lee and (args.transform in lee_transforms):
69 | lee_model = copy.deepcopy(model)
70 | lee.apply_hooks(lee_model, args.transform)
71 | lee_metrics = get_layerwise(
72 | args, lee_model, loader, func=lee.compute_equivariance_attribution
73 | )
74 |
75 | lee_output_dir = os.path.join(args.output_dir, "lee_" + args.transform)
76 | os.makedirs(lee_output_dir, exist_ok=True)
77 | lee_metrics.to_csv(os.path.join(lee_output_dir, args.modelname + ".csv"))
78 |
79 | other_metrics_transforms = ["integer_translation","translation","rotation"]
80 | if (not args.use_lee) and (args.transform in other_metrics_transforms):
81 | other_metrics_model = copy.deepcopy(model)
82 | other_metrics.apply_hooks(other_metrics_model)
83 | func = partial(other_metrics.compute_equivariance_attribution, args.transform)
84 | other_metrics_results = get_layerwise(
85 | args, other_metrics_model, loader, func=func
86 | )
87 |
88 | other_metrics_output_dir = os.path.join(args.output_dir, "stylegan3_" + args.transform)
89 | os.makedirs(other_metrics_output_dir, exist_ok=True)
90 | results_fn = args.modelname + "_norm_sqrt" + ".csv"
91 | other_metrics_results.to_csv(os.path.join(other_metrics_output_dir, results_fn))
92 |
93 |
94 | def get_args_parser():
95 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
96 |
97 | parser.add_argument(
98 | '--output_dir', metavar='NAME', default='equivariance_metrics_cnns',help='experiment name'
99 | )
100 | parser.add_argument(
101 | "--modelname", metavar="NAME", default="resnet50", help="model name"
102 | )
103 | parser.add_argument(
104 | "--num_imgs", type=int, default=20, help="Number of images to evaluate over"
105 | )
106 | parser.add_argument(
107 | "--num_probes",
108 | type=int,
109 | default=100,
110 | help="Number of probes to use in the estimator",
111 | )
112 | parser.add_argument(
113 | "--transform",
114 | metavar="NAME",
115 | default="translation",
116 | help="translation or rotation",
117 | )
118 | parser.add_argument(
119 | "--use_lee", type=int, default=0, help="Use LEE (rather than metric not in limit)"
120 | )
121 | return parser
122 |
123 | if __name__ == "__main__":
124 | args = get_args_parser().parse_args()
125 | main(args)
126 | # with warnings.catch_warnings():
127 | # warnings.simplefilter("ignore")
128 | # main(args)
129 |
--------------------------------------------------------------------------------
/lee/e2e_other.py:
--------------------------------------------------------------------------------
1 | from email import message_from_bytes
2 | from mailbox import Message
3 | import torch
4 | import torch.nn.functional as F
5 | import pandas as pd
6 | import numpy as np
7 |
8 | import sys
9 | sys.path.append("stylegan3")
10 | from stylegan3.metrics.equivariance import (
11 | apply_integer_translation,
12 | apply_fractional_translation,
13 | apply_fractional_rotation,
14 | )
15 |
16 | from .transforms import (
17 | translate,
18 | rotate,
19 | )
20 |
21 | def EQ_T(model, img, model_out, translate_max=0.125):
22 | d = []
23 | for _ in range(10):
24 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max
25 | t = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1])
26 | ref, _ = apply_integer_translation(img, t[0], t[1])
27 | t_model_out = model(ref)
28 | d.append((t_model_out - model_out).square().mean())
29 | #psnr = np.log10(2) * 20 - mse.log10() * 10
30 | return torch.stack(d).mean()
31 |
32 | def EQ_T_frac(model, img, model_out, translate_max=0.125):
33 | d = []
34 | for _ in range(10):
35 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max
36 | ref, _ = apply_fractional_translation(img, t[0], t[1])
37 | t_model_out = model(ref)
38 | d.append((t_model_out - model_out).square().mean())
39 | # psnr = np.log10(2) * 20 - mse.log10() * 10
40 | return torch.stack(d).mean()
41 |
42 | def EQ_R(model, img, model_out, rotate_max=1.0):
43 | d = []
44 | for _ in range(10):
45 | angle = (torch.rand([], device='cuda') * 2 - 1) * (rotate_max * np.pi)
46 | ref, _ = apply_fractional_rotation(img, angle)
47 | t_model_out = model(ref)
48 | d.append((t_model_out - model_out).square().mean())
49 | # psnr = np.log10(2) * 20 - mse.log10() * 10
50 | return torch.stack(d).mean()
51 |
52 |
53 | def translation_sample_invariance(model,inp_imgs,model_out,axis='x',eta=2.0):
54 | """ Lie derivative of model with respect to translation vector, assumes scalar output """
55 | shifted_model = lambda t: model(translate(inp_imgs,t,axis))
56 | d = []
57 | for _ in range(10):
58 | t_sample = (2 * eta) * torch.rand(1) - eta
59 | d.append((model_out - shifted_model(t_sample)).pow(2).mean())
60 | return torch.stack(d).mean(0).unsqueeze(0)
61 |
62 | def rotation_sample_invariance(model,inp_imgs,model_out,eta=np.pi//16):
63 | """ Lie derivative of model with respect to rotation, assumes scalar output """
64 | rotated_model = lambda theta: model(rotate(inp_imgs,theta))
65 | d = []
66 | for _ in range(10):
67 | theta_sample = (2 * eta) * torch.rand(1) - eta
68 | d.append((model_out - rotated_model(theta_sample)).pow(2).mean())
69 | return torch.stack(d).mean(0).unsqueeze(0)
70 |
71 |
72 | def get_equivariance_metrics(model, minibatch, num_probes=20):
73 | x, y = minibatch
74 | if torch.cuda.is_available():
75 | model = model.cuda()
76 | x, y = x.cuda(), y.cuda()
77 |
78 | model = model.eval()
79 |
80 | model_probs = lambda x: F.softmax(model(x), dim=-1)
81 | model_out = model_probs(x)
82 |
83 | yhat = model_out.argmax(dim=1) # .cpu()
84 | acc = (yhat == y).cpu().float().data.numpy()
85 |
86 | metrics = {}
87 | metrics["acc"] = pd.Series(acc)
88 |
89 | with torch.no_grad():
90 | for shift_x in range(8):
91 | rolled_img = torch.roll(x, shift_x, 2)
92 | rolled_yhat = model(rolled_img).argmax(dim=1)
93 | consistency = (rolled_yhat == yhat).cpu().data.numpy()
94 | metrics["consistency_x" + str(shift_x)] = pd.Series(consistency)
95 | for shift_y in range(8):
96 | rolled_img = torch.roll(x, shift_y, 3)
97 | rolled_yhat = model(rolled_img).argmax(dim=1)
98 | consistency = (rolled_yhat == yhat).cpu().data.numpy()
99 | metrics["consistency_y" + str(shift_y)] = pd.Series(consistency)
100 |
101 | eq_t = torch.stack([EQ_T(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0)
102 | eq_t_frac = torch.stack([EQ_T_frac(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0)
103 | eq_r = torch.stack([EQ_R(model_probs, x, model_out) for _ in range(num_probes)], dim=0).mean(0)
104 |
105 | metrics["eq_t"] = eq_t.cpu().data.numpy()
106 | metrics["eq_t_frac"] = eq_t_frac.cpu().data.numpy()
107 | metrics["eq_r"] = eq_r.cpu().data.numpy()
108 |
109 | metrics['trans_x_sample'] = translation_sample_invariance(model_probs,x,model_out,axis='x').abs().cpu().data.numpy()
110 | metrics['trans_y_sample'] = translation_sample_invariance(model_probs,x,model_out,axis='y').abs().cpu().data.numpy()
111 | metrics['rotate_sample'] = rotation_sample_invariance(model_probs,x,model_out).abs().cpu().data.numpy()
112 |
113 | df = pd.DataFrame.from_dict(metrics)
114 | return df
115 |
--------------------------------------------------------------------------------
/lee/lie_derivs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .transforms import *
3 |
4 | def jvp(f, x, u):
5 | """Jacobian vector product Df(x)u vs typical autograd VJP vTDF(x).
6 | Uses two backwards passes: computes (vTDF(x))u and then derivative wrt to v to get DF(x)u"""
7 | with torch.enable_grad():
8 | y = f(x)
9 | v = torch.ones_like(
10 | y, requires_grad=True
11 | ) # Dummy variable (could take any value)
12 | vJ = torch.autograd.grad(y, [x], [v], create_graph=True)
13 | Ju = torch.autograd.grad(vJ, [v], [u], create_graph=True)
14 | return Ju[0]
15 |
16 |
17 | def translation_lie_deriv(model, inp_imgs, axis="x"):
18 | """Lie derivative of model with respect to translation vector, output can be a scalar or an image"""
19 | # vector = vector.to(inp_imgs.device)
20 | if not img_like(inp_imgs.shape):
21 | return 0.0
22 |
23 | def shifted_model(t):
24 | # print("Input shape",inp_imgs.shape)
25 | shifted_img = translate(inp_imgs, t, axis)
26 | z = model(shifted_img)
27 | # print("Output shape",z.shape)
28 | # if model produces an output image, shift it back
29 | if img_like(z.shape):
30 | z = translate(z, -t, axis)
31 | # print('zshape',z.shape)
32 | return z
33 |
34 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
35 | lie_deriv = jvp(shifted_model, t, torch.ones_like(t, requires_grad=True))
36 | # print('Liederiv shape',lie_deriv.shape)
37 | # print(model.__class__.__name__)
38 | # print('')
39 | return lie_deriv
40 |
41 |
42 | def rotation_lie_deriv(model, inp_imgs):
43 | """Lie derivative of model with respect to rotation, assumes scalar output"""
44 | if not img_like(inp_imgs.shape):
45 | return 0.0
46 |
47 | def rotated_model(t):
48 | rotated_img = rotate(inp_imgs, t)
49 | z = model(rotated_img)
50 | if img_like(z.shape):
51 | z = rotate(z, -t)
52 | return z
53 |
54 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
55 | lie_deriv = jvp(rotated_model, t, torch.ones_like(t))
56 | return lie_deriv
57 |
58 |
59 | def hyperbolic_rotation_lie_deriv(model, inp_imgs):
60 | """Lie derivative of model with respect to rotation, assumes scalar output"""
61 | if not img_like(inp_imgs.shape):
62 | return 0.0
63 |
64 | def rotated_model(t):
65 | rotated_img = hyperbolic_rotate(inp_imgs, t)
66 | z = model(rotated_img)
67 | if img_like(z.shape):
68 | z = hyperbolic_rotate(z, -t)
69 | return z
70 |
71 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
72 | lie_deriv = jvp(rotated_model, t, torch.ones_like(t))
73 | return lie_deriv
74 |
75 |
76 | def scale_lie_deriv(model, inp_imgs):
77 | """Lie derivative of model with respect to rotation, assumes scalar output"""
78 | if not img_like(inp_imgs.shape):
79 | return 0.0
80 |
81 | def scaled_model(t):
82 | scaled_img = scale(inp_imgs, t)
83 | z = model(scaled_img)
84 | if img_like(z.shape):
85 | z = scale(z, -t)
86 | return z
87 |
88 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
89 | lie_deriv = jvp(scaled_model, t, torch.ones_like(t))
90 | return lie_deriv
91 |
92 |
93 | def shear_lie_deriv(model, inp_imgs, axis="x"):
94 | """Lie derivative of model with respect to shear, assumes scalar output"""
95 | if not img_like(inp_imgs.shape):
96 | return 0.0
97 |
98 | def sheared_model(t):
99 | sheared_img = shear(inp_imgs, t, axis)
100 | z = model(sheared_img)
101 | if img_like(z.shape):
102 | z = shear(z, -t, axis)
103 | return z
104 |
105 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
106 | lie_deriv = jvp(sheared_model, t, torch.ones_like(t))
107 | return lie_deriv
108 |
109 |
110 | def stretch_lie_deriv(model, inp_imgs, axis="x"):
111 | """Lie derivative of model with respect to stretch, assumes scalar output"""
112 | if not img_like(inp_imgs.shape):
113 | return 0.0
114 |
115 | def stretched_model(t):
116 | stretched_img = stretch(inp_imgs, t, axis)
117 | z = model(stretched_img)
118 | if img_like(z.shape):
119 | z = stretch(z, -t, axis)
120 | return z
121 |
122 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
123 | lie_deriv = jvp(stretched_model, t, torch.ones_like(t))
124 | return lie_deriv
125 |
126 |
127 | def saturate_lie_deriv(model, inp_imgs):
128 | """Lie derivative of model with respect to saturation, assumes scalar output"""
129 | if not img_like(inp_imgs.shape):
130 | return 0.0
131 |
132 | def saturated_model(t):
133 | saturated_img = saturate(inp_imgs, t)
134 | z = model(saturated_img)
135 | if img_like(z.shape):
136 | z = saturate(z, -t)
137 | return z
138 |
139 | t = torch.zeros(1, requires_grad=True, device=inp_imgs.device)
140 | lie_deriv = jvp(saturated_model, t, torch.ones_like(t))
141 | return lie_deriv
142 |
143 |
--------------------------------------------------------------------------------
/lee/layerwise_other.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from re import U
3 | import pandas as pd
4 | import numpy as np
5 | from collections import defaultdict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import sys
11 | sys.path.append("stylegan3")
12 | from metrics.equivariance import (
13 | apply_integer_translation,
14 | apply_fractional_translation,
15 | apply_fractional_rotation,
16 | )
17 |
18 | from .layerwise_lee import selective_apply
19 | from .transforms import img_like
20 |
21 | def EQ_T(module, img, model_out, translate_max=0.125):
22 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max
23 | t_in = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1])
24 | ref, _ = apply_integer_translation(img, t_in[0], t_in[1])
25 | with torch.no_grad():
26 | t_model_out = module(ref)
27 | t_out = (t * t_model_out.shape[-2] * t_model_out.shape[-1]).round() / (t_model_out.shape[-2] * t_model_out.shape[-1])
28 | t_model_out, out_mask = apply_integer_translation(t_model_out, -t_out[0], -t_out[1])
29 | squared_err = (t_model_out - model_out).square()
30 | denom = out_mask.sum().sqrt() if out_mask.sum() > 0 else 1.0
31 | mse = (squared_err * out_mask).sum() / denom
32 | if module.__class__.__name__ == 'BatchNorm2d':
33 | mse = mse * 0
34 | # psnr = np.log10(2) * 20 - mse.log10() * 10
35 | return mse
36 |
37 | def EQ_T_frac(module, img, model_out, translate_max=0.125):
38 | t = (torch.rand(2, device='cuda') * 2 - 1) * translate_max
39 | t_in = (t * img.shape[-2] * img.shape[-1]).round() / (img.shape[-2] * img.shape[-1])
40 | ref, _ = apply_fractional_translation(img, t_in[0], t_in[1])
41 | with torch.no_grad():
42 | t_model_out = module(ref)
43 | t_out = (t * t_model_out.shape[-2] * t_model_out.shape[-1]).round() / (t_model_out.shape[-2] * t_model_out.shape[-1])
44 | t_model_out, out_mask = apply_fractional_translation(t_model_out, -t_out[0], -t_out[1])
45 | squared_err = (t_model_out - model_out).square()
46 | denom = out_mask.sum().sqrt() if out_mask.sum() > 0 else 1.0
47 | mse = (squared_err * out_mask).sum() / denom
48 | if module.__class__.__name__ == 'BatchNorm2d':
49 | mse = mse * 0
50 | # psnr = np.log10(2) * 20 - mse.log10() * 10
51 | return mse
52 |
53 | def EQ_R(module, img, model_out, rotate_max=1.0):
54 | angle = (torch.rand(device='cuda') * 2 - 1) * (rotate_max * np.pi)
55 | ref, _ = apply_fractional_rotation(img, angle)
56 | with torch.no_grad():
57 | t_model_out = module(ref)
58 | t_model_out, out_mask = apply_fractional_rotation(t_model_out, -angle)
59 | # model_out, pseudo_mask = apply_fractional_pseudo_rotation(model_out, -angle)
60 | squared_err = (t_model_out - model_out).square()
61 | denom = out_mask.sum() if out_mask.sum() > 0 else 1.0
62 | mse = (squared_err * out_mask).sum()# / denom
63 | # psnr = np.log10(2) * 20 - mse.log10() * 10
64 | return mse
65 |
66 |
67 | def store_inputs(self, inputs, outputs):
68 | self._cached_input = inputs
69 | self._cached_output = outputs
70 |
71 | def apply_hooks(model):
72 | selective_apply(
73 | model, lambda m: m.register_forward_hook(store_inputs)
74 | )
75 |
76 | def reset(self):
77 | try:
78 | del self._cached_input
79 | except AttributeError:
80 | pass
81 |
82 | def compute_equivariance_attribution(transform, model, img_batch, num_probes=20):
83 | model = model.eval()
84 |
85 | equiv_metric = {
86 | "integer_translation": EQ_T,
87 | "translation": EQ_T_frac,
88 | "rotation": EQ_R,
89 | }[transform]
90 |
91 | model_fn = lambda z: F.softmax(model(z), dim=1)
92 | with torch.no_grad():
93 | model_fn(img_batch)
94 |
95 | name_counter = defaultdict(int)
96 | cached_info = []
97 | for i, (name, module) in enumerate(model.named_modules()):
98 | name_counter[name] += 1
99 |
100 | if not hasattr(module, "_cached_input") or \
101 | not hasattr(module, "_cached_output"):
102 | continue
103 |
104 | model_in = copy.deepcopy(module._cached_input)
105 | model_in = model_in[0] if isinstance(model_in, tuple) else model_in
106 | model_out = copy.deepcopy(module._cached_output)
107 | model_out = model_out[0] if isinstance(model_out, tuple) else model_out
108 |
109 | cached_info.append((name, module, model_in, model_out))
110 |
111 | model.apply(reset)
112 |
113 | all_errs = []
114 | for j in range(num_probes):
115 | errs = {}
116 | for i, (name, module, model_in, model_out) in enumerate(cached_info):
117 | num_tag = f"{name_counter[name]}" if name_counter[name] else ""
118 | mod = module.__class__.__name__
119 |
120 | if not img_like(model_in.shape) or not img_like(model_out.shape):
121 | equiv_err = 0.0
122 | else:
123 | with torch.no_grad():
124 | equiv_err = equiv_metric(module, model_in, model_out)
125 | equiv_err = equiv_err.cpu().data.numpy()
126 |
127 | errs[
128 | (name + (num_tag), mod, i, name_counter[name])
129 | ] = equiv_err
130 |
131 | errs = pd.Series(errs, index=pd.MultiIndex.from_tuples(errs.keys()), name=j)
132 | all_errs.append(errs)
133 |
134 | df = pd.DataFrame(all_errs)
135 | print(df)
136 | return df
--------------------------------------------------------------------------------
/lee/layerwise_lee.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import argparse
4 | import pandas as pd
5 | from functools import partial
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .lie_derivs import *
12 |
13 | class flag:
14 | pass
15 |
16 | singleton = flag
17 | singleton.compute_lie = True
18 | singleton.op_counter = 0
19 | singleton.fwd = True
20 |
21 | # TODO: record call order
22 | # TODO: if module is called 2nd time, create a copy and store separately
23 | # change self inplace?
24 | def store_inputs(lie_deriv_type, self, inputs, outputs):
25 | if not singleton.fwd or not singleton.compute_lie:
26 | return
27 | # if hasattr(self,'_lie_norm_sum'): return
28 | with torch.no_grad():
29 | singleton.compute_lie = False
30 | if not hasattr(self, "_lie_norm_sum"):
31 | self._lie_norm_sum = []
32 | self._lie_norm_sum_sq = []
33 | self._num_probes = []
34 | self._op_counter = []
35 | self._fwd_counter = 0
36 | self._lie_deriv_output = []
37 | if self._fwd_counter == len(self._lie_norm_sum):
38 | self._lie_norm_sum.append(0)
39 | self._lie_norm_sum_sq.append(0)
40 | self._num_probes.append(0)
41 | self._op_counter.append(singleton.op_counter)
42 | self._lie_deriv_output.append(0)
43 | singleton.op_counter += 1
44 | assert len(inputs) == 1
45 | (x,) = inputs
46 | x = x + torch.zeros_like(x)
47 | self._lie_deriv_output[self._fwd_counter] = lie_deriv_type(self, x)
48 |
49 | self._fwd_counter += 1
50 | singleton.compute_lie = True
51 |
52 | def store_estimator(self, grad_input, grad_output):
53 | if singleton.compute_lie:
54 | with torch.no_grad():
55 | assert len(grad_output) == 1
56 | bs = grad_output[0].shape[0]
57 | self._fwd_counter -= 1 # reverse of forward ordering
58 | i = self._fwd_counter
59 | # if i!=0: return
60 | # print('bwd',self._fwd_counter)
61 | estimator = (
62 | (
63 | (grad_output[0] * self._lie_deriv_output[i]).reshape(bs, -1).sum(-1)
64 | ** 2
65 | )
66 | .cpu()
67 | .data.numpy()
68 | )
69 | self._lie_norm_sum[i] += estimator
70 | self._lie_norm_sum_sq[i] += estimator**2
71 | self._num_probes[i] += 1
72 | # print("finished bwd",self)
73 |
74 |
75 | from timm.models.vision_transformer import Attention as A1
76 | # from timm.models.vision_transformer_wconvs import Attention as A2
77 | from timm.models.mlp_mixer import MixerBlock, Affine, SpatialGatingBlock
78 | from timm.models.layers import PatchEmbed, Mlp, DropPath, BlurPool2d
79 |
80 | # from timm.models.layers import FastAdaptiveAvgPool2d,AdaptiveAvgMaxPool2d
81 | from timm.models.layers import GatherExcite, EvoNormBatch2d
82 | from timm.models.senet import SEModule
83 | from timm.models.efficientnet_blocks import SqueezeExcite
84 | from timm.models.convit import MHSA, GPSA
85 |
86 | leaflist = (
87 | A1,
88 | # A2,
89 | MixerBlock,
90 | Affine,
91 | SpatialGatingBlock,
92 | PatchEmbed,
93 | Mlp,
94 | DropPath,
95 | BlurPool2d,
96 | )
97 | # leaflist += (nn.AdaptiveAvgPool2d,nn.MaxPool2d,nn.AvgPool2d)
98 | leaflist += (
99 | GatherExcite,
100 | EvoNormBatch2d,
101 | nn.BatchNorm2d,
102 | nn.BatchNorm1d,
103 | nn.LayerNorm,
104 | nn.GroupNorm,
105 | SEModule,
106 | SqueezeExcite,
107 | )
108 | leaflist += (MHSA, GPSA)
109 |
110 | def is_leaf(m):
111 | return (not hasattr(m, "children") or not list(m.children())) or isinstance(
112 | m, leaflist
113 | )
114 |
115 | def is_excluded(m):
116 | excluded_list = nn.Dropout
117 | return isinstance(m, excluded_list)
118 |
119 | def selective_apply(m, fn):
120 | if is_leaf(m):
121 | if not is_excluded(m):
122 | fn(m)
123 | else:
124 | for c in m.children():
125 | selective_apply(c, fn)
126 |
127 | def apply_hooks(model, lie_deriv_type):
128 | lie_deriv = {
129 | "translation": translation_lie_deriv,
130 | "rotation": rotation_lie_deriv,
131 | "hyper_rotation": hyperbolic_rotation_lie_deriv,
132 | "scale": scale_lie_deriv,
133 | "saturate": saturate_lie_deriv,
134 | }[lie_deriv_type]
135 |
136 | selective_apply(
137 | model, lambda m: m.register_forward_hook(partial(store_inputs, lie_deriv))
138 | )
139 | selective_apply(model, lambda m: m.register_backward_hook(store_estimator))
140 |
141 |
142 | def reset(self):
143 | try:
144 | # del self._input
145 | del self._lie_norm_sum
146 | del self._lie_norm_sum_sq
147 | del self._num_probes
148 | del self._op_counter
149 | del self._bwd_counter
150 | del self._lie_deriv_output
151 | except AttributeError:
152 | pass
153 |
154 |
155 | def reset2(self):
156 | self._lie_norm_sum = [0.0] * len(self._lie_norm_sum)
157 | self._lie_norm_sum_sq = [0.0] * len(self._lie_norm_sum)
158 | self._num_probes = [0.0] * len(self._lie_norm_sum)
159 |
160 |
161 | def compute_equivariance_attribution(model, img_batch, num_probes=100):
162 | model_fn = lambda z: F.softmax(model(z), dim=1)
163 | all_errs = []
164 | order = []
165 | for j in range(num_probes):
166 | singleton.fwd = True
167 | y = model_fn(img_batch)
168 | z = torch.randn(img_batch.shape[0], 1000).to(torch.device("cuda"))
169 | loss = (z * y).sum()
170 | singleton.fwd = False
171 | loss.backward()
172 | model.zero_grad()
173 | singleton.op_counter = 0
174 |
175 | errs = {}
176 | for name, module in model.named_modules():
177 | if hasattr(module, "_lie_norm_sum"):
178 | for i in range(len(module._lie_norm_sum)):
179 | assert module._num_probes[i] == 1
180 |
181 | lie_norm = module._lie_norm_sum[i] / module._num_probes[i]
182 | mod = module.__class__.__name__
183 | errs[
184 | (name + (f"{i}" if i else ""), mod, module._op_counter[i], i)
185 | ] = lie_norm.item()
186 |
187 | errs = pd.Series(errs, index=pd.MultiIndex.from_tuples(errs.keys()), name=j)
188 | all_errs.append(errs)
189 | selective_apply(model, reset2)
190 | model.apply(reset)
191 | df = pd.DataFrame(all_errs)
192 | return df
--------------------------------------------------------------------------------
/lee/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | def grid_sample(image, optical):
6 | N, C, IH, IW = image.shape
7 | _, H, W, _ = optical.shape
8 |
9 | ix = optical[..., 0]
10 | iy = optical[..., 1]
11 |
12 | ix = ((ix + 1) / 2) * (IW - 1)
13 | iy = ((iy + 1) / 2) * (IH - 1)
14 | with torch.no_grad():
15 | ix_nw = torch.floor(ix)
16 | iy_nw = torch.floor(iy)
17 | ix_ne = ix_nw + 1
18 | iy_ne = iy_nw
19 | ix_sw = ix_nw
20 | iy_sw = iy_nw + 1
21 | ix_se = ix_nw + 1
22 | iy_se = iy_nw + 1
23 |
24 | nw = (ix_se - ix) * (iy_se - iy)
25 | ne = (ix - ix_sw) * (iy_sw - iy)
26 | sw = (ix_ne - ix) * (iy - iy_ne)
27 | se = (ix - ix_nw) * (iy - iy_nw)
28 |
29 | with torch.no_grad():
30 | ix_nw = IW - 1 - (IW - 1 - ix_nw.abs()).abs()
31 | iy_nw = IH - 1 - (IH - 1 - iy_nw.abs()).abs()
32 |
33 | ix_ne = IW - 1 - (IW - 1 - ix_ne.abs()).abs()
34 | iy_ne = IH - 1 - (IH - 1 - iy_ne.abs()).abs()
35 |
36 | ix_sw = IW - 1 - (IW - 1 - ix_sw.abs()).abs()
37 | iy_sw = IH - 1 - (IH - 1 - iy_sw.abs()).abs()
38 |
39 | ix_se = IW - 1 - (IW - 1 - ix_se.abs()).abs()
40 | iy_se = IH - 1 - (IH - 1 - iy_se.abs()).abs()
41 |
42 | image = image.view(N, C, IH * IW)
43 |
44 | nw_val = torch.gather(
45 | image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)
46 | )
47 | ne_val = torch.gather(
48 | image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)
49 | )
50 | sw_val = torch.gather(
51 | image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)
52 | )
53 | se_val = torch.gather(
54 | image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)
55 | )
56 |
57 | out_val = (
58 | nw_val.view(N, C, H, W) * nw.view(N, 1, H, W)
59 | + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W)
60 | + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W)
61 | + se_val.view(N, C, H, W) * se.view(N, 1, H, W)
62 | )
63 |
64 | return out_val
65 |
66 | def img_like(img_shape):
67 | bchw = len(img_shape) == 4 and img_shape[-2:] != (1, 1)
68 | is_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1]
69 | is_one_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 1
70 | is_two_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 2
71 | bnc = (
72 | len(img_shape) == 3
73 | and img_shape[1] != 1
74 | and (is_square or is_one_off_square or is_two_off_square)
75 | )
76 | return bchw or bnc
77 |
78 | def num_tokens(img_shape):
79 | if len(img_shape) == 4 and img_shape[-2:] != (1, 1):
80 | return 0
81 | is_one_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 1
82 | is_two_off_square = int(int(np.sqrt(img_shape[1])) + 0.5) ** 2 == img_shape[1] - 2
83 | return int(is_one_off_square * 1 or is_two_off_square * 2)
84 |
85 |
86 | def bnc2bchw(bnc, num_tokens):
87 | b, n, c = bnc.shape
88 | h = w = int(np.sqrt(n))
89 | extra = bnc[:, :num_tokens, :]
90 | img = bnc[:, num_tokens:, :]
91 | return img.reshape(b, h, w, c).permute(0, 3, 1, 2), extra
92 |
93 |
94 | def bchw2bnc(bchw, tokens):
95 | b, c, h, w = bchw.shape
96 | n = h * w
97 | bnc = bchw.permute(0, 2, 3, 1).reshape(b, n, c)
98 | return torch.cat([tokens, bnc], dim=1) # assumes tokens are at the start
99 |
100 |
101 | def affine_transform(affineMatrices, img):
102 | assert img_like(img.shape)
103 | if len(img.shape) == 3:
104 | ntokens = num_tokens(img.shape)
105 | x, extra = bnc2bchw(img, ntokens)
106 | else:
107 | x = img
108 | flowgrid = F.affine_grid(
109 | affineMatrices, size=x.size(), align_corners=True
110 | ) # .double()
111 | # uses manual grid sample implementation to be able to compute 2nd derivatives
112 | # img_out = F.grid_sample(img, flowgrid,padding_mode="reflection",align_corners=True)
113 | transformed = grid_sample(x, flowgrid)
114 | if len(img.shape) == 3:
115 | transformed = bchw2bnc(transformed, extra)
116 | return transformed
117 |
118 |
119 | def translate(img, t, axis="x"):
120 | """Translates an image by a fraction of the size (sx,sy) in (0,1)"""
121 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
122 | affineMatrices[:, 0, 0] = 1
123 | affineMatrices[:, 1, 1] = 1
124 | if axis == "x":
125 | affineMatrices[:, 0, 2] = t
126 | else:
127 | affineMatrices[:, 1, 2] = t
128 | return affine_transform(affineMatrices, img)
129 |
130 |
131 | def rotate(img, angle):
132 | """Rotates an image by angle"""
133 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
134 | affineMatrices[:, 0, 0] = torch.cos(angle)
135 | affineMatrices[:, 0, 1] = torch.sin(angle)
136 | affineMatrices[:, 1, 0] = -torch.sin(angle)
137 | affineMatrices[:, 1, 1] = torch.cos(angle)
138 | return affine_transform(affineMatrices, img)
139 |
140 |
141 | def shear(img, t, axis="x"):
142 | """Shear an image by an amount t"""
143 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
144 | affineMatrices[:, 0, 0] = 1
145 | affineMatrices[:, 1, 1] = 1
146 | if axis == "x":
147 | affineMatrices[:, 0, 1] = t
148 | affineMatrices[:, 1, 0] = 0
149 | else:
150 | affineMatrices[:, 0, 1] = 0
151 | affineMatrices[:, 1, 0] = t
152 | return affine_transform(affineMatrices, img)
153 |
154 |
155 | def stretch(img, x, axis="x"):
156 | """Stretch an image by an amount t"""
157 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
158 | if axis == "x":
159 | affineMatrices[:, 0, 0] = 1 * (1 + x)
160 | else:
161 | affineMatrices[:, 1, 1] = 1 * (1 + x)
162 | return affine_transform(affineMatrices, img)
163 |
164 |
165 | def hyperbolic_rotate(img, angle):
166 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
167 | affineMatrices[:, 0, 0] = torch.cosh(angle)
168 | affineMatrices[:, 0, 1] = torch.sinh(angle)
169 | affineMatrices[:, 1, 0] = torch.sinh(angle)
170 | affineMatrices[:, 1, 1] = torch.cosh(angle)
171 | return affine_transform(affineMatrices, img)
172 |
173 |
174 | def scale(img, s):
175 | affineMatrices = torch.zeros(img.shape[0], 2, 3).to(img.device)
176 | affineMatrices[:, 0, 0] = 1 - s
177 | affineMatrices[:, 1, 1] = 1 - s
178 | return affine_transform(affineMatrices, img)
179 |
180 |
181 | def saturate(img, t):
182 | img = img.clone()
183 | img *= 1 + t
184 | return img
--------------------------------------------------------------------------------
/lee/loader.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import pandas as pd
3 | import random
4 | from functools import partial
5 | from typing import Callable
6 |
7 | import torch
8 | import torch.utils.data
9 | import numpy as np
10 |
11 | import sys
12 | sys.path.append("pytorch-image-models")
13 | from timm.data import create_dataset, resolve_data_config
14 | from timm.data.transforms_factory import create_transform
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from timm.data.distributed_sampler import OrderedDistributedSampler
17 |
18 | def _worker_init(worker_id, worker_seeding='all'):
19 | worker_info = torch.utils.data.get_worker_info()
20 | assert worker_info.id == worker_id
21 | if isinstance(worker_seeding, Callable):
22 | seed = worker_seeding(worker_info)
23 | random.seed(seed)
24 | torch.manual_seed(seed)
25 | np.random.seed(seed % (2 ** 32 - 1))
26 | else:
27 | assert worker_seeding in ('all', 'part')
28 | # random / torch seed already called in dataloader iter class w/ worker_info.seed
29 | # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
30 | if worker_seeding == 'all':
31 | np.random.seed(worker_info.seed % (2 ** 32 - 1))
32 |
33 | def create_loader(
34 | dataset,
35 | input_size,
36 | batch_size,
37 | is_training=False,
38 | no_aug=False,
39 | re_prob=0.,
40 | re_mode='const',
41 | re_count=1,
42 | re_split=False,
43 | scale=None,
44 | ratio=None,
45 | hflip=0.5,
46 | vflip=0.,
47 | color_jitter=0.4,
48 | auto_augment=None,
49 | num_aug_splits=0,
50 | interpolation='bilinear',
51 | mean=IMAGENET_DEFAULT_MEAN,
52 | std=IMAGENET_DEFAULT_STD,
53 | num_workers=1,
54 | crop_pct=None,
55 | collate_fn=None,
56 | pin_memory=False,
57 | tf_preprocessing=False,
58 | persistent_workers=True,
59 | worker_seeding='all',
60 | sampler=None,
61 | ):
62 | inner_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.dataset.Subset) \
63 | else dataset
64 |
65 | re_num_splits = 0
66 | if re_split:
67 | # apply RE to second half of batch if no aug split otherwise line up with aug split
68 | re_num_splits = num_aug_splits or 2
69 | inner_dataset.transform = create_transform(
70 | input_size,
71 | is_training=is_training,
72 | use_prefetcher=False,
73 | no_aug=no_aug,
74 | scale=scale,
75 | ratio=ratio,
76 | hflip=hflip,
77 | vflip=vflip,
78 | color_jitter=color_jitter,
79 | auto_augment=auto_augment,
80 | interpolation=interpolation,
81 | mean=mean,
82 | std=std,
83 | crop_pct=crop_pct,
84 | tf_preprocessing=tf_preprocessing,
85 | re_prob=re_prob,
86 | re_mode=re_mode,
87 | re_count=re_count,
88 | re_num_splits=re_num_splits,
89 | separate=num_aug_splits > 0,
90 | )
91 |
92 | if collate_fn is None:
93 | collate_fn = torch.utils.data.dataloader.default_collate
94 |
95 | loader_class = torch.utils.data.DataLoader
96 |
97 | loader_args = dict(
98 | batch_size=batch_size,
99 | shuffle=False,
100 | num_workers=num_workers,
101 | sampler=sampler,
102 | collate_fn=collate_fn,
103 | pin_memory=pin_memory,
104 | drop_last=is_training,
105 | worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
106 | persistent_workers=persistent_workers
107 | )
108 | try:
109 | loader = loader_class(dataset, **loader_args)
110 | except TypeError as e:
111 | loader_args.pop('persistent_workers') # only in Pytorch 1.7+
112 | loader = loader_class(dataset, **loader_args)
113 |
114 | return loader
115 |
116 | def eval_average_metrics_wstd(loader, metrics, max_mbs=None):
117 | total = len(loader) if max_mbs is None else min(max_mbs, len(loader))
118 | dfs = []
119 | with torch.no_grad():
120 | for idx, minibatch in tqdm.tqdm(enumerate(loader), total=total):
121 | dfs.append(metrics(minibatch))
122 | if max_mbs is not None and idx >= max_mbs:
123 | break
124 | df = pd.concat(dfs)
125 | return df
126 |
127 | def get_loaders(
128 | model,
129 | dataset,
130 | data_dir,
131 | batch_size,
132 | num_train,
133 | num_val,
134 | args,
135 | train_split="train",
136 | val_split="val",
137 | ):
138 |
139 | dataset_train = create_dataset(
140 | dataset,
141 | root=data_dir,
142 | split=train_split,
143 | is_training=False,
144 | batch_size=batch_size,
145 | )
146 | if num_train < len(dataset_train):
147 | dataset_train, _ = torch.utils.data.random_split(
148 | dataset_train,
149 | [num_train, len(dataset_train) - num_train],
150 | generator=torch.Generator().manual_seed(42),
151 | )
152 |
153 | dataset_eval = create_dataset(
154 | dataset,
155 | root=data_dir,
156 | split=val_split,
157 | is_training=False,
158 | batch_size=batch_size,
159 | )
160 | if num_val < len(dataset_eval):
161 | dataset_eval, _ = torch.utils.data.random_split(
162 | dataset_eval,
163 | [num_val, len(dataset_eval) - num_val],
164 | generator=torch.Generator().manual_seed(42),
165 | )
166 |
167 | data_config = resolve_data_config(vars(args), model=model, verbose=True)
168 |
169 | print(data_config)
170 |
171 | train_loader = create_loader(
172 | dataset_train,
173 | input_size=data_config["input_size"],
174 | batch_size=batch_size,
175 | is_training=False,
176 | interpolation=data_config["interpolation"],
177 | mean=data_config["mean"],
178 | std=data_config["std"],
179 | num_workers=1,
180 | crop_pct=data_config["crop_pct"],
181 | pin_memory=False,
182 | no_aug=True,
183 | hflip=0.0,
184 | color_jitter=0.0,
185 | )
186 |
187 | eval_loader = create_loader(
188 | dataset_eval,
189 | input_size=data_config["input_size"],
190 | batch_size=batch_size,
191 | is_training=False,
192 | interpolation=data_config["interpolation"],
193 | mean=data_config["mean"],
194 | std=data_config["std"],
195 | num_workers=1,
196 | crop_pct=data_config["crop_pct"],
197 | pin_memory=False,
198 | no_aug=True,
199 | )
200 |
201 | return train_loader, eval_loader
--------------------------------------------------------------------------------
/sweep_configs/e2e_configs.py:
--------------------------------------------------------------------------------
1 | core_models = {
2 | "name": "core_models",
3 | "method": "grid",
4 | "parameters": {
5 | "modelname": {"values": [
6 | 'vit_tiny_patch16_224',
7 | 'vit_tiny_patch16_384',
8 | 'vit_small_patch32_224',
9 | 'vit_small_patch32_384',
10 | 'vit_small_patch16_224',
11 | 'vit_small_patch16_384',
12 | 'vit_base_patch32_224',
13 | 'vit_base_patch32_384',
14 | 'vit_base_patch16_224',
15 | 'vit_base_patch16_384',
16 | 'vit_large_patch16_224',
17 | 'vit_large_patch32_384',
18 | 'vit_large_patch16_384',
19 | 'vit_base_patch16_224_miil',
20 | 'mixer_b16_224',
21 | 'mixer_l16_224',
22 | 'mixer_b16_224_miil',
23 | 'mixer_b16_224_in21k',
24 | 'mixer_l16_224_in21k',
25 | 'resmlp_12_224',
26 | 'resmlp_24_224',
27 | 'resmlp_36_224',
28 | 'resmlp_big_24_224',
29 | 'resmlp_12_distilled_224',
30 | 'resmlp_24_distilled_224',
31 | 'resmlp_36_distilled_224',
32 | 'resmlp_big_24_distilled_224',
33 | 'resmlp_big_24_224_in22ft1k',
34 | 'vgg11',
35 | 'vgg13',
36 | 'vgg16',
37 | 'vgg19',
38 | 'vgg11_bn',
39 | 'vgg13_bn',
40 | 'vgg16_bn',
41 | 'vgg19_bn',
42 | 'inception_resnet_v2',
43 | 'inception_v3',
44 | 'inception_v4',
45 | 'densenet121',
46 | 'densenet161',
47 | 'densenet169',
48 | 'densenet201',
49 | 'densenetblur121d',
50 | 'tv_densenet121',
51 | 'tv_resnet34',
52 | 'tv_resnet50',
53 | 'tv_resnet101',
54 | 'tv_resnet152',
55 | 'resnet34',
56 | 'resnet50',
57 | 'resnet50d',
58 | 'resnet101',
59 | 'resnet101d',
60 | 'resnet152',
61 | 'resnet152d',
62 | 'resnetrs101',
63 | 'resnetrs152',
64 | 'wide_resnet50_2',
65 | 'wide_resnet101_2',
66 | 'resnetblur50',
67 | 'convmixer_768_32',
68 | 'convmixer_1536_20',
69 | 'mobilenetv2_100',
70 | 'mobilenetv2_110d',
71 | 'mobilenetv2_120d',
72 | 'mobilenetv2_140',
73 | 'efficientnet_b0',
74 | 'efficientnet_b1',
75 | 'efficientnet_b2',
76 | 'efficientnet_b3',
77 | 'efficientnet_b4',
78 | "convnext_tiny",
79 | "convnext_small",
80 | "convnext_base",
81 | "convnext_large",
82 | "convnext_tiny_in22ft1k",
83 | "convnext_small_in22ft1k",
84 | "convnext_base_in22ft1k",
85 | "convnext_large_in22ft1k",
86 | "convnext_xlarge_in22ft1k",
87 | 'convit_small',
88 | 'convit_base',
89 | ]},
90 | }
91 | }
92 |
93 | all_models = {
94 | "name": "all_models",
95 | "method": "grid",
96 | "parameters": {
97 | "modelname": {"values": [
98 | 'vit_tiny_patch16_224',
99 | 'vit_tiny_patch16_384',
100 | 'vit_small_patch32_224',
101 | 'vit_small_patch32_384',
102 | 'vit_small_patch16_224',
103 | 'vit_small_patch16_384',
104 | 'vit_base_patch32_224',
105 | 'vit_base_patch32_384',
106 | 'vit_base_patch16_224',
107 | 'vit_base_patch16_384',
108 | 'vit_large_patch16_224',
109 | 'vit_large_patch32_384',
110 | 'vit_large_patch16_384',
111 | 'vit_base_patch16_224_miil',
112 | 'swin_base_patch4_window12_384',
113 | 'swin_base_patch4_window7_224',
114 | 'swin_large_patch4_window12_384',
115 | 'swin_large_patch4_window7_224',
116 | 'swin_small_patch4_window7_224',
117 | 'swin_tiny_patch4_window7_224',
118 | 'mixer_b16_224',
119 | 'mixer_l16_224',
120 | 'mixer_b16_224_miil',
121 | 'mixer_b16_224_in21k',
122 | 'mixer_l16_224_in21k',
123 | 'resmlp_12_224',
124 | 'resmlp_24_224',
125 | 'resmlp_36_224',
126 | 'resmlp_big_24_224',
127 | 'resmlp_12_distilled_224',
128 | 'resmlp_24_distilled_224',
129 | 'resmlp_36_distilled_224',
130 | 'resmlp_big_24_distilled_224',
131 | 'resmlp_big_24_224_in22ft1k',
132 | 'vgg11',
133 | 'vgg13',
134 | 'vgg16',
135 | 'vgg19',
136 | 'vgg11_bn',
137 | 'vgg13_bn',
138 | 'vgg16_bn',
139 | 'vgg19_bn',
140 | 'inception_resnet_v2',
141 | 'inception_v3',
142 | 'inception_v4',
143 | 'densenet121',
144 | 'densenet161',
145 | 'densenet169',
146 | 'densenet201',
147 | 'densenetblur121d',
148 | 'tv_densenet121',
149 | 'tv_resnet34',
150 | 'tv_resnet50',
151 | 'tv_resnet101',
152 | 'tv_resnet152',
153 | 'resnet34',
154 | 'resnet50',
155 | 'resnet50d',
156 | 'resnet101',
157 | 'resnet101d',
158 | 'resnet152',
159 | 'resnet152d',
160 | 'resnetrs101',
161 | 'resnetrs152',
162 | 'wide_resnet50_2',
163 | 'wide_resnet101_2',
164 | 'resnetblur50',
165 | 'ig_resnext101_32x16d',
166 | 'ig_resnext101_32x32d',
167 | 'ssl_resnext101_32x16d',
168 | 'convmixer_768_32',
169 | 'convmixer_1536_20',
170 | 'mobilenetv2_100',
171 | 'mobilenetv2_110d',
172 | 'mobilenetv2_120d',
173 | 'mobilenetv2_140',
174 | 'efficientnet_b0',
175 | 'efficientnet_b1',
176 | 'efficientnet_b2',
177 | 'efficientnet_b3',
178 | 'efficientnet_b4',
179 | "convnext_tiny",
180 | "convnext_small",
181 | "convnext_base",
182 | "convnext_large",
183 | "convnext_tiny_in22ft1k",
184 | "convnext_small_in22ft1k",
185 | "convnext_base_in22ft1k",
186 | "convnext_large_in22ft1k",
187 | "convnext_xlarge_in22ft1k",
188 | 'beit_base_patch16_224',
189 | 'beit_base_patch16_384',
190 | 'beit_large_patch16_224',
191 | 'beit_large_patch16_384',
192 | 'beit_large_patch16_512',
193 | 'convit_small',
194 | 'convit_base',
195 | 'resnet18','tv_resnet34','tv_resnet50','tv_resnet101','tv_resnet152',
196 | 'resnet18d','resnet34','resnet34d','resnet26','resnet26d','resnet26t',
197 | 'resnet50','resnet50d','resnet101','resnet101d','resnet152d',
198 | 'wide_resnet50_2','wide_resnet101_2',
199 | 'resnetblur50',
200 | 'resnext50_32x4d','resnext50d_32x4d',
201 | 'resnext101_32x8d', 'tv_resnext50_32x4d',
202 | 'ig_resnext101_32x8d','ig_resnext101_32x16d','ig_resnext101_32x32d','ig_resnext101_32x48d',
203 | 'ssl_resnet18','ssl_resnet50','ssl_resnext50_32x4d',
204 | 'ssl_resnext101_32x4d','ssl_resnext101_32x8d','ssl_resnext101_32x16d',
205 | 'resnetrs50','resnetrs101','resnetrs152','resnetrs200',
206 | 'resnetrs270','resnetrs350','resnetrs420',
207 | 'vit_tiny_patch16_224','vit_tiny_patch16_384',
208 | 'vit_small_patch32_224','vit_small_patch32_384','vit_small_patch16_224','vit_small_patch16_384',
209 | 'vit_base_patch32_224','vit_base_patch32_384','vit_base_patch16_224','vit_base_patch16_384',
210 | 'vit_large_patch16_224','vit_large_patch32_384','vit_large_patch16_384',
211 | 'vit_base_patch32_sam_224','vit_base_patch16_sam_224',
212 | 'deit_tiny_patch16_224','deit_small_patch16_224',
213 | 'deit_base_patch16_224','deit_base_patch16_384',
214 | 'deit_tiny_distilled_patch16_224','deit_small_distilled_patch16_224',
215 | 'deit_base_distilled_patch16_224','deit_base_distilled_patch16_384',
216 | 'vit_base_patch16_224_miil',
217 | 'vit_tiny_r_s16_p8_224','vit_tiny_r_s16_p8_384',
218 | 'vit_small_r26_s32_224','vit_small_r26_s32_384',
219 | 'vit_base_r50_s16_384',
220 | 'vit_large_r50_s32_224','vit_large_r50_s32_384',
221 | 'crossvit_15_240','crossvit_15_dagger_240','crossvit_15_dagger_408',
222 | 'crossvit_18_240','crossvit_18_dagger_240','crossvit_18_dagger_408',
223 | 'crossvit_9_240','crossvit_9_dagger_240',
224 | 'crossvit_base_240','crossvit_small_240','crossvit_tiny_240',
225 | 'beit_base_patch16_224','beit_base_patch16_384',
226 | 'beit_large_patch16_224','beit_large_patch16_384','beit_large_patch16_512',
227 | 'coat_tiny','coat_mini','coat_lite_tiny','coat_lite_mini','coat_lite_small',
228 | 'cait_xxs24_224','cait_xxs24_224','cait_xxs24_384',
229 | 'cait_xxs36_224','cait_xxs36_384','cait_xs24_384',
230 | 'cait_s24_224','cait_s24_384','cait_s36_384','cait_m36_384','cait_m48_448',
231 | 'convit_tiny','convit_small','convit_base',
232 | 'levit_128s','levit_128','levit_192','levit_256','levit_384',
233 | 'mixer_b16_224','mixer_l16_224','mixer_b16_224_miil',
234 | 'gmixer_24_224',
235 | 'resmlp_12_224','resmlp_24_224','resmlp_36_224','resmlp_big_24_224',
236 | 'resmlp_12_distilled_224','resmlp_24_distilled_224',
237 | 'resmlp_36_distilled_224','resmlp_big_24_distilled_224',
238 | 'resmlp_big_24_224_in22ft1k',
239 | 'gmlp_s16_224',
240 | 'pit_ti_224','pit_xs_224','pit_s_224','pit_b_224',
241 | 'pit_ti_distilled_224','pit_xs_distilled_224','pit_s_distilled_224','pit_b_distilled_224',
242 | 'swin_base_patch4_window12_384','swin_base_patch4_window7_224',
243 | 'swin_large_patch4_window12_384','swin_large_patch4_window7_224',
244 | 'swin_small_patch4_window7_224','swin_tiny_patch4_window7_224',
245 | 'xcit_nano_12_p16_224','xcit_nano_12_p16_224_dist','xcit_nano_12_p16_384_dist',
246 | 'xcit_tiny_12_p16_224','xcit_tiny_12_p16_224_dist','xcit_tiny_12_p16_384_dist',
247 | 'xcit_tiny_24_p16_224','xcit_tiny_24_p16_224_dist','xcit_tiny_24_p16_384_dist',
248 | 'xcit_small_12_p16_224','xcit_small_12_p16_224_dist','xcit_small_12_p16_384_dist',
249 | 'xcit_small_24_p16_224','xcit_small_24_p16_224_dist','xcit_small_24_p16_384_dist',
250 | 'xcit_medium_24_p16_224','xcit_medium_24_p16_224_dist','xcit_medium_24_p16_384_dist',
251 | 'xcit_large_24_p16_224','xcit_large_24_p16_224_dist','xcit_large_24_p16_384_dist',
252 | 'xcit_nano_12_p8_224','xcit_nano_12_p8_224_dist','xcit_nano_12_p8_384_dist',
253 | 'xcit_tiny_12_p8_224','xcit_tiny_12_p8_224_dist','xcit_tiny_12_p8_384_dist',
254 | 'xcit_tiny_24_p8_224','xcit_tiny_24_p8_224_dist','xcit_tiny_24_p8_384_dist',
255 | 'xcit_small_12_p8_224','xcit_small_12_p8_224_dist','xcit_small_12_p8_384_dist',
256 | 'xcit_small_24_p8_224','xcit_small_24_p8_224_dist','xcit_small_24_p8_384_dist',
257 | 'xcit_medium_24_p8_224','xcit_medium_24_p8_224_dist','xcit_medium_24_p8_384_dist',
258 | 'xcit_large_24_p8_224','xcit_large_24_p8_224_dist','xcit_large_24_p8_384_dist',
259 | 'vgg11_bn','vgg13_bn','vgg16_bn','vgg19_bn',
260 | 'inception_resnet_v2','inception_v3','inception_v4',
261 | 'densenet121','densenetblur121d','tv_densenet121',
262 | 'densenet169','densenet201','densenet161',
263 | 'mobilenetv2_110d','mobilenetv2_120d',
264 | 'efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4',
265 | 'gernet_s','gernet_m','gernet_l',
266 | 'repvgg_a2','repvgg_b0','repvgg_b1','repvgg_b1g4',
267 | 'repvgg_b2','repvgg_b2g4','repvgg_b3','repvgg_b3g4',
268 | 'resnet51q','resnet61q',
269 | 'resnext26ts','gcresnext26ts','seresnext26ts','eca_resnext26ts','bat_resnext26ts',
270 | 'resnet32ts','resnet33ts','gcresnet33ts','seresnet33ts','eca_resnet33ts',
271 | 'gcresnet50t','gcresnext50ts',
272 | 'cspresnet50','cspresnext50','cspdarknet53',
273 | 'dla34','dla46_c','dla46x_c','dla60x_c','dla60','dla60x',
274 | 'dla102','dla102x','dla102x2','dla169',
275 | 'dla60_res2net','dla60_res2next',
276 | 'dpn68','dpn68b','dpn92','dpn98','dpn131','dpn107',
277 | 'mnasnet_100','semnasnet_100',
278 | 'mobilenetv2_100','mobilenetv2_110d','mobilenetv2_120d','mobilenetv2_140',
279 | 'fbnetc_100','spnasnet_100',
280 | 'efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4',
281 | 'efficientnet_es','efficientnet_em','efficientnet_el',
282 | 'efficientnet_es_pruned','efficientnet_el_pruned','efficientnet_lite0',
283 | 'efficientnet_b1_pruned','efficientnet_b2_pruned','efficientnet_b3_pruned',
284 | 'efficientnetv2_rw_t','gc_efficientnetv2_rw_t','efficientnetv2_rw_s','efficientnetv2_rw_m',
285 | 'tf_efficientnet_b0','tf_efficientnet_b1','tf_efficientnet_b2','tf_efficientnet_b3',
286 | 'tf_efficientnet_b4','tf_efficientnet_b5','tf_efficientnet_b6','tf_efficientnet_b7','tf_efficientnet_b8',
287 | 'tf_efficientnet_b0_ap','tf_efficientnet_b1_ap','tf_efficientnet_b2_ap',
288 | 'tf_efficientnet_b3_ap','tf_efficientnet_b4_ap','tf_efficientnet_b5_ap',
289 | 'tf_efficientnet_b6_ap','tf_efficientnet_b7_ap','tf_efficientnet_b8_ap',
290 | 'tf_efficientnet_b0_ns','tf_efficientnet_b1_ns','tf_efficientnet_b2_ns',
291 | 'tf_efficientnet_b3_ns','tf_efficientnet_b4_ns','tf_efficientnet_b5_ns',
292 | 'tf_efficientnet_b6_ns','tf_efficientnet_b7_ns',
293 | 'tf_efficientnet_l2_ns_475','tf_efficientnet_l2_ns',
294 | 'tf_efficientnet_es','tf_efficientnet_em','tf_efficientnet_el',
295 | 'tf_efficientnet_cc_b0_4e','tf_efficientnet_cc_b0_8e','tf_efficientnet_cc_b1_8e',
296 | 'tf_efficientnet_lite0','tf_efficientnet_lite1','tf_efficientnet_lite2',
297 | 'tf_efficientnet_lite3','tf_efficientnet_lite4',
298 | 'tf_efficientnetv2_s','tf_efficientnetv2_m','tf_efficientnetv2_l',
299 | 'tf_efficientnetv2_s_in21ft1k','tf_efficientnetv2_m_in21ft1k',
300 | 'tf_efficientnetv2_l_in21ft1k','tf_efficientnetv2_xl_in21ft1k',
301 | 'tf_efficientnetv2_b0','tf_efficientnetv2_b1',
302 | 'tf_efficientnetv2_b2','tf_efficientnetv2_b3',
303 | 'mixnet_s','mixnet_m','mixnet_l','mixnet_xl','mixnet_xxl',
304 | 'tf_mixnet_s','tf_mixnet_m','tf_mixnet_l',
305 | 'ghostnet_100',
306 | 'gluon_resnet18_v1b','gluon_resnet34_v1b','gluon_resnet50_v1b',
307 | 'gluon_resnet101_v1b','gluon_resnet152_v1b',
308 | 'gluon_resnet50_v1c','gluon_resnet101_v1c','gluon_resnet152_v1c',
309 | 'gluon_resnet50_v1d','gluon_resnet101_v1d','gluon_resnet152_v1d',
310 | 'gluon_resnet50_v1s','gluon_resnet101_v1s','gluon_resnet152_v1s',
311 | 'gluon_resnext50_32x4d','gluon_resnext101_32x4d','gluon_resnext101_64x4d',
312 | 'gluon_seresnext50_32x4d','gluon_seresnext101_32x4d','gluon_seresnext101_64x4d',
313 | 'gluon_senet154',
314 | 'gluon_xception65',
315 | 'hrnet_w18_small','hrnet_w18_small_v2','hrnet_w18','hrnet_w30',
316 | 'hrnet_w32','hrnet_w40','hrnet_w44','hrnet_w48','hrnet_w64',
317 | 'inception_resnet_v2','ens_adv_inception_resnet_v2',
318 | 'inception_v3','tf_inception_v3', 'adv_inception_v3','gluon_inception_v3',
319 | 'inception_v4',
320 | 'mobilenetv3_large_100','mobilenetv3_large_100_miil','mobilenetv3_small_075','mobilenetv3_small_100',
321 | 'mobilenetv3_rw',
322 | 'tf_mobilenetv3_large_075','tf_mobilenetv3_large_100','tf_mobilenetv3_large_minimal_100',
323 | 'tf_mobilenetv3_small_075','tf_mobilenetv3_small_100','tf_mobilenetv3_small_minimal_100',
324 | 'nasnetalarge',
325 | 'jx_nest_base','jx_nest_small','jx_nest_tiny',
326 | 'dm_nfnet_f0','dm_nfnet_f1','dm_nfnet_f2','dm_nfnet_f3',
327 | 'dm_nfnet_f4','dm_nfnet_f5','dm_nfnet_f6',
328 | 'nfnet_l0','eca_nfnet_l0','eca_nfnet_l1','eca_nfnet_l2',
329 | 'nf_regnet_b1','nf_resnet50',
330 | 'pnasnet5large',
331 | 'regnetx_002','regnetx_004','regnetx_006','regnetx_008','regnetx_016',
332 | 'regnetx_032','regnetx_040','regnetx_064','regnetx_080','regnetx_120',
333 | 'regnetx_160','regnetx_320','regnety_002','regnety_004','regnety_006',
334 | 'regnety_008','regnety_016','regnety_032','regnety_040','regnety_064',
335 | 'regnety_080','regnety_120','regnety_160','regnety_320',
336 | 'res2net50_26w_4s','res2net50_48w_2s','res2net50_14w_8s',
337 | 'res2net50_26w_6s','res2net50_26w_8s','res2net101_26w_4s','res2next50',
338 | 'resnest14d','resnest26d','resnest50d','resnest101e',
339 | 'resnest200e','resnest269e','resnest50d_4s2x40d','resnest50d_1s4x24d',
340 | 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200',
341 | 'selecsls42b','selecsls60','selecsls60b',
342 | 'legacy_senet154','legacy_seresnet18','legacy_seresnet34','legacy_seresnet50',
343 | 'legacy_seresnet101','legacy_seresnet152','legacy_seresnext26_32x4d',
344 | 'legacy_seresnext50_32x4d','legacy_seresnext101_32x4d',
345 | 'skresnet18','skresnet34','skresnext50_32x4d',
346 | 'tnt_s_patch16_224',
347 | 'tresnet_m','tresnet_l','tresnet_xl',
348 | 'tresnet_m_448','tresnet_l_448','tresnet_xl_448',
349 | 'twins_pcpvt_small','twins_pcpvt_base','twins_pcpvt_large',
350 | 'twins_svt_small','twins_svt_base','twins_svt_large',
351 | 'vgg11','vgg13','vgg16','vgg19','vgg11_bn','vgg13_bn','vgg16_bn','vgg19_bn',
352 | 'ese_vovnet19b_dw','ese_vovnet39b',
353 | 'xception','xception41','xception65','xception71',
354 | ]},
355 | }
356 | }
357 |
--------------------------------------------------------------------------------