├── .gitignore ├── README.md ├── alpaca_data.json ├── alpaca_embed.py ├── alpaca_sample.py ├── data ├── alpaca │ ├── alpaca_data_dq_k5_1k.json │ ├── alpaca_data_dq_k5_5k.json │ ├── alpaca_data_random_1k.json │ └── alpaca_data_random_5k.json └── cifar10 │ └── select_indices_CIFAR10_0.125.npy ├── dq ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ └── imagenet.py ├── methods │ ├── __init__.py │ ├── coresetmethod.py │ ├── earlytrain.py │ ├── methods_utils │ │ ├── __init__.py │ │ ├── cossim.py │ │ ├── euclidean.py │ │ ├── submodular_function.py │ │ └── submodular_optimizer.py │ ├── submodular.py │ └── uniform.py └── nets │ ├── __init__.py │ ├── nets_utils │ ├── __init__.py │ ├── parallel.py │ └── recorder.py │ ├── resnet.py │ └── vit.py ├── figs ├── effects.png └── pipeline-classification.png ├── mae_models.py ├── pretrained └── .gitkeep ├── pytorch_image_models ├── .gitattributes ├── .github │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── config.yml │ │ └── feature_request.md │ └── workflows │ │ └── tests.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── avg_checkpoints.py ├── benchmark.py ├── clean_checkpoint.py ├── convert │ ├── convert_from_mxnet.py │ └── convert_nest_flax.py ├── distributed_train.sh ├── docs │ ├── archived_changes.md │ ├── changes.md │ ├── feature_extraction.md │ ├── index.md │ ├── javascripts │ │ └── tables.js │ ├── models.md │ ├── models │ │ ├── .pages │ │ ├── .templates │ │ │ ├── code_snippets.md │ │ │ ├── generate_readmes.py │ │ │ └── models │ │ │ │ ├── adversarial-inception-v3.md │ │ │ │ ├── advprop.md │ │ │ │ ├── big-transfer.md │ │ │ │ ├── csp-darknet.md │ │ │ │ ├── csp-resnet.md │ │ │ │ ├── csp-resnext.md │ │ │ │ ├── densenet.md │ │ │ │ ├── dla.md │ │ │ │ ├── dpn.md │ │ │ │ ├── ecaresnet.md │ │ │ │ ├── efficientnet-pruned.md │ │ │ │ ├── efficientnet.md │ │ │ │ ├── ensemble-adversarial.md │ │ │ │ ├── ese-vovnet.md │ │ │ │ ├── fbnet.md │ │ │ │ ├── gloun-inception-v3.md │ │ │ │ ├── gloun-resnet.md │ │ │ │ ├── gloun-resnext.md │ │ │ │ ├── gloun-senet.md │ │ │ │ ├── gloun-seresnext.md │ │ │ │ ├── gloun-xception.md │ │ │ │ ├── hrnet.md │ │ │ │ ├── ig-resnext.md │ │ │ │ ├── inception-resnet-v2.md │ │ │ │ ├── inception-v3.md │ │ │ │ ├── inception-v4.md │ │ │ │ ├── legacy-se-resnet.md │ │ │ │ ├── legacy-se-resnext.md │ │ │ │ ├── legacy-senet.md │ │ │ │ ├── mixnet.md │ │ │ │ ├── mnasnet.md │ │ │ │ ├── mobilenet-v2.md │ │ │ │ ├── mobilenet-v3.md │ │ │ │ ├── nasnet.md │ │ │ │ ├── noisy-student.md │ │ │ │ ├── pnasnet.md │ │ │ │ ├── regnetx.md │ │ │ │ ├── regnety.md │ │ │ │ ├── res2net.md │ │ │ │ ├── res2next.md │ │ │ │ ├── resnest.md │ │ │ │ ├── resnet-d.md │ │ │ │ ├── resnet.md │ │ │ │ ├── resnext.md │ │ │ │ ├── rexnet.md │ │ │ │ ├── se-resnet.md │ │ │ │ ├── selecsls.md │ │ │ │ ├── seresnext.md │ │ │ │ ├── skresnet.md │ │ │ │ ├── skresnext.md │ │ │ │ ├── spnasnet.md │ │ │ │ ├── ssl-resnet.md │ │ │ │ ├── ssl-resnext.md │ │ │ │ ├── swsl-resnet.md │ │ │ │ ├── swsl-resnext.md │ │ │ │ ├── tf-efficientnet-condconv.md │ │ │ │ ├── tf-efficientnet-lite.md │ │ │ │ ├── tf-efficientnet.md │ │ │ │ ├── tf-inception-v3.md │ │ │ │ ├── tf-mixnet.md │ │ │ │ ├── tf-mobilenet-v3.md │ │ │ │ ├── tresnet.md │ │ │ │ ├── vision-transformer.md │ │ │ │ ├── wide-resnet.md │ │ │ │ └── xception.md │ │ ├── adversarial-inception-v3.md │ │ ├── advprop.md │ │ ├── big-transfer.md │ │ ├── csp-darknet.md │ │ ├── csp-resnet.md │ │ ├── csp-resnext.md │ │ ├── densenet.md │ │ ├── dla.md │ │ ├── dpn.md │ │ ├── ecaresnet.md │ │ ├── efficientnet-pruned.md │ │ ├── efficientnet.md │ │ ├── ensemble-adversarial.md │ │ ├── ese-vovnet.md │ │ ├── fbnet.md │ │ ├── gloun-inception-v3.md │ │ ├── gloun-resnet.md │ │ ├── gloun-resnext.md │ │ ├── gloun-senet.md │ │ ├── gloun-seresnext.md │ │ ├── gloun-xception.md │ │ ├── hrnet.md │ │ ├── ig-resnext.md │ │ ├── inception-resnet-v2.md │ │ ├── inception-v3.md │ │ ├── inception-v4.md │ │ ├── legacy-se-resnet.md │ │ ├── legacy-se-resnext.md │ │ ├── legacy-senet.md │ │ ├── mixnet.md │ │ ├── mnasnet.md │ │ ├── mobilenet-v2.md │ │ ├── mobilenet-v3.md │ │ ├── nasnet.md │ │ ├── noisy-student.md │ │ ├── pnasnet.md │ │ ├── regnetx.md │ │ ├── regnety.md │ │ ├── res2net.md │ │ ├── res2next.md │ │ ├── resnest.md │ │ ├── resnet-d.md │ │ ├── resnet.md │ │ ├── resnext.md │ │ ├── rexnet.md │ │ ├── se-resnet.md │ │ ├── selecsls.md │ │ ├── seresnext.md │ │ ├── skresnet.md │ │ ├── skresnext.md │ │ ├── spnasnet.md │ │ ├── ssl-resnet.md │ │ ├── ssl-resnext.md │ │ ├── swsl-resnet.md │ │ ├── swsl-resnext.md │ │ ├── tf-efficientnet-condconv.md │ │ ├── tf-efficientnet-lite.md │ │ ├── tf-efficientnet.md │ │ ├── tf-inception-v3.md │ │ ├── tf-mixnet.md │ │ ├── tf-mobilenet-v3.md │ │ ├── tresnet.md │ │ ├── vision-transformer.md │ │ ├── wide-resnet.md │ │ └── xception.md │ ├── results.md │ ├── scripts.md │ └── training_hparam_examples.md ├── hubconf.py ├── inference.py ├── mkdocs.yml ├── model-index.yml ├── requirements-docs.txt ├── requirements-modelindex.txt ├── requirements.txt ├── results │ ├── README.md │ ├── benchmark-infer-amp-nchw-pt110-cu113-rtx3090.csv │ ├── benchmark-infer-amp-nchw-pt111-cu113-rtx3090.csv │ ├── benchmark-infer-amp-nhwc-pt110-cu113-rtx3090.csv │ ├── benchmark-infer-amp-nhwc-pt111-cu113-rtx3090.csv │ ├── benchmark-train-amp-nchw-pt110-cu113-rtx3090.csv │ ├── benchmark-train-amp-nchw-pt111-cu113-rtx3090.csv │ ├── benchmark-train-amp-nhwc-pt110-cu113-rtx3090.csv │ ├── benchmark-train-amp-nhwc-pt111-cu113-rtx3090.csv │ ├── generate_csv_results.py │ ├── imagenet21k_goog_synsets.txt │ ├── imagenet_a_indices.txt │ ├── imagenet_a_synsets.txt │ ├── imagenet_r_indices.txt │ ├── imagenet_r_synsets.txt │ ├── imagenet_real_labels.json │ ├── imagenet_synsets.txt │ ├── model_metadata-in1k.csv │ ├── results-imagenet-a-clean.csv │ ├── results-imagenet-a.csv │ ├── results-imagenet-r-clean.csv │ ├── results-imagenet-r.csv │ ├── results-imagenet-real.csv │ ├── results-imagenet.csv │ ├── results-imagenetv2-matched-frequency.csv │ └── results-sketch.csv ├── setup.cfg ├── setup.py ├── tests │ ├── __init__.py │ ├── test_layers.py │ ├── test_models.py │ ├── test_optim.py │ └── test_utils.py ├── timm │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── auto_augment.py │ │ ├── config.py │ │ ├── constants.py │ │ ├── dataset.py │ │ ├── dataset_factory.py │ │ ├── distributed_sampler.py │ │ ├── loader.py │ │ ├── mixup.py │ │ ├── parsers │ │ │ ├── __init__.py │ │ │ ├── class_map.py │ │ │ ├── img_extensions.py │ │ │ ├── parser.py │ │ │ ├── parser_factory.py │ │ │ ├── parser_image_folder.py │ │ │ ├── parser_image_in_tar.py │ │ │ ├── parser_image_tar.py │ │ │ └── parser_tfds.py │ │ ├── random_erasing.py │ │ ├── real_labels.py │ │ ├── tf_preprocessing.py │ │ ├── transforms.py │ │ └── transforms_factory.py │ ├── loss │ │ ├── __init__.py │ │ ├── asymmetric_loss.py │ │ ├── binary_cross_entropy.py │ │ ├── cross_entropy.py │ │ └── jsd.py │ ├── models │ │ ├── __init__.py │ │ ├── beit.py │ │ ├── byoanet.py │ │ ├── byobnet.py │ │ ├── cait.py │ │ ├── coat.py │ │ ├── convit.py │ │ ├── convmixer.py │ │ ├── convnext.py │ │ ├── crossvit.py │ │ ├── cspnet.py │ │ ├── deit.py │ │ ├── densenet.py │ │ ├── dla.py │ │ ├── dpn.py │ │ ├── edgenext.py │ │ ├── efficientnet.py │ │ ├── efficientnet_blocks.py │ │ ├── efficientnet_builder.py │ │ ├── factory.py │ │ ├── features.py │ │ ├── fx_features.py │ │ ├── ghostnet.py │ │ ├── gluon_resnet.py │ │ ├── gluon_xception.py │ │ ├── hardcorenas.py │ │ ├── helpers.py │ │ ├── hrnet.py │ │ ├── hub.py │ │ ├── inception_resnet_v2.py │ │ ├── inception_v3.py │ │ ├── inception_v4.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── activations_jit.py │ │ │ ├── activations_me.py │ │ │ ├── adaptive_avgmax_pool.py │ │ │ ├── attention_pool2d.py │ │ │ ├── blur_pool.py │ │ │ ├── bottleneck_attn.py │ │ │ ├── cbam.py │ │ │ ├── classifier.py │ │ │ ├── cond_conv2d.py │ │ │ ├── config.py │ │ │ ├── conv2d_same.py │ │ │ ├── conv_bn_act.py │ │ │ ├── create_act.py │ │ │ ├── create_attn.py │ │ │ ├── create_conv2d.py │ │ │ ├── create_norm_act.py │ │ │ ├── drop.py │ │ │ ├── eca.py │ │ │ ├── evo_norm.py │ │ │ ├── filter_response_norm.py │ │ │ ├── gather_excite.py │ │ │ ├── global_context.py │ │ │ ├── halo_attn.py │ │ │ ├── helpers.py │ │ │ ├── inplace_abn.py │ │ │ ├── lambda_layer.py │ │ │ ├── linear.py │ │ │ ├── median_pool.py │ │ │ ├── mixed_conv2d.py │ │ │ ├── ml_decoder.py │ │ │ ├── mlp.py │ │ │ ├── non_local_attn.py │ │ │ ├── norm.py │ │ │ ├── norm_act.py │ │ │ ├── padding.py │ │ │ ├── patch_embed.py │ │ │ ├── pool2d_same.py │ │ │ ├── pos_embed.py │ │ │ ├── selective_kernel.py │ │ │ ├── separable_conv.py │ │ │ ├── space_to_depth.py │ │ │ ├── split_attn.py │ │ │ ├── split_batchnorm.py │ │ │ ├── squeeze_excite.py │ │ │ ├── std_conv.py │ │ │ ├── test_time_pool.py │ │ │ ├── trace_utils.py │ │ │ └── weight_init.py │ │ ├── levit.py │ │ ├── mlp_mixer.py │ │ ├── mobilenetv3.py │ │ ├── mobilevit.py │ │ ├── nasnet.py │ │ ├── nest.py │ │ ├── nfnet.py │ │ ├── pit.py │ │ ├── pnasnet.py │ │ ├── poolformer.py │ │ ├── pruned │ │ │ ├── ecaresnet101d_pruned.txt │ │ │ ├── ecaresnet50d_pruned.txt │ │ │ ├── efficientnet_b1_pruned.txt │ │ │ ├── efficientnet_b2_pruned.txt │ │ │ └── efficientnet_b3_pruned.txt │ │ ├── registry.py │ │ ├── regnet.py │ │ ├── res2net.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnetv2.py │ │ ├── rexnet.py │ │ ├── selecsls.py │ │ ├── senet.py │ │ ├── sequencer.py │ │ ├── sknet.py │ │ ├── swin_transformer.py │ │ ├── swin_transformer_v2.py │ │ ├── swin_transformer_v2_cr.py │ │ ├── tnt.py │ │ ├── tresnet.py │ │ ├── twins.py │ │ ├── vgg.py │ │ ├── visformer.py │ │ ├── vision_transformer.py │ │ ├── vision_transformer_hybrid.py │ │ ├── vision_transformer_relpos.py │ │ ├── volo.py │ │ ├── vovnet.py │ │ ├── xception.py │ │ ├── xception_aligned.py │ │ └── xcit.py │ ├── optim │ │ ├── __init__.py │ │ ├── adabelief.py │ │ ├── adafactor.py │ │ ├── adahessian.py │ │ ├── adamp.py │ │ ├── adamw.py │ │ ├── lamb.py │ │ ├── lars.py │ │ ├── lookahead.py │ │ ├── madgrad.py │ │ ├── nadam.py │ │ ├── nvnovograd.py │ │ ├── optim_factory.py │ │ ├── radam.py │ │ ├── rmsprop_tf.py │ │ └── sgdp.py │ ├── scheduler │ │ ├── __init__.py │ │ ├── cosine_lr.py │ │ ├── multistep_lr.py │ │ ├── plateau_lr.py │ │ ├── poly_lr.py │ │ ├── scheduler.py │ │ ├── scheduler_factory.py │ │ ├── step_lr.py │ │ └── tanh_lr.py │ ├── utils │ │ ├── __init__.py │ │ ├── agc.py │ │ ├── checkpoint_saver.py │ │ ├── clip_grad.py │ │ ├── cuda.py │ │ ├── distributed.py │ │ ├── jit.py │ │ ├── log.py │ │ ├── metrics.py │ │ ├── misc.py │ │ ├── model.py │ │ ├── model_ema.py │ │ ├── random.py │ │ └── summary.py │ └── version.py ├── train.py └── validate.py ├── quantize_bin.py ├── quantize_pixel.py ├── quantize_sample.py ├── requirements.txt ├── util ├── __init__.py ├── pos_embed.py └── utils.py └── validate_cifar.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Customize 2 | *.pth 3 | 4 | # OS X 5 | .DS_Store 6 | .Spotlight-V100 7 | .Trashes 8 | ._* 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | .static_storage/ 66 | .media/ 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /alpaca_embed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import string 4 | 5 | import numpy as np 6 | import openai 7 | import util.utils as utils 8 | from datasketch import MinHash 9 | from nltk import ngrams 10 | from scipy import spatial 11 | from tqdm import tqdm 12 | 13 | # OPENAI embeddings 14 | 15 | 16 | def get_embedding(text, model="text-embedding-ada-002"): 17 | text = text.replace("\n", " ") 18 | return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"] 19 | 20 | 21 | cos_dist = spatial.distance.cosine 22 | 23 | 24 | # MinHash embeddings 25 | # See: https://github.com/Cerebras/modelzoo/blob/main/modelzoo/transformers/data_processing/slimpajama/dedup/to_hash.py 26 | 27 | 28 | def get_features(s, width): 29 | # lower cased 30 | s = s.lower() 31 | # remove punctuation 32 | s = s.translate(str.maketrans("", "", string.punctuation)) 33 | # remove consecutive spaces, newlines, tabs in the middle and in the beginning / end 34 | s = re.sub(r"\s+", " ", s.strip()) 35 | return map(lambda x: "".join(x), ngrams(s, width)) 36 | 37 | 38 | def get_hash(text, width=6, num_perm=128): 39 | m = MinHash(num_perm) 40 | for x in get_features(text, width): 41 | m.update(x.encode("utf8")) 42 | return m 43 | 44 | 45 | def hash_dist(m1, m2): 46 | return m1.jaccard(m2) 47 | 48 | 49 | # Dataset 50 | 51 | 52 | def get_text(x): 53 | if x["input"] == "": 54 | return x["instruction"] + " " + x["output"] 55 | else: 56 | return x["instruction"] + " " + x["input"] + " " + x["output"] 57 | 58 | def merge_data(num_split=6): 59 | data = [] 60 | for i in range(num_split): 61 | npy = np.load(f"alpaca_embeds_{i}.npy") 62 | data.append(npy) 63 | data = np.vstack(data) 64 | np.save("alpaca_embeds.npy", data) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--index", type=int, default=0) 70 | parser.add_argument("--nums", type=int, default=10000) 71 | parser.add_argument("--num_split", type=int, default=6) 72 | parser.add_argument("--merge", action="store_true") 73 | args = parser.parse_args() 74 | 75 | if args.merge: 76 | merge_data(args.num_split) 77 | exit() 78 | 79 | data = utils.jload("alpaca_data.json") 80 | index = args.index 81 | nums = args.nums 82 | data = data[nums * index : nums * (index + 1)] 83 | print(f"Processing {len(data)} examples from index {index * nums}") 84 | 85 | embeds = [] 86 | for i in tqdm(range(len(data))): 87 | text = get_text(data[i]) 88 | embed = get_embedding(text) 89 | embeds.append(np.array(embed)) 90 | embeds = np.vstack(embeds) 91 | 92 | np.save(f"alpaca_embeds_{index}.npy", embeds) 93 | -------------------------------------------------------------------------------- /alpaca_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import util.utils as utils 5 | from dq.methods.methods_utils.submodular_function import GraphCut 6 | from dq.methods.methods_utils.submodular_optimizer import NaiveGreedy 7 | 8 | 9 | def random_sample(data, n=1000): 10 | indices = np.random.choice(len(data), n, replace=False) 11 | return [data[i] for i in indices] 12 | 13 | 14 | def dataset_quantization(data, ratio=0.02, k=50): 15 | n = int(len(data) * ratio) 16 | bins_n = int(n / k) 17 | budget_n = len(data) // k 18 | print(f"total: {len(data)} n: {n}, k: {k}, budget_n: {budget_n}, bins_n: {bins_n}") 19 | 20 | embeddings_original = np.load("alpaca_embeds.npy") 21 | embeddings = embeddings_original.copy() 22 | indices_original = np.arange(len(data)) 23 | indices = indices_original.copy() 24 | 25 | sim_matrix = lambda a, b: embeddings[a] @ embeddings[b].T 26 | 27 | # bin generation 28 | bins = [] 29 | for i in range(k): 30 | print(f"bin {i}/{k}") 31 | submod_f = GraphCut(index=indices, similarity_kernel=sim_matrix) 32 | submod_opt = NaiveGreedy(args=None, index=indices, budget=budget_n) 33 | result_indices = submod_opt.select( 34 | gain_function=submod_f.calc_gain, 35 | update_state=submod_f.update_state, 36 | ) 37 | 38 | bins.append(result_indices) 39 | indices = np.delete(indices_original, np.concatenate(bins)) 40 | embeddings = np.delete(embeddings_original, np.concatenate(bins), axis=0) 41 | 42 | # bin sampling 43 | index = [] 44 | assert len(bins) == k 45 | for i in range(k): 46 | sampled_indices = random_sample(bins[i], n=bins_n) 47 | index.extend(sampled_indices) 48 | data = [data[i] for i in index] 49 | print(f"sampled: {len(data)} examples") 50 | 51 | return data 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--random", action="store_true") 57 | parser.add_argument("--ratio", type=float, default=0.1) 58 | parser.add_argument("--k", type=int, default=10) 59 | args = parser.parse_args() 60 | 61 | data = utils.jload("alpaca_data.json") 62 | total_len = len(data) 63 | 64 | # random sample 65 | if args.random: 66 | n = int(total_len * args.ratio) 67 | data = random_sample(data, n=n) 68 | print(f"Random sample: {len(data)} examples") 69 | utils.jdump(data, f"alpaca_data_random.json") 70 | # DQ 71 | else: 72 | k = args.k 73 | ratio = args.ratio 74 | data = dataset_quantization(data, ratio=ratio, k=k) 75 | print(f"DQ: {len(data)} examples") 76 | utils.jdump(data, f"alpaca_data_dq_k{k}.json") 77 | -------------------------------------------------------------------------------- /data/cifar10/select_indices_CIFAR10_0.125.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/data/cifar10/select_indices_CIFAR10_0.125.npy -------------------------------------------------------------------------------- /dq/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py -------------------------------------------------------------------------------- /dq/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import * 2 | from .imagenet import * 3 | -------------------------------------------------------------------------------- /dq/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch import tensor, long 3 | 4 | 5 | def CIFAR10(data_path): 6 | channel = 3 7 | im_size = (32, 32) 8 | num_classes = 10 9 | mean = [0.4914, 0.4822, 0.4465] 10 | std = [0.2470, 0.2435, 0.2616] 11 | 12 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 13 | dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) 14 | dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform) 15 | class_names = dst_train.classes 16 | dst_train.targets = tensor(dst_train.targets, dtype=long) 17 | dst_test.targets = tensor(dst_test.targets, dtype=long) 18 | return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test 19 | -------------------------------------------------------------------------------- /dq/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch import tensor, long 3 | import os 4 | 5 | 6 | def ImageNet(data_path): 7 | channel = 3 8 | im_size = (224, 224) 9 | num_classes = 1000 10 | mean = [0.485, 0.456, 0.406] 11 | std = [0.229, 0.224, 0.225] 12 | normalize = transforms.Normalize(mean, std) 13 | # dst_train = datasets.ImageNet(data_path, split="train", transform=transforms.Compose([ 14 | # transforms.Resize(256), 15 | # transforms.CenterCrop(224), 16 | # transforms.ToTensor(), 17 | # normalize, 18 | # ])) 19 | # dst_test = datasets.ImageNet(data_path, split="val", transform=transforms.Compose([ 20 | # transforms.Resize(256), 21 | # transforms.CenterCrop(224), 22 | # transforms.ToTensor(), 23 | # normalize, 24 | # ])) 25 | dst_train = datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=transforms.Compose([ 26 | transforms.Resize(256), 27 | transforms.CenterCrop(224), 28 | transforms.ToTensor(), 29 | normalize, 30 | ])) 31 | dst_test = datasets.ImageFolder(root=os.path.join(data_path, 'val'), transform=transforms.Compose([ 32 | transforms.Resize(256), 33 | transforms.CenterCrop(224), 34 | transforms.ToTensor(), 35 | normalize, 36 | ])) 37 | class_names = dst_train.classes 38 | dst_train.targets = tensor(dst_train.targets, dtype=long) 39 | dst_test.targets = tensor(dst_test.targets, dtype=long) 40 | return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test 41 | -------------------------------------------------------------------------------- /dq/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .coresetmethod import * 2 | from .earlytrain import * 3 | from .submodular import * 4 | from .uniform import * 5 | 6 | -------------------------------------------------------------------------------- /dq/methods/coresetmethod.py: -------------------------------------------------------------------------------- 1 | class CoresetMethod(object): 2 | def __init__(self, dst_train, args, fraction, random_seed=None, **kwargs): 3 | self.dst_train = dst_train 4 | self.num_classes = len(dst_train.dataset.classes) 5 | self.fraction = fraction 6 | if fraction < 0 or fraction > 1: 7 | raise ValueError("Illegal Coreset Size.") 8 | self.random_seed = random_seed 9 | self.index = [] 10 | self.args = args 11 | 12 | self.n_train = len(dst_train) 13 | self.coreset_size = round(fraction * self.n_train) 14 | 15 | def select(self, **kwargs): 16 | return 17 | 18 | -------------------------------------------------------------------------------- /dq/methods/methods_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .euclidean import * 2 | from .cossim import * 3 | from .submodular_function import * 4 | from .submodular_optimizer import * 5 | -------------------------------------------------------------------------------- /dq/methods/methods_utils/cossim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def cossim_np(v1, v2): 6 | num = np.dot(v1, v2.T) 7 | denom = np.linalg.norm(v1, axis=1).reshape(-1, 1) * np.linalg.norm(v2, axis=1) 8 | res = num / denom 9 | res[np.isneginf(res)] = 0. 10 | return 0.5 + 0.5 * res 11 | 12 | def cossim_pair_np(v1): 13 | num = np.dot(v1, v1.T) 14 | norm = np.linalg.norm(v1, axis=1) 15 | denom = norm.reshape(-1, 1) * norm 16 | res = num / denom 17 | res[np.isneginf(res)] = 0. 18 | return 0.5 + 0.5 * res 19 | 20 | def cossim(v1, v2): 21 | num = torch.matmul(v1, v2.T) 22 | denom = torch.norm(v1, dim=1).view(-1, 1) * torch.norm(v2, dim=1) 23 | res = num / denom 24 | res[torch.isneginf(res)] = 0. 25 | return 0.5 + 0.5 * res 26 | 27 | def cossim_pair(v1): 28 | num = torch.matmul(v1, v1.T) 29 | norm = torch.norm(v1, dim=1) 30 | denom = norm.view(-1, 1) * norm 31 | res = num / denom 32 | res[torch.isneginf(res)] = 0. 33 | return 0.5 + 0.5 * res -------------------------------------------------------------------------------- /dq/methods/methods_utils/euclidean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def euclidean_dist(x, y): 6 | m, n = x.size(0), y.size(0) 7 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 8 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 9 | dist = xx + yy 10 | dist.addmm_(1, -2, x, y.t()) 11 | dist = dist.clamp(min=1e-12).sqrt() 12 | return dist 13 | 14 | 15 | def euclidean_dist_pair(x): 16 | m = x.size(0) 17 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, m) 18 | dist = xx + xx.t() 19 | dist.addmm_(1, -2, x, x.t()) 20 | dist = dist.clamp(min=1e-12).sqrt() 21 | return dist 22 | 23 | def euclidean_dist_np(x, y): 24 | (rowx, colx) = x.shape 25 | (rowy, coly) = y.shape 26 | xy = np.dot(x, y.T) 27 | x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowy, axis=1) 28 | y2 = np.repeat(np.reshape(np.sum(np.multiply(y, y), axis=1), (rowy, 1)), repeats=rowx, axis=1).T 29 | return np.sqrt(np.clip(x2 + y2 - 2. * xy, 1e-12, None)) 30 | 31 | def euclidean_dist_pair_np(x): 32 | (rowx, colx) = x.shape 33 | xy = np.dot(x, x.T) 34 | x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowx, axis=1) 35 | return np.sqrt(np.clip(x2 + x2.T - 2. * xy, 1e-12, None)) 36 | -------------------------------------------------------------------------------- /dq/methods/uniform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .coresetmethod import CoresetMethod 3 | 4 | 5 | class Uniform(CoresetMethod): 6 | def __init__(self, dst_train, args, fraction=0.5, random_seed=None, balance=False, replace=False, **kwargs): 7 | super().__init__(dst_train, args, fraction, random_seed) 8 | self.balance = balance 9 | self.replace = replace 10 | self.n_train = len(dst_train) 11 | 12 | def select_balance(self): 13 | """The same sampling proportions were used in each class separately.""" 14 | np.random.seed(self.random_seed) 15 | self.index = np.array([], dtype=np.int64) 16 | all_index = np.arange(self.n_train) 17 | for c in range(self.num_classes): 18 | c_index = (self.dst_train.dataset.targets[self.dst_train.indices] == c) 19 | self.index = np.append(self.index, 20 | np.random.choice(all_index[c_index], round(self.fraction * c_index.sum().item()), 21 | replace=self.replace)) 22 | return self.index 23 | 24 | def select_no_balance(self): 25 | np.random.seed(self.random_seed) 26 | self.index = np.random.choice(np.arange(self.n_train), round(self.n_train * self.fraction), 27 | replace=self.replace) 28 | 29 | return self.index 30 | 31 | def select(self, **kwargs): 32 | return {"indices": self.select_balance() if self.balance else self.select_no_balance()} 33 | -------------------------------------------------------------------------------- /dq/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .vit import * 3 | -------------------------------------------------------------------------------- /dq/nets/nets_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .parallel import * 2 | from .recorder import * -------------------------------------------------------------------------------- /dq/nets/nets_utils/parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn import DataParallel 2 | 3 | 4 | class MyDataParallel(DataParallel): 5 | def __getattr__(self, name): 6 | try: 7 | return super().__getattr__(name) 8 | except AttributeError: 9 | return getattr(self.module, name) 10 | def __setattr__(self, name, value): 11 | try: 12 | if name == "no_grad": 13 | return setattr(self.module, name, value) 14 | return super().__setattr__(name, value) 15 | except AttributeError: 16 | return setattr(self.module, name, value) 17 | -------------------------------------------------------------------------------- /dq/nets/nets_utils/recorder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class EmbeddingRecorder(nn.Module): 5 | def __init__(self, record_embedding: bool = False): 6 | super().__init__() 7 | self.record_embedding = record_embedding 8 | 9 | def forward(self, x): 10 | if self.record_embedding: 11 | self.embedding = x 12 | return x 13 | 14 | def __enter__(self): 15 | self.record_embedding = True 16 | 17 | def __exit__(self, exc_type, exc_val, exc_tb): 18 | self.record_embedding = False 19 | -------------------------------------------------------------------------------- /figs/effects.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/figs/effects.png -------------------------------------------------------------------------------- /figs/pipeline-classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/figs/pipeline-classification.png -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/pretrained/.gitkeep -------------------------------------------------------------------------------- /pytorch_image_models/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /pytorch_image_models/.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: rwightman 3 | -------------------------------------------------------------------------------- /pytorch_image_models/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting 4 | features, the discussion forum is available for asking questions or seeking help 5 | from the community. 6 | title: "[BUG] Issue title..." 7 | labels: bug 8 | assignees: rwightman 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. 18 | 2. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Windows 10, Ubuntu 18.04] 28 | - This repository version [e.g. pip 0.3.1 or commit ref] 29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /pytorch_image_models/.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Discussions 4 | url: https://github.com/rwightman/pytorch-image-models/discussions 5 | about: Issues are for features and bugs. Questions can be asked in Discussions. 6 | -------------------------------------------------------------------------------- /pytorch_image_models/.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. Issues are for reporting bugs or requesting 4 | features, the discussion forum is available for asking questions or seeking help 5 | from the community. 6 | title: "[FEATURE] Feature title..." 7 | labels: enhancement 8 | assignees: '' 9 | 10 | --- 11 | 12 | **Is your feature request related to a problem? Please describe.** 13 | A clear and concise description of what the problem is. 14 | 15 | **Describe the solution you'd like** 16 | A clear and concise description of what you want to happen. 17 | 18 | **Describe alternatives you've considered** 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | **Additional context** 22 | Add any other context or screenshots about the feature request here. 23 | -------------------------------------------------------------------------------- /pytorch_image_models/.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | OMP_NUM_THREADS: 2 11 | MKL_NUM_THREADS: 2 12 | 13 | jobs: 14 | test: 15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, macOS-latest] 19 | python: ['3.9'] 20 | torch: ['1.10.0'] 21 | torchvision: ['0.11.1'] 22 | runs-on: ${{ matrix.os }} 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | - name: Set up Python ${{ matrix.python }} 27 | uses: actions/setup-python@v1 28 | with: 29 | python-version: ${{ matrix.python }} 30 | - name: Install testing dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install pytest pytest-timeout pytest-xdist expecttest 34 | - name: Install torch on mac 35 | if: startsWith(matrix.os, 'macOS') 36 | run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} 37 | - name: Install torch on ubuntu 38 | if: startsWith(matrix.os, 'ubuntu') 39 | run: | 40 | pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html 41 | sudo apt update 42 | sudo apt install -y google-perftools 43 | - name: Install requirements 44 | run: | 45 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 46 | pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.1.0 47 | - name: Run tests 48 | env: 49 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 50 | run: | 51 | export PYTHONDONTWRITEBYTECODE=1 52 | pytest -vv --forked --durations=0 ./tests 53 | -------------------------------------------------------------------------------- /pytorch_image_models/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | output/ 104 | 105 | # PyTorch weights 106 | *.tar 107 | *.pth 108 | *.pt 109 | *.gz 110 | Untitled.ipynb 111 | Testing notebook.ipynb 112 | -------------------------------------------------------------------------------- /pytorch_image_models/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include timm/models/pruned/*.txt 2 | 3 | -------------------------------------------------------------------------------- /pytorch_image_models/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Image Models 2 | 3 | Employ `timm` for ImageNet evaluation. 4 | 5 | ## Training 6 | 7 | Different from standard `timm` scripts, we separate the root directory of train and eval data, as the images are reconstructed in the quantization process. 8 | Besides, you can change the `select_indices` parameter to specify the sample-level quantized sample indices. Multiple indices can be specified here. 9 | We use the `ResNet50` model as the template here. For the other models, you can refer to the [timm documentation](https://rwightman.github.io/pytorch-image-models/) and conduct the above modifications. 10 | 11 | ```bash 12 | sh distributed_train.sh 9 [TRAIN_ROOT] [EVAL_ROOT] --select-indices [INDICES1] [INDICES2] --output [OUTPUT_DIR] --model resnet50 --sched cosine --epochs 260 --lr 0.6 --reprob 0.6 --remode pixel --batch-size 128 --amp --aug-splits 3 -aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce 13 | ``` 14 | 15 | ## Getting Started (Documentation) 16 | 17 | Current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. 18 | 19 | [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. 20 | 21 | [timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. 22 | 23 | [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. 24 | -------------------------------------------------------------------------------- /pytorch_image_models/clean_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ Checkpoint Cleaning Script 3 | 4 | Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. 5 | and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 6 | calculation for model zoo compatibility. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import torch 11 | import argparse 12 | import os 13 | import hashlib 14 | import shutil 15 | from collections import OrderedDict 16 | from timm.models.helpers import load_state_dict 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') 19 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 20 | help='path to latest checkpoint (default: none)') 21 | parser.add_argument('--output', default='', type=str, metavar='PATH', 22 | help='output path') 23 | parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', 24 | help='use ema version of weights if present') 25 | parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', 26 | help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') 27 | 28 | _TEMP_NAME = './_checkpoint.pth' 29 | 30 | 31 | def main(): 32 | args = parser.parse_args() 33 | 34 | if os.path.exists(args.output): 35 | print("Error: Output filename ({}) already exists.".format(args.output)) 36 | exit(1) 37 | 38 | clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) 39 | 40 | 41 | def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): 42 | # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save 43 | if checkpoint and os.path.isfile(checkpoint): 44 | print("=> Loading checkpoint '{}'".format(checkpoint)) 45 | state_dict = load_state_dict(checkpoint, use_ema=use_ema) 46 | new_state_dict = {} 47 | for k, v in state_dict.items(): 48 | if clean_aux_bn and 'aux_bn' in k: 49 | # If all aux_bn keys are removed, the SplitBN layers will end up as normal and 50 | # load with the unmodified model using BatchNorm2d. 51 | continue 52 | name = k[7:] if k.startswith('module.') else k 53 | new_state_dict[name] = v 54 | print("=> Loaded state_dict from '{}'".format(checkpoint)) 55 | 56 | try: 57 | torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) 58 | except: 59 | torch.save(new_state_dict, _TEMP_NAME) 60 | 61 | with open(_TEMP_NAME, 'rb') as f: 62 | sha_hash = hashlib.sha256(f.read()).hexdigest() 63 | 64 | if output: 65 | checkpoint_root, checkpoint_base = os.path.split(output) 66 | checkpoint_base = os.path.splitext(checkpoint_base)[0] 67 | else: 68 | checkpoint_root = '' 69 | checkpoint_base = os.path.splitext(checkpoint)[0] 70 | final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' 71 | shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) 72 | print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) 73 | return final_filename 74 | else: 75 | print("Error: Checkpoint ({}) doesn't exist".format(checkpoint)) 76 | return '' 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /pytorch_image_models/distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Welcome 4 | 5 | Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`. 6 | 7 | For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora). 8 | 9 | ## Install 10 | 11 | The library can be installed with pip: 12 | 13 | ``` 14 | pip install timm 15 | ``` 16 | 17 | I update the PyPi (pip) packages when I'm confident there are no significant model regressions from previous releases. If you want to pip install the bleeding edge from GitHub, use: 18 | ``` 19 | pip install git+https://github.com/rwightman/pytorch-image-models.git 20 | ``` 21 | 22 | !!! info "Conda Environment" 23 | All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10 24 | 25 | Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. 26 | 27 | PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code. 28 | 29 | I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: 30 | ``` 31 | conda create -n torch-env 32 | conda activate torch-env 33 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch 34 | conda install pyyaml 35 | ``` 36 | 37 | ## Load a Pretrained Model 38 | 39 | Pretrained models can be loaded using `timm.create_model` 40 | 41 | ```python 42 | import timm 43 | 44 | m = timm.create_model('mobilenetv3_large_100', pretrained=True) 45 | m.eval() 46 | ``` 47 | 48 | ## List Models with Pretrained Weights 49 | ```python 50 | import timm 51 | from pprint import pprint 52 | model_names = timm.list_models(pretrained=True) 53 | pprint(model_names) 54 | >>> ['adv_inception_v3', 55 | 'cspdarknet53', 56 | 'cspresnext50', 57 | 'densenet121', 58 | 'densenet161', 59 | 'densenet169', 60 | 'densenet201', 61 | 'densenetblur121d', 62 | 'dla34', 63 | 'dla46_c', 64 | ... 65 | ] 66 | ``` 67 | 68 | ## List Model Architectures by Wildcard 69 | ```python 70 | import timm 71 | from pprint import pprint 72 | model_names = timm.list_models('*resne*t*') 73 | pprint(model_names) 74 | >>> ['cspresnet50', 75 | 'cspresnet50d', 76 | 'cspresnet50w', 77 | 'cspresnext50', 78 | ... 79 | ] 80 | ``` 81 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/javascripts/tables.js: -------------------------------------------------------------------------------- 1 | app.location$.subscribe(function() { 2 | var tables = document.querySelectorAll("article table") 3 | tables.forEach(function(table) { 4 | new Tablesort(table) 5 | }) 6 | }) -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.pages: -------------------------------------------------------------------------------- 1 | title: Model Pages -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/code_snippets.md: -------------------------------------------------------------------------------- 1 | ## How do I use this model on an image? 2 | To load a pretrained model: 3 | 4 | ```python 5 | import timm 6 | model = timm.create_model('{{ model_name }}', pretrained=True) 7 | model.eval() 8 | ``` 9 | 10 | To load and preprocess the image: 11 | ```python 12 | import urllib 13 | from PIL import Image 14 | from timm.data import resolve_data_config 15 | from timm.data.transforms_factory import create_transform 16 | 17 | config = resolve_data_config({}, model=model) 18 | transform = create_transform(**config) 19 | 20 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") 21 | urllib.request.urlretrieve(url, filename) 22 | img = Image.open(filename).convert('RGB') 23 | tensor = transform(img).unsqueeze(0) # transform and add batch dimension 24 | ``` 25 | 26 | To get the model predictions: 27 | ```python 28 | import torch 29 | with torch.no_grad(): 30 | out = model(tensor) 31 | probabilities = torch.nn.functional.softmax(out[0], dim=0) 32 | print(probabilities.shape) 33 | # prints: torch.Size([1000]) 34 | ``` 35 | 36 | To get the top-5 predictions class names: 37 | ```python 38 | # Get imagenet class mappings 39 | url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt") 40 | urllib.request.urlretrieve(url, filename) 41 | with open("imagenet_classes.txt", "r") as f: 42 | categories = [s.strip() for s in f.readlines()] 43 | 44 | # Print top categories per image 45 | top5_prob, top5_catid = torch.topk(probabilities, 5) 46 | for i in range(top5_prob.size(0)): 47 | print(categories[top5_catid[i]], top5_prob[i].item()) 48 | # prints class names and probabilities like: 49 | # [('Samoyed', 0.6425196528434753), ('Pomeranian', 0.04062102362513542), ('keeshond', 0.03186424449086189), ('white wolf', 0.01739676296710968), ('Eskimo dog', 0.011717947199940681)] 50 | ``` 51 | 52 | Replace the model name with the variant you want to use, e.g. `{{ model_name }}`. You can find the IDs in the model summaries at the top of this page. 53 | 54 | To extract image features with this model, follow the [timm feature extraction examples](https://rwightman.github.io/pytorch-image-models/feature_extraction/), just change the name of the model you want to use. 55 | 56 | ## How do I finetune this model? 57 | You can finetune any of the pre-trained models just by changing the classifier (the last layer). 58 | ```python 59 | model = timm.create_model('{{ model_name }}', pretrained=True, num_classes=NUM_FINETUNE_CLASSES) 60 | ``` 61 | To finetune on your own dataset, you have to write a training loop or adapt [timm's training 62 | script](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) to use your dataset. 63 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/generate_readmes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script to generate the model-index files in `models` from the templates in `.templates/models`. 3 | """ 4 | 5 | import argparse 6 | from pathlib import Path 7 | 8 | from jinja2 import Environment, FileSystemLoader 9 | 10 | import modelindex 11 | 12 | 13 | def generate_readmes(templates_path: Path, dest_path: Path): 14 | """Add the code snippet template to the readmes""" 15 | readme_templates_path = templates_path / "models" 16 | code_template_path = templates_path / "code_snippets.md" 17 | 18 | env = Environment( 19 | loader=FileSystemLoader([readme_templates_path, readme_templates_path.parent]), 20 | ) 21 | 22 | for readme in readme_templates_path.iterdir(): 23 | if readme.suffix == ".md": 24 | template = env.get_template(readme.name) 25 | 26 | # get the first model_name for this model family 27 | mi = modelindex.load(str(readme)) 28 | model_name = mi.models[0].name 29 | 30 | full_content = template.render(model_name=model_name) 31 | 32 | # generate full_readme 33 | with open(dest_path / readme.name, "w") as f: 34 | f.write(full_content) 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Model index generation config") 39 | parser.add_argument( 40 | "-t", 41 | "--templates", 42 | default=Path(__file__).parent / ".templates", 43 | type=str, 44 | help="Location of the markdown templates", 45 | ) 46 | parser.add_argument( 47 | "-d", 48 | "--dest", 49 | default=Path(__file__).parent / "models", 50 | type=str, 51 | help="Destination folder that contains the generated model-index files.", 52 | ) 53 | args = parser.parse_args() 54 | templates_path = Path(args.templates) 55 | dest_readmes_path = Path(args.dest) 56 | 57 | generate_readmes( 58 | templates_path, 59 | dest_readmes_path, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/csp-darknet.md: -------------------------------------------------------------------------------- 1 | # CSP-DarkNet 2 | 3 | **CSPDarknet53** is a convolutional neural network and backbone for object detection that uses [DarkNet-53](https://paperswithcode.com/method/darknet-53). It employs a CSPNet strategy to partition the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | This CNN is used as the backbone for [YOLOv4](https://paperswithcode.com/method/yolov4). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{bochkovskiy2020yolov4, 17 | title={YOLOv4: Optimal Speed and Accuracy of Object Detection}, 18 | author={Alexey Bochkovskiy and Chien-Yao Wang and Hong-Yuan Mark Liao}, 19 | year={2020}, 20 | eprint={2004.10934}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 82 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/csp-resnet.md: -------------------------------------------------------------------------------- 1 | # CSP-ResNet 2 | 3 | **CSPResNet** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNet](https://paperswithcode.com/method/resnet). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{wang2019cspnet, 15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN}, 16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh}, 17 | year={2019}, 18 | eprint={1911.11929}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 77 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/csp-resnext.md: -------------------------------------------------------------------------------- 1 | # CSP-ResNeXt 2 | 3 | **CSPResNeXt** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNeXt](https://paperswithcode.com/method/resnext). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{wang2019cspnet, 15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN}, 16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh}, 17 | year={2019}, 18 | eprint={1911.11929}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 78 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/ese-vovnet.md: -------------------------------------------------------------------------------- 1 | # ESE-VoVNet 2 | 3 | **VoVNet** is a convolutional neural network that seeks to make [DenseNet](https://paperswithcode.com/method/densenet) more efficient by concatenating all features only once in the last feature map, which makes input size constant and enables enlarging new output channel. 4 | 5 | Read about [one-shot aggregation here](https://paperswithcode.com/method/one-shot-aggregation). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{lee2019energy, 17 | title={An Energy and GPU-Computation Efficient Backbone Network for Real-Time Object Detection}, 18 | author={Youngwan Lee and Joong-won Hwang and Sangrok Lee and Yuseok Bae and Jongyoul Park}, 19 | year={2019}, 20 | eprint={1904.09730}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 93 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/fbnet.md: -------------------------------------------------------------------------------- 1 | # FBNet 2 | 3 | **FBNet** is a type of convolutional neural architectures discovered through [DNAS](https://paperswithcode.com/method/dnas) neural architecture search. It utilises a basic type of image model block inspired by [MobileNetv2](https://paperswithcode.com/method/mobilenetv2) that utilises depthwise convolutions and an inverted residual structure (see components). 4 | 5 | The principal building block is the [FBNet Block](https://paperswithcode.com/method/fbnet-block). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{wu2019fbnet, 17 | title={FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search}, 18 | author={Bichen Wu and Xiaoliang Dai and Peizhao Zhang and Yanghan Wang and Fei Sun and Yiming Wu and Yuandong Tian and Peter Vajda and Yangqing Jia and Kurt Keutzer}, 19 | year={2019}, 20 | eprint={1812.03443}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 77 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/gloun-inception-v3.md: -------------------------------------------------------------------------------- 1 | # (Gluon) Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @article{DBLP:journals/corr/SzegedyVISW15, 17 | author = {Christian Szegedy and 18 | Vincent Vanhoucke and 19 | Sergey Ioffe and 20 | Jonathon Shlens and 21 | Zbigniew Wojna}, 22 | title = {Rethinking the Inception Architecture for Computer Vision}, 23 | journal = {CoRR}, 24 | volume = {abs/1512.00567}, 25 | year = {2015}, 26 | url = {http://arxiv.org/abs/1512.00567}, 27 | archivePrefix = {arXiv}, 28 | eprint = {1512.00567}, 29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 31 | bibsource = {dblp computer science bibliography, https://dblp.org} 32 | } 33 | ``` 34 | 35 | 79 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/gloun-senet.md: -------------------------------------------------------------------------------- 1 | # (Gluon) SENet 2 | 3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration. 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{hu2019squeezeandexcitation, 17 | title={Squeeze-and-Excitation Networks}, 18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu}, 19 | year={2019}, 20 | eprint={1709.01507}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 64 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/gloun-xception.md: -------------------------------------------------------------------------------- 1 | # (Gluon) Xception 2 | 3 | **Xception** is a convolutional neural network architecture that relies solely on [depthwise separable convolution](https://paperswithcode.com/method/depthwise-separable-convolution) layers. 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{chollet2017xception, 17 | title={Xception: Deep Learning with Depthwise Separable Convolutions}, 18 | author={François Chollet}, 19 | year={2017}, 20 | eprint={1610.02357}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 67 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/inception-resnet-v2.md: -------------------------------------------------------------------------------- 1 | # Inception ResNet v2 2 | 3 | **Inception-ResNet-v2** is a convolutional neural architecture that builds on the Inception family of architectures but incorporates [residual connections](https://paperswithcode.com/method/residual-connection) (replacing the filter concatenation stage of the Inception architecture). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{szegedy2016inceptionv4, 15 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning}, 16 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi}, 17 | year={2016}, 18 | eprint={1602.07261}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 73 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/inception-v3.md: -------------------------------------------------------------------------------- 1 | # Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{DBLP:journals/corr/SzegedyVISW15, 15 | author = {Christian Szegedy and 16 | Vincent Vanhoucke and 17 | Sergey Ioffe and 18 | Jonathon Shlens and 19 | Zbigniew Wojna}, 20 | title = {Rethinking the Inception Architecture for Computer Vision}, 21 | journal = {CoRR}, 22 | volume = {abs/1512.00567}, 23 | year = {2015}, 24 | url = {http://arxiv.org/abs/1512.00567}, 25 | archivePrefix = {arXiv}, 26 | eprint = {1512.00567}, 27 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 28 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 29 | bibsource = {dblp computer science bibliography, https://dblp.org} 30 | } 31 | ``` 32 | 33 | 86 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/inception-v4.md: -------------------------------------------------------------------------------- 1 | # Inception v4 2 | 3 | **Inception-v4** is a convolutional neural network architecture that builds on previous iterations of the Inception family by simplifying the architecture and using more inception modules than [Inception-v3](https://paperswithcode.com/method/inception-v3). 4 | {% include 'code_snippets.md' %} 5 | 6 | ## How do I train this model? 7 | 8 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 9 | 10 | ## Citation 11 | 12 | ```BibTeX 13 | @misc{szegedy2016inceptionv4, 14 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning}, 15 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi}, 16 | year={2016}, 17 | eprint={1602.07261}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.CV} 20 | } 21 | ``` 22 | 23 | 72 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/legacy-senet.md: -------------------------------------------------------------------------------- 1 | # (Legacy) SENet 2 | 3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration. 4 | 5 | The weights from this model were ported from Gluon. 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{hu2019squeezeandexcitation, 17 | title={Squeeze-and-Excitation Networks}, 18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu}, 19 | year={2019}, 20 | eprint={1709.01507}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 75 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/nasnet.md: -------------------------------------------------------------------------------- 1 | # NASNet 2 | 3 | **NASNet** is a type of convolutional neural network discovered through neural architecture search. The building blocks consist of normal and reduction cells. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{zoph2018learning, 15 | title={Learning Transferable Architectures for Scalable Image Recognition}, 16 | author={Barret Zoph and Vijay Vasudevan and Jonathon Shlens and Quoc V. Le}, 17 | year={2018}, 18 | eprint={1707.07012}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 71 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/pnasnet.md: -------------------------------------------------------------------------------- 1 | # PNASNet 2 | 3 | **Progressive Neural Architecture Search**, or **PNAS**, is a method for learning the structure of convolutional neural networks (CNNs). It uses a sequential model-based optimization (SMBO) strategy, where we search the space of cell structures, starting with simple (shallow) models and progressing to complex ones, pruning out unpromising structures as we go. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{liu2018progressive, 15 | title={Progressive Neural Architecture Search}, 16 | author={Chenxi Liu and Barret Zoph and Maxim Neumann and Jonathon Shlens and Wei Hua and Li-Jia Li and Li Fei-Fei and Alan Yuille and Jonathan Huang and Kevin Murphy}, 17 | year={2018}, 18 | eprint={1712.00559}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 72 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/res2next.md: -------------------------------------------------------------------------------- 1 | # Res2NeXt 2 | 3 | **Res2NeXt** is an image model that employs a variation on [ResNeXt](https://paperswithcode.com/method/resnext) bottleneck residual blocks. The motivation is to be able to represent features at multiple scales. This is achieved through a novel building block for CNNs that constructs hierarchical residual-like connections within one single residual block. This represents multi-scale features at a granular level and increases the range of receptive fields for each network layer. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{Gao_2021, 15 | title={Res2Net: A New Multi-Scale Backbone Architecture}, 16 | volume={43}, 17 | ISSN={1939-3539}, 18 | url={http://dx.doi.org/10.1109/TPAMI.2019.2938758}, 19 | DOI={10.1109/tpami.2019.2938758}, 20 | number={2}, 21 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 22 | publisher={Institute of Electrical and Electronics Engineers (IEEE)}, 23 | author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip}, 24 | year={2021}, 25 | month={Feb}, 26 | pages={652–662} 27 | } 28 | ``` 29 | 30 | 76 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/skresnext.md: -------------------------------------------------------------------------------- 1 | # SK-ResNeXt 2 | 3 | **SK ResNeXt** is a variant of a [ResNeXt](https://www.paperswithcode.com/method/resnext) that employs a [Selective Kernel](https://paperswithcode.com/method/selective-kernel) unit. In general, all the large kernel convolutions in the original bottleneck blocks in ResNext are replaced by the proposed [SK convolutions](https://paperswithcode.com/method/selective-kernel-convolution), enabling the network to choose appropriate receptive field sizes in an adaptive manner. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{li2019selective, 15 | title={Selective Kernel Networks}, 16 | author={Xiang Li and Wenhai Wang and Xiaolin Hu and Jian Yang}, 17 | year={2019}, 18 | eprint={1903.06586}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 71 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/spnasnet.md: -------------------------------------------------------------------------------- 1 | # SPNASNet 2 | 3 | **Single-Path NAS** is a novel differentiable NAS method for designing hardware-efficient ConvNets in less than 4 hours. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{stamoulis2019singlepath, 15 | title={Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours}, 16 | author={Dimitrios Stamoulis and Ruizhou Ding and Di Wang and Dimitrios Lymberopoulos and Bodhi Priyantha and Jie Liu and Diana Marculescu}, 17 | year={2019}, 18 | eprint={1904.02877}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.LG} 21 | } 22 | ``` 23 | 24 | 63 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/tf-inception-v3.md: -------------------------------------------------------------------------------- 1 | # (Tensorflow) Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | The weights from this model were ported from [Tensorflow/Models](https://github.com/tensorflow/models). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @article{DBLP:journals/corr/SzegedyVISW15, 17 | author = {Christian Szegedy and 18 | Vincent Vanhoucke and 19 | Sergey Ioffe and 20 | Jonathon Shlens and 21 | Zbigniew Wojna}, 22 | title = {Rethinking the Inception Architecture for Computer Vision}, 23 | journal = {CoRR}, 24 | volume = {abs/1512.00567}, 25 | year = {2015}, 26 | url = {http://arxiv.org/abs/1512.00567}, 27 | archivePrefix = {arXiv}, 28 | eprint = {1512.00567}, 29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 31 | bibsource = {dblp computer science bibliography, https://dblp.org} 32 | } 33 | ``` 34 | 35 | 88 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/models/.templates/models/wide-resnet.md: -------------------------------------------------------------------------------- 1 | # Wide ResNet 2 | 3 | **Wide Residual Networks** are a variant on [ResNets](https://paperswithcode.com/method/resnet) where we decrease depth and increase the width of residual networks. This is achieved through the use of [wide residual blocks](https://paperswithcode.com/method/wide-residual-block). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{DBLP:journals/corr/ZagoruykoK16, 15 | author = {Sergey Zagoruyko and 16 | Nikos Komodakis}, 17 | title = {Wide Residual Networks}, 18 | journal = {CoRR}, 19 | volume = {abs/1605.07146}, 20 | year = {2016}, 21 | url = {http://arxiv.org/abs/1605.07146}, 22 | archivePrefix = {arXiv}, 23 | eprint = {1605.07146}, 24 | timestamp = {Mon, 13 Aug 2018 16:46:42 +0200}, 25 | biburl = {https://dblp.org/rec/journals/corr/ZagoruykoK16.bib}, 26 | bibsource = {dblp computer science bibliography, https://dblp.org} 27 | } 28 | ``` 29 | 30 | 103 | -------------------------------------------------------------------------------- /pytorch_image_models/docs/scripts.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release. 3 | 4 | The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on 5 | [NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples). 6 | 7 | ## Training Script 8 | 9 | The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder. 10 | 11 | To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value: 12 | 13 | `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` 14 | 15 | NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. 16 | 17 | ## Validation / Inference Scripts 18 | 19 | Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script. 20 | 21 | To validate with the model's pretrained weights (if they exist): 22 | 23 | `python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained` 24 | 25 | To run inference from a checkpoint: 26 | 27 | `python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar` -------------------------------------------------------------------------------- /pytorch_image_models/hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | from timm.models import registry 3 | 4 | globals().update(registry._model_entrypoints) 5 | -------------------------------------------------------------------------------- /pytorch_image_models/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: 'Pytorch Image Models' 2 | site_description: 'Pretained Image Recognition Models' 3 | repo_name: 'rwightman/pytorch-image-models' 4 | repo_url: 'https://github.com/rwightman/pytorch-image-models' 5 | nav: 6 | - index.md 7 | - models.md 8 | - ... | models/*.md 9 | - results.md 10 | - scripts.md 11 | - training_hparam_examples.md 12 | - feature_extraction.md 13 | - changes.md 14 | - archived_changes.md 15 | theme: 16 | name: 'material' 17 | feature: 18 | tabs: false 19 | extra_javascript: 20 | - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML' 21 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js 22 | - javascripts/tables.js 23 | markdown_extensions: 24 | - codehilite: 25 | linenums: true 26 | - admonition 27 | - pymdownx.arithmatex 28 | - pymdownx.betterem: 29 | smart_enable: all 30 | - pymdownx.caret 31 | - pymdownx.critic 32 | - pymdownx.details 33 | - pymdownx.emoji: 34 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 35 | - pymdownx.inlinehilite 36 | - pymdownx.magiclink 37 | - pymdownx.mark 38 | - pymdownx.smartsymbols 39 | - pymdownx.superfences 40 | - pymdownx.tasklist: 41 | custom_checkbox: true 42 | - pymdownx.tilde 43 | - mdx_truly_sane_lists 44 | plugins: 45 | - search 46 | - awesome-pages 47 | -------------------------------------------------------------------------------- /pytorch_image_models/model-index.yml: -------------------------------------------------------------------------------- 1 | Import: 2 | - ./docs/models/*.md 3 | Library: 4 | Name: PyTorch Image Models 5 | Headline: PyTorch image models, scripts, pretrained weights 6 | Website: https://rwightman.github.io/pytorch-image-models/ 7 | Repository: https://github.com/rwightman/pytorch-image-models 8 | Docs: https://rwightman.github.io/pytorch-image-models/ 9 | README: "# PyTorch Image Models\r\n\r\nPyTorch Image Models (TIMM) is a library\ 10 | \ for state-of-the-art image classification. With this library you can:\r\n\r\n\ 11 | - Choose from 300+ pre-trained state-of-the-art image classification models.\r\ 12 | \n- Train models afresh on research datasets such as ImageNet using provided scripts.\r\ 13 | \n- Finetune pre-trained models on your own datasets, including the latest cutting\ 14 | \ edge models." 15 | -------------------------------------------------------------------------------- /pytorch_image_models/requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | mdx_truly_sane_lists 4 | mkdocs-awesome-pages-plugin -------------------------------------------------------------------------------- /pytorch_image_models/requirements-modelindex.txt: -------------------------------------------------------------------------------- 1 | model-index==0.1.10 2 | jinja2==2.11.3 3 | -------------------------------------------------------------------------------- /pytorch_image_models/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | pyyaml 4 | -------------------------------------------------------------------------------- /pytorch_image_models/results/generate_csv_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | results = { 6 | 'results-imagenet.csv': [ 7 | 'results-imagenet-real.csv', 8 | 'results-imagenetv2-matched-frequency.csv', 9 | 'results-sketch.csv' 10 | ], 11 | 'results-imagenet-a-clean.csv': [ 12 | 'results-imagenet-a.csv', 13 | ], 14 | 'results-imagenet-r-clean.csv': [ 15 | 'results-imagenet-r.csv', 16 | ], 17 | } 18 | 19 | 20 | def diff(base_df, test_csv): 21 | base_models = base_df['model'].values 22 | test_df = pd.read_csv(test_csv) 23 | test_models = test_df['model'].values 24 | 25 | rank_diff = np.zeros_like(test_models, dtype='object') 26 | top1_diff = np.zeros_like(test_models, dtype='object') 27 | top5_diff = np.zeros_like(test_models, dtype='object') 28 | 29 | for rank, model in enumerate(test_models): 30 | if model in base_models: 31 | base_rank = int(np.where(base_models == model)[0]) 32 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank] 33 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank] 34 | 35 | # rank_diff 36 | if rank == base_rank: 37 | rank_diff[rank] = f'0' 38 | elif rank > base_rank: 39 | rank_diff[rank] = f'-{rank - base_rank}' 40 | else: 41 | rank_diff[rank] = f'+{base_rank - rank}' 42 | 43 | # top1_diff 44 | if top1_d >= .0: 45 | top1_diff[rank] = f'+{top1_d:.3f}' 46 | else: 47 | top1_diff[rank] = f'-{abs(top1_d):.3f}' 48 | 49 | # top5_diff 50 | if top5_d >= .0: 51 | top5_diff[rank] = f'+{top5_d:.3f}' 52 | else: 53 | top5_diff[rank] = f'-{abs(top5_d):.3f}' 54 | 55 | else: 56 | rank_diff[rank] = '' 57 | top1_diff[rank] = '' 58 | top5_diff[rank] = '' 59 | 60 | test_df['top1_diff'] = top1_diff 61 | test_df['top5_diff'] = top5_diff 62 | test_df['rank_diff'] = rank_diff 63 | 64 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) 65 | test_df.sort_values('top1', ascending=False, inplace=True) 66 | test_df.to_csv(test_csv, index=False, float_format='%.3f') 67 | 68 | 69 | for base_results, test_results in results.items(): 70 | base_df = pd.read_csv(base_results) 71 | base_df.sort_values('top1', ascending=False, inplace=True) 72 | for test_csv in test_results: 73 | diff(base_df, test_csv) 74 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format) 75 | base_df.to_csv(base_results, index=False, float_format='%.3f') 76 | -------------------------------------------------------------------------------- /pytorch_image_models/results/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /pytorch_image_models/results/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /pytorch_image_models/results/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /pytorch_image_models/results/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /pytorch_image_models/setup.cfg: -------------------------------------------------------------------------------- 1 | [dist_conda] 2 | 3 | conda_name_differences = 'torch:pytorch' 4 | channels = pytorch 5 | noarch = True 6 | -------------------------------------------------------------------------------- /pytorch_image_models/setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('timm/version.py').read()) 14 | setup( 15 | name='timm', 16 | version=__version__, 17 | description='(Unofficial) PyTorch Image Models', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/pytorch-image-models', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', 44 | packages=find_packages(exclude=['convert', 'tests', 'results']), 45 | include_package_data=True, 46 | install_requires=['torch >= 1.4', 'torchvision'], 47 | python_requires='>=3.6', 48 | ) 49 | -------------------------------------------------------------------------------- /pytorch_image_models/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/pytorch_image_models/tests/__init__.py -------------------------------------------------------------------------------- /pytorch_image_models/tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import platform 5 | import os 6 | 7 | from timm.models.layers import create_act_layer, get_act_layer, set_layer_config 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, act_layer="relu", inplace=True): 12 | super(MLP, self).__init__() 13 | self.fc1 = nn.Linear(1000, 100) 14 | self.act = create_act_layer(act_layer, inplace=inplace) 15 | self.fc2 = nn.Linear(100, 10) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.act(x) 20 | x = self.fc2(x) 21 | return x 22 | 23 | 24 | def _run_act_layer_grad(act_type, inplace=True): 25 | x = torch.rand(10, 1000) * 10 26 | m = MLP(act_layer=act_type, inplace=inplace) 27 | 28 | def _run(x, act_layer=''): 29 | if act_layer: 30 | # replace act layer if set 31 | m.act = create_act_layer(act_layer, inplace=inplace) 32 | out = m(x) 33 | l = (out - 0).pow(2).sum() 34 | return l 35 | 36 | out_me = _run(x) 37 | 38 | with set_layer_config(scriptable=True): 39 | out_jit = _run(x, act_type) 40 | 41 | assert torch.isclose(out_jit, out_me) 42 | 43 | with set_layer_config(no_jit=True): 44 | out_basic = _run(x, act_type) 45 | 46 | assert torch.isclose(out_basic, out_jit) 47 | 48 | 49 | def test_swish_grad(): 50 | for _ in range(100): 51 | _run_act_layer_grad('swish') 52 | 53 | 54 | def test_mish_grad(): 55 | for _ in range(100): 56 | _run_act_layer_grad('mish') 57 | 58 | 59 | def test_hard_sigmoid_grad(): 60 | for _ in range(100): 61 | _run_act_layer_grad('hard_sigmoid', inplace=None) 62 | 63 | 64 | def test_hard_swish_grad(): 65 | for _ in range(100): 66 | _run_act_layer_grad('hard_swish') 67 | 68 | 69 | def test_hard_mish_grad(): 70 | for _ in range(100): 71 | _run_act_layer_grad('hard_mish') 72 | -------------------------------------------------------------------------------- /pytorch_image_models/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.batchnorm import BatchNorm2d 2 | from torchvision.ops.misc import FrozenBatchNorm2d 3 | 4 | import timm 5 | from timm.utils.model import freeze, unfreeze 6 | 7 | 8 | def test_freeze_unfreeze(): 9 | model = timm.create_model('resnet18') 10 | 11 | # Freeze all 12 | freeze(model) 13 | # Check top level module 14 | assert model.fc.weight.requires_grad == False 15 | # Check submodule 16 | assert model.layer1[0].conv1.weight.requires_grad == False 17 | # Check BN 18 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 19 | 20 | # Unfreeze all 21 | unfreeze(model) 22 | # Check top level module 23 | assert model.fc.weight.requires_grad == True 24 | # Check submodule 25 | assert model.layer1[0].conv1.weight.requires_grad == True 26 | # Check BN 27 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 28 | 29 | # Freeze some 30 | freeze(model, ['layer1', 'layer2.0']) 31 | # Check frozen 32 | assert model.layer1[0].conv1.weight.requires_grad == False 33 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 34 | assert model.layer2[0].conv1.weight.requires_grad == False 35 | # Check not frozen 36 | assert model.layer3[0].conv1.weight.requires_grad == True 37 | assert isinstance(model.layer3[0].bn1, BatchNorm2d) 38 | assert model.layer2[1].conv1.weight.requires_grad == True 39 | 40 | # Unfreeze some 41 | unfreeze(model, ['layer1', 'layer2.0']) 42 | # Check not frozen 43 | assert model.layer1[0].conv1.weight.requires_grad == True 44 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 45 | assert model.layer2[0].conv1.weight.requires_grad == True 46 | 47 | # Freeze/unfreeze BN 48 | # From root 49 | freeze(model, ['layer1.0.bn1']) 50 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 51 | unfreeze(model, ['layer1.0.bn1']) 52 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 53 | # From direct parent 54 | freeze(model.layer1[0], ['bn1']) 55 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 56 | unfreeze(model.layer1[0], ['bn1']) 57 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) -------------------------------------------------------------------------------- /pytorch_image_models/timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ 4 | get_pretrained_cfg_value, is_model_pretrained 5 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser,\ 10 | get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 11 | from .real_labels import RealLabelsImagenet 12 | from .transforms import * 13 | from .transforms_factory import create_transform 14 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | else: 29 | if use_test_size and 'test_input_size' in default_cfg: 30 | input_size = default_cfg['test_input_size'] 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if 'interpolation' in args and args['interpolation']: 38 | new_config['interpolation'] = args['interpolation'] 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if 'mean' in args and args['mean'] is not None: 45 | mean = tuple(args['mean']) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if 'std' in args and args['std'] is not None: 57 | std = tuple(args['std']) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | crop_pct = DEFAULT_CROP_PCT 68 | if 'crop_pct' in args and args['crop_pct'] is not None: 69 | crop_pct = args['crop_pct'] 70 | else: 71 | if use_test_size and 'test_crop_pct' in default_cfg: 72 | crop_pct = default_cfg['test_crop_pct'] 73 | elif 'crop_pct' in default_cfg: 74 | crop_pct = default_cfg['crop_pct'] 75 | new_config['crop_pct'] = crop_pct 76 | 77 | if verbose: 78 | _logger.info('Data processing configuration for current model + dataset:') 79 | for n, v in new_config.items(): 80 | _logger.info('\t%s: %s' % (n, str(v))) 81 | 82 | return new_config 83 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_class_map(map_or_filename, root=''): 5 | if isinstance(map_or_filename, dict): 6 | assert dict, 'class_map dict must be non-empty' 7 | return map_or_filename 8 | class_map_path = map_or_filename 9 | if not os.path.exists(class_map_path): 10 | class_map_path = os.path.join(root, class_map_path) 11 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 12 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 13 | if class_map_ext == '.txt': 14 | with open(class_map_path) as f: 15 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 16 | else: 17 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 18 | return class_to_idx 19 | 20 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_in_tar import ParserImageInTar 5 | 6 | 7 | def create_parser(name, root, split='train', **kwargs): 8 | name = name.lower() 9 | name = name.split('/', 2) 10 | prefix = '' 11 | if len(name) > 1: 12 | prefix = name[0] 13 | name = name[-1] 14 | 15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 16 | # explicitly select other options shortly 17 | if prefix == 'tfds': 18 | from .parser_tfds import ParserTfds # defer tensorflow import 19 | parser = ParserTfds(root, name, split=split, **kwargs) 20 | else: 21 | assert os.path.exists(root) 22 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 23 | # FIXME support split here, in parser? 24 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 25 | parser = ParserImageInTar(root, **kwargs) 26 | else: 27 | parser = ParserImageFolder(root, **kwargs) 28 | return parser 29 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .parser import Parser 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ParserImageTar(Parser): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, 18 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): 19 | super(BinaryCrossEntropy, self).__init__() 20 | assert 0. <= smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.target_threshold = target_threshold 23 | self.reduction = reduction 24 | self.register_buffer('weight', weight) 25 | self.register_buffer('pos_weight', pos_weight) 26 | 27 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 28 | assert x.shape[0] == target.shape[0] 29 | if target.shape != x.shape: 30 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 31 | num_classes = x.shape[-1] 32 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 33 | off_value = self.smoothing / num_classes 34 | on_value = 1. - self.smoothing + off_value 35 | target = target.long().view(-1, 1) 36 | target = torch.full( 37 | (target.size()[0], num_classes), 38 | off_value, 39 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 40 | if self.target_threshold is not None: 41 | # Make target 0, or 1 if threshold set 42 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 43 | return F.binary_cross_entropy_with_logits( 44 | x, target, 45 | self.weight, 46 | pos_weight=self.pos_weight, 47 | reduction=self.reduction) 48 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .convnext import * 9 | from .crossvit import * 10 | from .cspnet import * 11 | from .deit import * 12 | from .densenet import * 13 | from .dla import * 14 | from .dpn import * 15 | from .edgenext import * 16 | from .efficientnet import * 17 | from .ghostnet import * 18 | from .gluon_resnet import * 19 | from .gluon_xception import * 20 | from .hardcorenas import * 21 | from .hrnet import * 22 | from .inception_resnet_v2 import * 23 | from .inception_v3 import * 24 | from .inception_v4 import * 25 | from .levit import * 26 | from .mlp_mixer import * 27 | from .mobilenetv3 import * 28 | from .mobilevit import * 29 | from .nasnet import * 30 | from .nest import * 31 | from .nfnet import * 32 | from .pit import * 33 | from .pnasnet import * 34 | from .poolformer import * 35 | from .regnet import * 36 | from .res2net import * 37 | from .resnest import * 38 | from .resnet import * 39 | from .resnetv2 import * 40 | from .rexnet import * 41 | from .selecsls import * 42 | from .senet import * 43 | from .sequencer import * 44 | from .sknet import * 45 | from .swin_transformer import * 46 | from .swin_transformer_v2 import * 47 | from .swin_transformer_v2_cr import * 48 | from .tnt import * 49 | from .tresnet import * 50 | from .twins import * 51 | from .vgg import * 52 | from .visformer import * 53 | from .vision_transformer import * 54 | from .vision_transformer_hybrid import * 55 | from .vision_transformer_relpos import * 56 | from .volo import * 57 | from .vovnet import * 58 | from .xception import * 59 | from .xception_aligned import * 60 | from .xcit import * 61 | 62 | from .factory import create_model, parse_model_name, safe_model_name 63 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 64 | from .layers import TestTimePoolHead, apply_test_time_pool 65 | from .layers import convert_splitbn_model, convert_sync_batchnorm 66 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 67 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 68 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value 69 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from .layers import set_layer_config 7 | from .hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | 41 | Args: 42 | model_name (str): name of model to instantiate 43 | pretrained (bool): load pretrained ImageNet-1k weights if true 44 | checkpoint_path (str): path of checkpoint to load after model is initialized 45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 48 | 49 | Keyword Args: 50 | drop_rate (float): dropout rate for training (default: 0.0) 51 | global_pool (str): global pool type (default: 'avg') 52 | **: other kwargs are model specific 53 | """ 54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 56 | # non-supporting models don't break and default args remain in effect. 57 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 58 | 59 | model_source, model_name = parse_model_name(model_name) 60 | if model_source == 'hf-hub': 61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 63 | # load model weights + pretrained_cfg from Hugging Face hub. 64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 65 | 66 | if not is_model(model_name): 67 | raise RuntimeError('Unknown model (%s)' % model_name) 68 | 69 | create_fn = model_entrypoint(model_name) 70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 72 | 73 | if checkpoint_path: 74 | load_checkpoint(model, checkpoint_path) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 16 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 17 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 18 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 19 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 20 | from .gather_excite import GatherExcite 21 | from .global_context import GlobalContext 22 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 23 | from .inplace_abn import InplaceAbn 24 | from .linear import Linear 25 | from .mixed_conv2d import MixedConv2d 26 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 27 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 28 | from .norm import GroupNorm, GroupNorm1, LayerNorm2d 29 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 30 | from .padding import get_padding, get_same_padding, pad_same 31 | from .patch_embed import PatchEmbed 32 | from .pool2d_same import AvgPool2dSame, create_pool2d 33 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 34 | from .selective_kernel import SelectiveKernel 35 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 36 | from .space_to_depth import SpaceToDepthModule 37 | from .split_attn import SplitAttn 38 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 39 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 40 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 41 | from .trace_utils import _assert, _float_to_int 42 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 43 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 16 | super(ConvNormAct, self).__init__() 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | return x 39 | 40 | 41 | ConvBnAct = ConvNormAct 42 | 43 | 44 | def create_aa(aa_layer, channels, stride=2, enable=True): 45 | if not aa_layer or not enable: 46 | return nn.Identity() 47 | if isinstance(aa_layer, functools.partial): 48 | if issubclass(aa_layer.func, nn.AvgPool2d): 49 | return aa_layer() 50 | else: 51 | return aa_layer(channels) 52 | elif issubclass(aa_layer, nn.AvgPool2d): 53 | return aa_layer(stride) 54 | else: 55 | return aa_layer(channels=channels, stride=stride) 56 | 57 | 58 | class ConvNormActAa(nn.Module): 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 62 | super(ConvNormActAa, self).__init__() 63 | use_aa = aa_layer is not None and stride == 2 64 | 65 | self.conv = create_conv2d( 66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 67 | padding=padding, dilation=dilation, groups=groups, bias=bias) 68 | 69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.aa(x) 88 | return x 89 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GroupNorm(nn.GroupNorm): 9 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 10 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 11 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 12 | 13 | def forward(self, x): 14 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 15 | 16 | 17 | class GroupNorm1(nn.GroupNorm): 18 | """ Group Normalization with 1 group. 19 | Input: tensor in shape [B, C, *] 20 | """ 21 | 22 | def __init__(self, num_channels, **kwargs): 23 | super().__init__(1, num_channels, **kwargs) 24 | 25 | 26 | class LayerNorm2d(nn.LayerNorm): 27 | """ LayerNorm for channels of '2D' spatial NCHW tensors """ 28 | def __init__(self, num_channels, eps=1e-6, affine=True): 29 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | return F.layer_norm( 33 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 34 | 35 | 36 | def _is_contiguous(tensor: torch.Tensor) -> bool: 37 | # jit is oh so lovely :/ 38 | # if torch.jit.is_tracing(): 39 | # return True 40 | if torch.jit.is_scripting(): 41 | return tensor.is_contiguous() 42 | else: 43 | return tensor.is_contiguous(memory_format=torch.contiguous_format) 44 | 45 | 46 | @torch.jit.script 47 | def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 48 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) 49 | x = (x - u) * torch.rsqrt(s + eps) 50 | x = x * weight[:, None, None] + bias[:, None, None] 51 | return x 52 | 53 | 54 | class LayerNormExp2d(nn.LayerNorm): 55 | """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). 56 | 57 | Experimental implementation w/ manual norm for tensors non-contiguous tensors. 58 | 59 | This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last 60 | layout. However, benefits are not always clear and can perform worse on other GPUs. 61 | """ 62 | 63 | def __init__(self, num_channels, eps=1e-6): 64 | super().__init__(num_channels, eps=eps) 65 | 66 | def forward(self, x) -> torch.Tensor: 67 | if _is_contiguous(x): 68 | x = F.layer_norm( 69 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 70 | else: 71 | x = _layer_norm_cf(x, self.weight, self.bias, self.eps) 72 | return x 73 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 19 | super().__init__() 20 | img_size = to_2tuple(img_size) 21 | patch_size = to_2tuple(patch_size) 22 | self.img_size = img_size 23 | self.patch_size = patch_size 24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.flatten = flatten 27 | 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 30 | 31 | def forward(self, x): 32 | B, C, H, W = x.shape 33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 35 | x = self.proj(x) 36 | if self.flatten: 37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 38 | x = self.norm(x) 39 | return x 40 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/models/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adahessian import Adahessian 4 | from .adamp import AdamP 5 | from .adamw import AdamW 6 | from .lamb import Lamb 7 | from .lars import Lars 8 | from .lookahead import Lookahead 9 | from .madgrad import MADGRAD 10 | from .nadam import Nadam 11 | from .nvnovograd import NvNovoGrad 12 | from .radam import RAdam 13 | from .rmsprop_tf import RMSpropTF 14 | from .sgdp import SGDP 15 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 16 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | # NOTE super().__init__() not called on purpose 15 | if not 0.0 <= alpha <= 1.0: 16 | raise ValueError(f'Invalid slow update rate: {alpha}') 17 | if not 1 <= k: 18 | raise ValueError(f'Invalid lookahead steps: {k}') 19 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 20 | self._base_optimizer = base_optimizer 21 | self.param_groups = base_optimizer.param_groups 22 | self.defaults = base_optimizer.defaults 23 | self.defaults.update(defaults) 24 | self.state = defaultdict(dict) 25 | # manually add our defaults to the param groups 26 | for name, default in defaults.items(): 27 | for group in self._base_optimizer.param_groups: 28 | group.setdefault(name, default) 29 | 30 | @torch.no_grad() 31 | def update_slow(self, group): 32 | for fast_p in group["params"]: 33 | if fast_p.grad is None: 34 | continue 35 | param_state = self._base_optimizer.state[fast_p] 36 | if 'lookahead_slow_buff' not in param_state: 37 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 38 | param_state['lookahead_slow_buff'].copy_(fast_p) 39 | slow = param_state['lookahead_slow_buff'] 40 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 41 | fast_p.copy_(slow) 42 | 43 | def sync_lookahead(self): 44 | for group in self._base_optimizer.param_groups: 45 | self.update_slow(group) 46 | 47 | @torch.no_grad() 48 | def step(self, closure=None): 49 | loss = self._base_optimizer.step(closure) 50 | for group in self._base_optimizer.param_groups: 51 | group['lookahead_step'] += 1 52 | if group['lookahead_step'] % group['lookahead_k'] == 0: 53 | self.update_slow(group) 54 | return loss 55 | 56 | def state_dict(self): 57 | return self._base_optimizer.state_dict() 58 | 59 | def load_state_dict(self, state_dict): 60 | self._base_optimizer.load_state_dict(state_dict) 61 | self.param_groups = self._base_optimizer.param_groups 62 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__(self, params, lr=required, momentum=0, dampening=0, 21 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 22 | defaults = dict( 23 | lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 24 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 25 | super(SGDP, self).__init__(params, defaults) 26 | 27 | @torch.no_grad() 28 | def step(self, closure=None): 29 | loss = None 30 | if closure is not None: 31 | with torch.enable_grad(): 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | weight_decay = group['weight_decay'] 36 | momentum = group['momentum'] 37 | dampening = group['dampening'] 38 | nesterov = group['nesterov'] 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad 44 | state = self.state[p] 45 | 46 | # State initialization 47 | if len(state) == 0: 48 | state['momentum'] = torch.zeros_like(p) 49 | 50 | # SGD 51 | buf = state['momentum'] 52 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 53 | if nesterov: 54 | d_p = grad + momentum * buf 55 | else: 56 | d_p = buf 57 | 58 | # Projection 59 | wd_ratio = 1. 60 | if len(p.shape) > 1: 61 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 62 | 63 | # Weight decay 64 | if weight_decay != 0: 65 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 66 | 67 | # Step 68 | p.add_(d_p, alpha=-group['lr']) 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler 9 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__(self, 15 | optimizer: torch.optim.Optimizer, 16 | decay_t: List[int], 17 | decay_rate: float = 1., 18 | warmup_t=0, 19 | warmup_lr_init=0, 20 | t_in_epochs=True, 21 | noise_range_t=None, 22 | noise_pct=0.67, 23 | noise_std=1.0, 24 | noise_seed=42, 25 | initialize=True, 26 | ) -> None: 27 | super().__init__( 28 | optimizer, param_group_field="lr", 29 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 30 | initialize=initialize) 31 | 32 | self.decay_t = decay_t 33 | self.decay_rate = decay_rate 34 | self.warmup_t = warmup_t 35 | self.warmup_lr_init = warmup_lr_init 36 | self.t_in_epochs = t_in_epochs 37 | if self.warmup_t: 38 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 39 | super().update_groups(self.warmup_lr_init) 40 | else: 41 | self.warmup_steps = [1 for _ in self.base_values] 42 | 43 | def get_curr_decay_steps(self, t): 44 | # find where in the array t goes, 45 | # assumes self.decay_t is sorted 46 | return bisect.bisect_right(self.decay_t, t+1) 47 | 48 | def _get_lr(self, t): 49 | if t < self.warmup_t: 50 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 51 | else: 52 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 53 | return lrs 54 | 55 | def get_epoch_values(self, epoch: int): 56 | if self.t_in_epochs: 57 | return self._get_lr(epoch) 58 | else: 59 | return None 60 | 61 | def get_update_values(self, num_updates: int): 62 | if not self.t_in_epochs: 63 | return self._get_lr(num_updates) 64 | else: 65 | return None 66 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .distributed import distribute_bn, reduce_tensor 6 | from .jit import set_jit_legacy, set_jit_fuser 7 | from .log import setup_default_logging, FormatterNoInfo 8 | from .metrics import AverageMeter, accuracy 9 | from .misc import natural_key, add_bool_arg 10 | from .model import unwrap_model, get_state_dict, freeze, unfreeze 11 | from .model_ema import ModelEma, ModelEmaV2 12 | from .random import random_seed 13 | from .summary import update_summary, get_outdir 14 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 21 | with amp.scale_loss(loss, optimizer) as scaled_loss: 22 | scaled_loss.backward(create_graph=create_graph) 23 | if clip_grad is not None: 24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 25 | optimizer.step() 26 | 27 | def state_dict(self): 28 | if 'state_dict' in amp.__dict__: 29 | return amp.state_dict() 30 | 31 | def load_state_dict(self, state_dict): 32 | if 'load_state_dict' in amp.__dict__: 33 | amp.load_state_dict(state_dict) 34 | 35 | 36 | class NativeScaler: 37 | state_dict_key = "amp_scaler" 38 | 39 | def __init__(self): 40 | self._scaler = torch.cuda.amp.GradScaler() 41 | 42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 43 | self._scaler.scale(loss).backward(create_graph=create_graph) 44 | if clip_grad is not None: 45 | assert parameters is not None 46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 48 | self._scaler.step(optimizer) 49 | self._scaler.update() 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) 56 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | 10 | def set_jit_legacy(): 11 | """ Set JIT executor to legacy w/ support for op fusion 12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 13 | in the JIT exectutor. These API are not supported so could change. 14 | """ 15 | # 16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 17 | torch._C._jit_set_profiling_executor(False) 18 | torch._C._jit_set_profiling_mode(False) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | #torch._C._jit_set_texpr_fuser_enabled(True) 21 | 22 | 23 | def set_jit_fuser(fuser): 24 | if fuser == "te": 25 | # default fuser should be == 'te' 26 | torch._C._jit_set_profiling_executor(True) 27 | torch._C._jit_set_profiling_mode(True) 28 | torch._C._jit_override_can_fuse_on_cpu(False) 29 | torch._C._jit_override_can_fuse_on_gpu(True) 30 | torch._C._jit_set_texpr_fuser_enabled(True) 31 | try: 32 | torch._C._jit_set_nvfuser_enabled(False) 33 | except Exception: 34 | pass 35 | elif fuser == "old" or fuser == "legacy": 36 | torch._C._jit_set_profiling_executor(False) 37 | torch._C._jit_set_profiling_mode(False) 38 | torch._C._jit_override_can_fuse_on_gpu(True) 39 | torch._C._jit_set_texpr_fuser_enabled(False) 40 | try: 41 | torch._C._jit_set_nvfuser_enabled(False) 42 | except Exception: 43 | pass 44 | elif fuser == "nvfuser" or fuser == "nvf": 45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' 46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' 47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' 48 | torch._C._jit_set_texpr_fuser_enabled(False) 49 | torch._C._jit_set_profiling_executor(True) 50 | torch._C._jit_set_profiling_mode(True) 51 | torch._C._jit_can_fuse_on_cpu() 52 | torch._C._jit_can_fuse_on_gpu() 53 | torch._C._jit_override_can_fuse_on_cpu(False) 54 | torch._C._jit_override_can_fuse_on_gpu(False) 55 | torch._C._jit_set_nvfuser_guard_mode(True) 56 | torch._C._jit_set_nvfuser_enabled(True) 57 | else: 58 | assert False, f"Invalid jit fuser ({fuser})" 59 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | def get_outdir(path, *paths, inc=False): 14 | outdir = os.path.join(path, *paths) 15 | if not os.path.exists(outdir): 16 | os.makedirs(outdir) 17 | elif inc: 18 | count = 1 19 | outdir_inc = outdir + '-' + str(count) 20 | while os.path.exists(outdir_inc): 21 | count = count + 1 22 | outdir_inc = outdir + '-' + str(count) 23 | assert count < 100 24 | outdir = outdir_inc 25 | os.makedirs(outdir) 26 | return outdir 27 | 28 | 29 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): 30 | rowd = OrderedDict(epoch=epoch) 31 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 32 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 33 | if log_wandb: 34 | wandb.log(rowd) 35 | with open(filename, mode='a') as cf: 36 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 37 | if write_header: # first iteration (epoch == 1 can't be used) 38 | dw.writeheader() 39 | dw.writerow(rowd) 40 | -------------------------------------------------------------------------------- /pytorch_image_models/timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.6.5' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | scipy 4 | datasketch 5 | openai 6 | numpy 7 | nltk 8 | scipy 9 | tqdm 10 | prefetch_generator 11 | torchcam 12 | timm -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magic-research/Dataset_Quantization/24a2a77dd80c2cc86b53e2b2e204f1e02804d70e/util/__init__.py --------------------------------------------------------------------------------