├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── README_MUSE.md ├── assets ├── contexts │ └── empty_context.npy └── pipeline.png ├── cog.yaml ├── configs ├── cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py ├── custom.py ├── imagenet256_base_vq_jax.py └── vae_configs │ └── vq-f16-jax.yaml ├── custom └── custom_dataset.py ├── data ├── data.json ├── image_01_01.jpg ├── image_01_02.jpg ├── image_01_03.jpg ├── image_01_04.jpg ├── image_01_05.jpg ├── image_01_06.jpg ├── image_01_07.jpg ├── image_01_08.jpg ├── image_02_01.jpg ├── image_02_02.jpg ├── image_02_03.jpg ├── image_02_04.jpg ├── image_02_05.jpg ├── image_02_06.jpg ├── image_03_01.jpg ├── image_03_03.jpg ├── image_03_04.jpg ├── image_03_05.jpg ├── image_03_07.jpg ├── image_03_08.jpg └── one_style.json ├── datasets.py ├── extract_empty_feature.py ├── extract_imagenet_feature.py ├── extract_test_prompt_feature.py ├── feature2webdataset.py ├── gradio_demo.py ├── img ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── result.png └── split.py ├── libs ├── __init__.py ├── muse.py ├── uvit_t2i_vq.py └── uvit_vq.py ├── open_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── factory.py ├── generation_utils.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py └── version.py ├── predict.py ├── styledrop_colab.ipynb ├── taming ├── models │ └── vqgan.py ├── modules │ ├── diffusionmodules │ │ └── model.py │ ├── util.py │ └── vqvae │ │ └── quantize.py └── util.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── random_erasing.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── asymmetric_loss.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── cspnet.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── features.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── helpers.py │ ├── hrnet.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── anti_aliasing.py │ │ ├── blur_pool.py │ │ ├── cbam.py │ │ ├── classifier.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm_act.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── linear.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── pool2d_same.py │ │ ├── se.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── test_time_pool.py │ │ └── weight_init.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── pnasnet.py │ ├── pruned │ │ ├── ecaresnet101d_pruned.txt │ │ ├── ecaresnet50d_pruned.txt │ │ ├── efficientnet_b1_pruned.txt │ │ ├── efficientnet_b2_pruned.txt │ │ └── efficientnet_b3_pruned.txt │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── tresnet.py │ ├── vision_transformer.py │ ├── vovnet.py │ ├── xception.py │ └── xception_aligned.py ├── optim │ ├── __init__.py │ ├── adafactor.py │ ├── adahessian.py │ ├── adamp.py │ ├── adamw.py │ ├── lookahead.py │ ├── nadam.py │ ├── novograd.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ ├── rmsprop_tf.py │ └── sgdp.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── plateau_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils │ ├── __init__.py │ ├── checkpoint_saver.py │ ├── cuda.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ └── summary.py └── version.py ├── tools ├── __init__.py ├── fid_score.py └── inception.py ├── train_t2i_colab_v2.py ├── train_t2i_custom_v2.py ├── train_t2i_discrete_muse.py ├── train_t2i_discrete_wds.py ├── utils.py └── webdata.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.npz 4 | wandb/ 5 | *.ipynb_checkpoints 6 | .vscode/ 7 | *style/ 8 | *.sh 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fan Bao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/contexts/empty_context.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/assets/contexts/empty_context.npy -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/assets/pipeline.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | cuda: "11.8" 7 | python_version: "3.9" 8 | python_packages: 9 | - "torch==2.0.1" 10 | - "torchvision==0.15.2" 11 | - "loguru==0.7.0" 12 | - "xformers==0.0.20" 13 | - "ftfy==6.1.1" 14 | - "einops==0.6.1" 15 | - "ml-collections==0.1.1" 16 | - "omegaconf==2.3.0" 17 | - "transformers==4.23.1" 18 | - "webdataset==0.2.5" 19 | predict: "predict.py:Predictor" 20 | -------------------------------------------------------------------------------- /configs/cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.z_shape = (8, 16, 16) 14 | 15 | config.autoencoder = d( 16 | config_file='vq-f16-jax.yaml', 17 | ) 18 | 19 | config.train = d( 20 | n_steps=999999999, 21 | batch_size=2048, 22 | log_interval=10, 23 | eval_interval=5000, 24 | save_interval=5000, 25 | fid_interval=50000, 26 | num_workers=8, 27 | resampled=False, 28 | ) 29 | 30 | config.eval = d( 31 | n_samples=10000, 32 | sample_steps=18, 33 | ) 34 | 35 | config.optimizer = d( 36 | name='adamw', 37 | lr=0.0002, 38 | weight_decay=0.03, 39 | betas=(0.99, 0.99), 40 | ) 41 | 42 | config.lr_scheduler = d( 43 | name='customized', 44 | warmup_steps=5000 45 | ) 46 | 47 | config.nnet = d( 48 | name='uvit_t2i_vq', 49 | img_size=16, 50 | codebook_size=1024, 51 | in_chans=4, 52 | embed_dim=1152, 53 | depth=28, 54 | num_heads=16, 55 | mlp_ratio=4, 56 | qkv_bias=False, 57 | clip_dim=1280, 58 | num_clip_token=77, 59 | use_checkpoint=True, 60 | skip=True, 61 | ) 62 | 63 | config.muse = d( 64 | ignore_ind=-1, 65 | smoothing=0.1, 66 | gen_temp=4.5 67 | ) 68 | 69 | config.dataset = d( 70 | name='cc3m_web', 71 | cfg=True, 72 | p_uncond=0.15, 73 | ) 74 | 75 | config.wds = d( 76 | train_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_train_emb/{00000..03044}.tar', 77 | val_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_val_emb/{00000..00012}.tar', 78 | ctx_path='assets/contexts', 79 | dist_eval=True, 80 | ) 81 | 82 | config.sample = d( 83 | sample_steps=18, 84 | n_samples=30000, 85 | mini_batch_size=2, 86 | cfg=True, 87 | linear_inc_scale=True, 88 | scale=10., 89 | path='', 90 | ) 91 | 92 | return config 93 | -------------------------------------------------------------------------------- /configs/custom.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | 13 | config.seed = 1234 14 | config.z_shape = (8, 16, 16) 15 | 16 | config.autoencoder = d( 17 | config_file='vq-f16-jax.yaml', 18 | ) 19 | config.data_path="data/one_style.json" 20 | config.resume_root="assets/ckpts/cc3m-285000.ckpt" 21 | config.adapter_path=None 22 | config.sample_interval=True 23 | config.train = d( 24 | n_steps=1000, 25 | batch_size=8, 26 | log_interval=20, 27 | eval_interval=100, 28 | save_interval=100, 29 | fid_interval=20000, 30 | num_workers=8, 31 | resampled=False, 32 | ) 33 | 34 | config.optimizer = d( 35 | name='adamw', 36 | lr=0.0003, 37 | weight_decay=0.03, 38 | betas=(0.99, 0.99), 39 | ) 40 | 41 | config.lr_scheduler = d( 42 | name='customized', 43 | warmup_steps=-1, # 5000 44 | ) 45 | 46 | config.nnet = d( 47 | name='uvit_t2i_vq', 48 | img_size=16, 49 | codebook_size=1024, 50 | in_chans=4, 51 | embed_dim=1152, 52 | depth=28, 53 | num_heads=16, 54 | mlp_ratio=4, 55 | qkv_bias=False, 56 | clip_dim=1280, 57 | num_clip_token=77, 58 | use_checkpoint=False, 59 | skip=True, 60 | d_prj=32,# Stage I: 32; Stage II: TODO 61 | is_shared=False, # Stage I: False; Stage II: False 62 | ) 63 | 64 | config.muse = d( 65 | ignore_ind=-1, 66 | smoothing=0.1, 67 | gen_temp=4.5 68 | ) 69 | 70 | 71 | config.sample = d( 72 | sample_steps=36, 73 | n_samples=50, 74 | mini_batch_size=8, 75 | cfg=True, 76 | linear_inc_scale=True, 77 | scale=10., 78 | path='', 79 | lambdaA=2.0, # Stage I: 2.0; Stage II: TODO 80 | lambdaB=5.0, # Stage I: 5.0; Stage II: TODO 81 | ) 82 | 83 | return config 84 | -------------------------------------------------------------------------------- /configs/imagenet256_base_vq_jax.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def d(**kwargs): 5 | """Helper of creating a config dict.""" 6 | return ml_collections.ConfigDict(initial_dictionary=kwargs) 7 | 8 | 9 | def get_config(): 10 | config = ml_collections.ConfigDict() 11 | 12 | config.seed = 1234 13 | config.z_shape = (8, 16, 16) 14 | 15 | config.autoencoder = d( 16 | config_file='vq-f16-jax.yaml', 17 | ) 18 | 19 | config.train = d( 20 | n_steps=99999999, 21 | batch_size=2048, 22 | log_interval=10, 23 | eval_interval=5000, 24 | save_interval=5000, 25 | fid_interval=50000, 26 | ) 27 | 28 | config.eval = d( 29 | n_samples=10000, 30 | sample_steps=12, 31 | ) 32 | 33 | config.optimizer = d( 34 | name='adamw', 35 | lr=0.0004, 36 | weight_decay=0.03, 37 | betas=(0.99, 0.99), 38 | ) 39 | 40 | config.lr_scheduler = d( 41 | name='customized', 42 | warmup_steps=5000 43 | ) 44 | 45 | config.nnet = d( 46 | name='uvit_vq', 47 | img_size=16, 48 | codebook_size=1024, 49 | in_chans=256, 50 | patch_size=1, 51 | embed_dim=768, 52 | depth=12, 53 | num_heads=12, 54 | mlp_ratio=4, 55 | qkv_bias=False, 56 | num_classes=1001, 57 | use_checkpoint=False, 58 | skip=True, 59 | ) 60 | 61 | config.muse = d( 62 | ignore_ind=-1, 63 | smoothing=0.1, 64 | gen_temp=4.5 65 | ) 66 | 67 | config.dataset = d( 68 | name='imagenet256_features', 69 | path='assets/datasets/imagenet256_vq_features/vq-f16-jax', 70 | cfg=True, 71 | p_uncond=0.15, 72 | ) 73 | 74 | config.sample = d( 75 | sample_steps=12, 76 | n_samples=50000, 77 | mini_batch_size=50, 78 | cfg=True, 79 | linear_inc_scale=True, 80 | scale=3., 81 | path='' 82 | ) 83 | 84 | return config 85 | -------------------------------------------------------------------------------- /configs/vae_configs/vq-f16-jax.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 250001 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 8 32 | num_workers: 24 33 | train: 34 | target: taming.data.imagenet.ImageNetTrain 35 | params: 36 | config: 37 | size: 256 38 | validation: 39 | target: taming.data.imagenet.ImageNetValidation 40 | params: 41 | config: 42 | size: 256 43 | -------------------------------------------------------------------------------- /data/data.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_01_01.jpg":["A bay","in watercolor painting style"], 3 | "image_01_02.jpg":["A house", "in watercolor painting style"], 4 | "image_01_03.jpg":["A cat", "in watercolor painting style"], 5 | "image_01_04.jpg":["Flowers", "in watercolor painting style"], 6 | "image_01_05.jpg":["A village", "in oil painting style"], 7 | "image_01_06.jpg":["A village", "in line drawing style"], 8 | "image_01_07.jpg":["A portrait of a person", "in oil painting style"], 9 | "image_01_08.jpg":["A portrait of a person wearing a hat", "in oil painting style"], 10 | "image_02_01.jpg":["A person drwoning into th phone", "in cartoon line drawing style"], 11 | "image_02_02.jpg":["A woman walking a dog", "in flat cartoon illustration style"], 12 | "image_02_03.jpg":["A woman working on a laptop", "in flat cartoon illustration style"], 13 | "image_02_04.jpg":["A Christmas tree", "in sticker style"], 14 | "image_02_05.jpg":["A wave", "in abstract rainbow colored flowing smoke wave design"], 15 | "image_02_06.jpg":["A mushroom", "in glowing style"], 16 | "image_03_01.jpg":["Slice of watermelon and clouds in the background", "in 3d rendering style"], 17 | "image_03_03.jpg":["A thumbs up", "in glowing 3d rendering style"], 18 | "image_03_04.jpg":["A woman", "in 3d rendering style"], 19 | "image_03_05.jpg":["A bear", "in kid crayon drawing style"], 20 | "image_03_07.jpg":["A flower", "in melting golden 3d rendering style"], 21 | "image_03_08.jpg":["A Viking face with beard", "in wooden sculpture"] 22 | } -------------------------------------------------------------------------------- /data/image_01_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_01.jpg -------------------------------------------------------------------------------- /data/image_01_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_02.jpg -------------------------------------------------------------------------------- /data/image_01_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_03.jpg -------------------------------------------------------------------------------- /data/image_01_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_04.jpg -------------------------------------------------------------------------------- /data/image_01_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_05.jpg -------------------------------------------------------------------------------- /data/image_01_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_06.jpg -------------------------------------------------------------------------------- /data/image_01_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_07.jpg -------------------------------------------------------------------------------- /data/image_01_08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_01_08.jpg -------------------------------------------------------------------------------- /data/image_02_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_01.jpg -------------------------------------------------------------------------------- /data/image_02_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_02.jpg -------------------------------------------------------------------------------- /data/image_02_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_03.jpg -------------------------------------------------------------------------------- /data/image_02_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_04.jpg -------------------------------------------------------------------------------- /data/image_02_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_05.jpg -------------------------------------------------------------------------------- /data/image_02_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_02_06.jpg -------------------------------------------------------------------------------- /data/image_03_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_01.jpg -------------------------------------------------------------------------------- /data/image_03_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_03.jpg -------------------------------------------------------------------------------- /data/image_03_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_04.jpg -------------------------------------------------------------------------------- /data/image_03_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_05.jpg -------------------------------------------------------------------------------- /data/image_03_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_07.jpg -------------------------------------------------------------------------------- /data/image_03_08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/data/image_03_08.jpg -------------------------------------------------------------------------------- /data/one_style.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_01_02.jpg":["A house", "in watercolor painting style"] 3 | } -------------------------------------------------------------------------------- /extract_empty_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open_clip 4 | 5 | def main(): 6 | prompts = [ 7 | '', 8 | ] 9 | 10 | device = 'cuda' 11 | model, _, _ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k') 12 | model = model.to(device) 13 | model.eval() 14 | tokenizer = open_clip.get_tokenizer('ViT-bigG-14') 15 | 16 | text_tokens = tokenizer(prompts).to(device) 17 | latent = model.encode_text(text_tokens) 18 | 19 | print(latent.shape) 20 | c = latent[0].detach().cpu().float().numpy() 21 | del model 22 | del tokenizer 23 | save_dir = f'assets/contexts' 24 | os.makedirs(save_dir, exist_ok=True) 25 | np.save(os.path.join(save_dir, f'empty_context.npy'), c) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /extract_imagenet_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch 5 | from datasets import ImageNet 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | 12 | 13 | def main(resolution=256): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('path') 16 | args = parser.parse_args() 17 | 18 | dataset = ImageNet(path=args.path, resolution=resolution, random_flip=False) 19 | train_dataset = dataset.get_split(split='train', labeled=True) 20 | train_dataset_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, drop_last=False, 21 | num_workers=8, pin_memory=True, persistent_workers=True) 22 | 23 | import taming.models.vqgan 24 | model = taming.models.vqgan.get_model('vq-f16-jax.yaml') 25 | 26 | model = nn.DataParallel(model) 27 | model.eval() 28 | model.requires_grad_(False) 29 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 30 | model.to(device) 31 | 32 | feat_all = [] 33 | with torch.no_grad(): 34 | for batch in tqdm(train_dataset_loader): 35 | img, label = batch 36 | img = img.to(device) 37 | img = torch.cat([img, img.flip(dims=[-1])], dim=0) 38 | label = torch.cat([label, label], dim=0) 39 | label = label.detach().cpu().numpy() 40 | N = len(label) 41 | batch = model(img) 42 | feat = batch[-1][-1].detach().cpu().numpy() 43 | feat_all.append(np.concatenate((label[:, None], feat.reshape(N, -1)), axis=-1)) 44 | feat_all = np.concatenate(feat_all) 45 | 46 | out_dir = f'assets/datasets/imagenet256_vq_features/vq-f16-jax' 47 | os.makedirs(out_dir, exist_ok=True) 48 | np.save(f'{out_dir}/train.npy', feat_all) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /extract_test_prompt_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | import open_clip 6 | 7 | def main(): 8 | prompts = [ 9 | 'A green train is coming down the tracks.', 10 | 'A group of skiers are preparing to ski down a mountain.', 11 | 'A small kitchen with a low ceiling.', 12 | 'A group of elephants walking in muddy water.', 13 | 'A living area with a television and a table.', 14 | 'A road with traffic lights, street lights and cars.', 15 | 'A bus driving in a city area with traffic signs.', 16 | 'A bus pulls over to the curb close to an intersection.', 17 | 'A group of people are walking and one is holding an umbrella.', 18 | 'A baseball player taking a swing at an incoming ball.', 19 | 'A city street line with brick buildings and trees.', 20 | 'A close up of a plate of broccoli and sauce.', 21 | ] 22 | 23 | device = 'cuda' 24 | model, _, _ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k') 25 | model = model.to(device) 26 | model.eval() 27 | tokenizer = open_clip.get_tokenizer('ViT-bigG-14') 28 | 29 | text_tokens = tokenizer(prompts).to(device) 30 | latent = model.encode_text(text_tokens) 31 | 32 | save_dir = Path(f'assets/contexts/run_vis') 33 | save_dir.mkdir(exist_ok=True, parents=True) 34 | for i in range(len(latent)): 35 | c = latent[i].detach().cpu().float().numpy() 36 | np.save(os.path.join(save_dir, f'{i}.npy'), c) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /feature2webdataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import shutil 5 | import numpy as np 6 | import datetime 7 | import webdataset as wds 8 | from multiprocessing import Process 9 | from PIL import Image 10 | 11 | def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs): 12 | samples_per_shards = [samples[i::num_shards] for i in range(num_shards)] 13 | shard_ids = list(range(num_shards)) 14 | 15 | processes = [ 16 | Process( 17 | target=write_partial_samples, 18 | args=( 19 | pattern, 20 | shard_ids[i::num_workers], 21 | samples_per_shards[i::num_workers], 22 | map_func, 23 | kwargs 24 | ) 25 | ) 26 | for i in range(num_workers)] 27 | for p in processes: 28 | p.start() 29 | for p in processes: 30 | p.join() 31 | 32 | 33 | def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs): 34 | for shard_id, samples in zip(shard_ids, samples): 35 | write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs) 36 | 37 | 38 | def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs): 39 | fname = pattern % shard_id 40 | url = '/'.join(fname.split('/')[-2:]) 41 | sink = wds.TarWriter(fname, **kwargs) 42 | for item in samples: 43 | for content in map_func(item, url): 44 | sink.write(content) 45 | sink.close() 46 | 47 | if __name__ == "__main__": 48 | 49 | # all files in img_emb_path and text_raw_path are npy format 50 | # files in img_emb_path and text_raw_path are related by their file names 51 | # i.e., img_emb/000000.npy and text_raw/000000.npy are related, img_emb/000001.npy and text_raw/000001.npy are related 52 | img_emb_path = /path/to/img_emb 53 | text_raw_path = /path/to/text_raw 54 | 55 | num_workers = 180 56 | img_filelist = os.listdir(img_emb_path) 57 | text_raw_filelist = os.listdir(text_raw_path) 58 | 59 | img_file_paths = [os.path.join(img_emb_path, fp) for fp in img_filelist] 60 | text_raw_file_paths = [os.path.join(text_raw_path, fp) for fp in text_raw_filelist] 61 | 62 | text_raw_file_paths = sorted(text_raw_file_paths) 63 | img_file_paths = sorted(img_file_paths) 64 | 65 | print("Num of img_file_paths: ", len(img_file_paths)) 66 | print("Num of text_raw_file_paths: ", len(text_raw_file_paths)) 67 | 68 | file_paths = [] 69 | for fi, ftr in zip(img_file_paths, text_raw_file_paths): 70 | file_paths.append([fi, ftr]) 71 | 72 | num_shards = len(file_paths) 73 | 74 | print(file_paths) 75 | 76 | def sampler(fp, url): 77 | image_path, text_raw_path = fp 78 | text_raw = np.load(text_raw_path, allow_pickle=True) 79 | images = np.load(image_path, allow_pickle=True).reshape(-1, 256) 80 | 81 | print(f"shape of input raw text: {text_raw.shape} with dtype: {text_raw.dtype} and shape of images: {images.shape} with dtype: {images.dtype}") 82 | 83 | for i, (img_emb, text) in enumerate(zip(images, text_raw)): 84 | try: 85 | 86 | sample = { 87 | "__key__": f"%08d"%i, 88 | "__url__": url, # path/to/xxx.tar 89 | "image.npy": img_emb.tobytes(), 90 | "text.npy": str(text), 91 | } 92 | 93 | yield sample 94 | except: 95 | continue 96 | 97 | make_wds_shards( 98 | pattern=f"{output_path}/%08d.tar", 99 | num_shards=num_shards, 100 | num_workers=num_workers, 101 | samples=file_paths, 102 | map_func=sampler, 103 | ) 104 | 105 | -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/3.png -------------------------------------------------------------------------------- /img/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/4.png -------------------------------------------------------------------------------- /img/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/5.png -------------------------------------------------------------------------------- /img/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/img/result.png -------------------------------------------------------------------------------- /img/split.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | import numpy as np 5 | 6 | def split(name): 7 | 8 | im_name = f"{name}.png" 9 | im = np.array(Image.open(im_name)) 10 | 11 | print(im.shape) 12 | 13 | im2 = im[258:,:,:] 14 | print(im2.shape) 15 | 16 | im2 = Image.fromarray(im2) 17 | im2.save(f"{name}_split.png") 18 | 19 | split("285300_0") -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # codes from third party 2 | -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 8 | from .openai import load_openai_model, list_openai_models 9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 12 | from .tokenizer import SimpleTokenizer, tokenize, decode 13 | from .transform import image_transform, AugmentationCfg 14 | -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /open_clip/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/open_clip/generation_utils.py -------------------------------------------------------------------------------- /open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | } 46 | -------------------------------------------------------------------------------- /open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /open_clip/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /open_clip/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /open_clip/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /open_clip/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /open_clip/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.15.0' 2 | -------------------------------------------------------------------------------- /taming/models/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from taming.modules.diffusionmodules.model import Encoder, Decoder 6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 7 | 8 | 9 | class VQModel(nn.Module): 10 | def __init__(self, 11 | ddconfig, 12 | lossconfig, 13 | n_embed, 14 | embed_dim, 15 | ckpt_path=None, 16 | ignore_keys=[], 17 | image_key="image", 18 | colorize_nlabels=None, 19 | monitor=None, 20 | remap=None, 21 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 22 | ): 23 | super().__init__() 24 | self.n_embed = n_embed 25 | self.embed_dim = embed_dim 26 | self.image_key = image_key 27 | self.encoder = Encoder(**ddconfig) 28 | self.decoder = Decoder(**ddconfig) 29 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 30 | remap=remap, sane_index_shape=sane_index_shape) 31 | if ckpt_path is not None: 32 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 33 | self.image_key = image_key 34 | if colorize_nlabels is not None: 35 | assert type(colorize_nlabels) == int 36 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 37 | if monitor is not None: 38 | self.monitor = monitor 39 | self.eval() 40 | self.requires_grad_(False) 41 | 42 | def init_from_ckpt(self, path, ignore_keys=list()): 43 | sd = torch.load(path, map_location="cpu") 44 | if "state_dict" in sd.keys(): 45 | sd = sd["state_dict"] 46 | keys = list(sd.keys()) 47 | for k in keys: 48 | for ik in ignore_keys: 49 | if k.startswith(ik): 50 | print("Deleting key {} from state_dict.".format(k)) 51 | del sd[k] 52 | print("Strict load") 53 | self.load_state_dict(sd, strict=True) 54 | print(f"Restored from {path}") 55 | 56 | def encode(self, x): 57 | h = self.encoder(x) 58 | quant, emb_loss, info = self.quantize(h) 59 | return quant, emb_loss, info 60 | 61 | def decode(self, quant): 62 | dec = self.decoder(quant) 63 | return dec 64 | 65 | def decode_code(self, code_b): 66 | quant_b = self.quantize.get_codebook_entry(code_b, [*code_b.shape, self.embed_dim]) 67 | dec = self.decode(quant_b) 68 | return dec 69 | 70 | def forward(self, input): 71 | quant, diff, info = self.encode(input) 72 | return quant, diff, info 73 | 74 | def get_input(self, batch, k): 75 | x = batch[k] 76 | if len(x.shape) == 3: 77 | x = x[..., None] 78 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 79 | return x.float() 80 | 81 | def get_last_layer(self): 82 | return self.decoder.conv_out.weight 83 | 84 | def log_images(self, batch, **kwargs): 85 | log = dict() 86 | x = self.get_input(batch, self.image_key) 87 | x = x.to(self.device) 88 | xrec, _ = self(x) 89 | if x.shape[1] > 3: 90 | # colorize with random projection 91 | assert xrec.shape[1] > 3 92 | x = self.to_rgb(x) 93 | xrec = self.to_rgb(xrec) 94 | log["inputs"] = x 95 | log["reconstructions"] = xrec 96 | return log 97 | 98 | def to_rgb(self, x): 99 | assert self.image_key == "segmentation" 100 | if not hasattr(self, "colorize"): 101 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 102 | x = F.conv2d(x, weight=self.colorize) 103 | x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. 104 | return x 105 | 106 | 107 | def get_model(config_file='vq-f16-jax.yaml'): 108 | from omegaconf import OmegaConf 109 | config = OmegaConf.load(f'configs/vae_configs/{config_file}').model 110 | return VQModel(ddconfig=config.params.ddconfig, 111 | lossconfig=config.params.lossconfig, 112 | n_embed=config.params.n_embed, 113 | embed_dim=config.params.embed_dim, 114 | ckpt_path='assets/vqgan_jax_strongaug.ckpt') 115 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable 4 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .config import resolve_data_config 3 | from .dataset import Dataset, DatasetTar, AugMixDataset 4 | from .transforms import * 5 | from .loader import create_loader 6 | from .transforms_factory import create_transform 7 | from .mixup import Mixup, FastCollateMixup 8 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 9 | rand_augment_transform, auto_augment_transform 10 | from .real_labels import RealLabelsImagenet 11 | -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, verbose=True): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | elif 'input_size' in default_cfg: 29 | input_size = default_cfg['input_size'] 30 | new_config['input_size'] = input_size 31 | 32 | # resolve interpolation method 33 | new_config['interpolation'] = 'bicubic' 34 | if 'interpolation' in args and args['interpolation']: 35 | new_config['interpolation'] = args['interpolation'] 36 | elif 'interpolation' in default_cfg: 37 | new_config['interpolation'] = default_cfg['interpolation'] 38 | 39 | # resolve dataset + model mean for normalization 40 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 41 | if 'mean' in args and args['mean'] is not None: 42 | mean = tuple(args['mean']) 43 | if len(mean) == 1: 44 | mean = tuple(list(mean) * in_chans) 45 | else: 46 | assert len(mean) == in_chans 47 | new_config['mean'] = mean 48 | elif 'mean' in default_cfg: 49 | new_config['mean'] = default_cfg['mean'] 50 | 51 | # resolve dataset + model std deviation for normalization 52 | new_config['std'] = IMAGENET_DEFAULT_STD 53 | if 'std' in args and args['std'] is not None: 54 | std = tuple(args['std']) 55 | if len(std) == 1: 56 | std = tuple(list(std) * in_chans) 57 | else: 58 | assert len(std) == in_chans 59 | new_config['std'] = std 60 | elif 'std' in default_cfg: 61 | new_config['std'] = default_cfg['std'] 62 | 63 | # resolve default crop percentage 64 | new_config['crop_pct'] = DEFAULT_CROP_PCT 65 | if 'crop_pct' in args and args['crop_pct'] is not None: 66 | new_config['crop_pct'] = args['crop_pct'] 67 | elif 'crop_pct' in default_cfg: 68 | new_config['crop_pct'] = default_cfg['crop_pct'] 69 | 70 | if verbose: 71 | _logger.info('Data processing configuration for current model + dataset:') 72 | for n, v in new_config.items(): 73 | _logger.info('\t%s: %s' % (n, str(v))) 74 | 75 | return new_config 76 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 2 | from .jsd import JsdCrossEntropy 3 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch._C.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch._C.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x, target): 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cspnet import * 2 | from .densenet import * 3 | from .dla import * 4 | from .dpn import * 5 | from .efficientnet import * 6 | from .gluon_resnet import * 7 | from .gluon_xception import * 8 | from .hrnet import * 9 | from .inception_resnet_v2 import * 10 | from .inception_v3 import * 11 | from .inception_v4 import * 12 | from .mobilenetv3 import * 13 | from .nasnet import * 14 | from .pnasnet import * 15 | from .regnet import * 16 | from .res2net import * 17 | from .resnest import * 18 | from .resnet import * 19 | from .rexnet import * 20 | from .selecsls import * 21 | from .senet import * 22 | from .sknet import * 23 | from .tresnet import * 24 | from .vision_transformer import * 25 | from .vovnet import * 26 | from .xception import * 27 | from .xception_aligned import * 28 | 29 | from .factory import create_model 30 | from .helpers import load_checkpoint, resume_checkpoint 31 | from .layers import TestTimePoolHead, apply_test_time_pool 32 | from .layers import convert_splitbn_model 33 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 34 | from .registry import * 35 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | from .layers import set_layer_config 4 | 5 | 6 | def create_model( 7 | model_name, 8 | pretrained=False, 9 | num_classes=1000, 10 | in_chans=3, 11 | checkpoint_path='', 12 | scriptable=None, 13 | exportable=None, 14 | no_jit=None, 15 | **kwargs): 16 | """Create a model 17 | 18 | Args: 19 | model_name (str): name of model to instantiate 20 | pretrained (bool): load pretrained ImageNet-1k weights if true 21 | num_classes (int): number of classes for final fully connected layer (default: 1000) 22 | in_chans (int): number of input channels / colors (default: 3) 23 | checkpoint_path (str): path of checkpoint to load after model is initialized 24 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 25 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 26 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 27 | 28 | Keyword Args: 29 | drop_rate (float): dropout rate for training (default: 0.0) 30 | global_pool (str): global pool type (default: 'avg') 31 | **: other kwargs are model specific 32 | """ 33 | model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) 34 | 35 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 36 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 37 | if not is_efficientnet: 38 | kwargs.pop('bn_tf', None) 39 | kwargs.pop('bn_momentum', None) 40 | kwargs.pop('bn_eps', None) 41 | 42 | # handle backwards compat with drop_connect -> drop_path change 43 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 44 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 45 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 46 | " Setting drop_path to %f." % drop_connect_rate) 47 | kwargs['drop_path_rate'] = drop_connect_rate 48 | 49 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 50 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 51 | # non-supporting models don't break and default args remain in effect. 52 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 53 | 54 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 55 | if is_model(model_name): 56 | create_fn = model_entrypoint(model_name) 57 | model = create_fn(**model_args, **kwargs) 58 | else: 59 | raise RuntimeError('Unknown model (%s)' % model_name) 60 | 61 | if checkpoint_path: 62 | load_checkpoint(model, checkpoint_path) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .anti_aliasing import AntiAliasDownsampleLayer 5 | from .blur_pool import BlurPool2d 6 | from .classifier import ClassifierHead, create_classifier 7 | from .cond_conv2d import CondConv2d, get_condconv_initializer 8 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 9 | set_layer_config 10 | from .conv2d_same import Conv2dSame 11 | from .conv_bn_act import ConvBnAct 12 | from .create_act import create_act_layer, get_act_layer, get_act_fn 13 | from .create_attn import create_attn 14 | from .create_conv2d import create_conv2d 15 | from .create_norm_act import create_norm_act, get_norm_act_layer 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule 18 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 19 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple 20 | from .inplace_abn import InplaceAbn 21 | from .linear import Linear 22 | from .mixed_conv2d import MixedConv2d 23 | from .norm_act import BatchNormAct2d 24 | from .padding import get_padding 25 | from .pool2d_same import AvgPool2dSame, create_pool2d 26 | from .se import SEModule 27 | from .selective_kernel import SelectiveKernelConv 28 | from .separable_conv import SeparableConv2d, SeparableConvBnAct 29 | from .space_to_depth import SpaceToDepthModule 30 | from .split_attn import SplitAttnConv2d 31 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 32 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 33 | from .weight_init import trunc_normal_ 34 | -------------------------------------------------------------------------------- /timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /timm/models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = flatten 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(self.flatten) 91 | self.flatten = False 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return self.pool_type == '' 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | if self.flatten: 109 | x = x.flatten(1) 110 | return x 111 | 112 | def feat_mult(self): 113 | return adaptive_pool_feat_mult(self.pool_type) 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + ' (' \ 117 | + 'pool_type=' + self.pool_type \ 118 | + ', flatten=' + str(self.flatten) + ')' 119 | 120 | -------------------------------------------------------------------------------- /timm/models/layers/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AntiAliasDownsampleLayer(nn.Module): 8 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): 9 | super(AntiAliasDownsampleLayer, self).__init__() 10 | if no_jit: 11 | self.op = Downsample(channels, filt_size, stride) 12 | else: 13 | self.op = DownsampleJIT(channels, filt_size, stride) 14 | 15 | # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls 16 | 17 | def forward(self, x): 18 | return self.op(x) 19 | 20 | 21 | @torch.jit.script 22 | class DownsampleJIT(object): 23 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): 24 | self.channels = channels 25 | self.stride = stride 26 | self.filt_size = filt_size 27 | assert self.filt_size == 3 28 | assert stride == 2 29 | self.filt = {} # lazy init by device for DataParallel compat 30 | 31 | def _create_filter(self, like: torch.Tensor): 32 | filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) 33 | filt = filt[:, None] * filt[None, :] 34 | filt = filt / torch.sum(filt) 35 | return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 36 | 37 | def __call__(self, input: torch.Tensor): 38 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 39 | filt = self.filt.get(str(input.device), self._create_filter(input)) 40 | return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) 41 | 42 | 43 | class Downsample(nn.Module): 44 | def __init__(self, channels=None, filt_size=3, stride=2): 45 | super(Downsample, self).__init__() 46 | self.channels = channels 47 | self.filt_size = filt_size 48 | self.stride = stride 49 | 50 | assert self.filt_size == 3 51 | filt = torch.tensor([1., 2., 1.]) 52 | filt = filt[:, None] * filt[None, :] 53 | filt = filt / torch.sum(filt) 54 | 55 | # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 56 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 57 | 58 | def forward(self, input): 59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) 61 | -------------------------------------------------------------------------------- /timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | FIXME merge this impl with those in `anti_aliasing.py` 7 | 8 | Hacked together by Chris Ha and Ross Wightman 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from typing import Dict 16 | from .padding import get_padding 17 | 18 | 19 | class BlurPool2d(nn.Module): 20 | r"""Creates a module that computes blurs and downsample a given feature map. 21 | See :cite:`zhang2019shiftinvar` for more details. 22 | Corresponds to the Downsample class, which does blurring and subsampling 23 | 24 | Args: 25 | channels = Number of input channels 26 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 27 | stride (int): downsampling filter stride 28 | 29 | Returns: 30 | torch.Tensor: the transformed tensor. 31 | """ 32 | filt: Dict[str, torch.Tensor] 33 | 34 | def __init__(self, channels, filt_size=3, stride=2) -> None: 35 | super(BlurPool2d, self).__init__() 36 | assert filt_size > 1 37 | self.channels = channels 38 | self.filt_size = filt_size 39 | self.stride = stride 40 | pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 41 | self.padding = nn.ReflectionPad2d(pad_size) 42 | self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat 43 | self.filt = {} # lazy init by device for DataParallel compat 44 | 45 | def _create_filter(self, like: torch.Tensor): 46 | blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) 47 | return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) 48 | 49 | def _apply(self, fn): 50 | # override nn.Module _apply, reset filter cache if used 51 | self.filt = {} 52 | super(BlurPool2d, self)._apply(fn) 53 | 54 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 55 | C = input_tensor.shape[1] 56 | blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) 57 | return F.conv2d( 58 | self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) 59 | -------------------------------------------------------------------------------- /timm/models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | import torch.nn.functional as F 14 | from .conv_bn_act import ConvBnAct 15 | 16 | 17 | class ChannelAttn(nn.Module): 18 | """ Original CBAM channel attention module, currently avg + max pool variant only. 19 | """ 20 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 21 | super(ChannelAttn, self).__init__() 22 | self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) 23 | self.act = act_layer(inplace=True) 24 | self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) 25 | 26 | def forward(self, x): 27 | x_avg = x.mean((2, 3), keepdim=True) 28 | x_max = F.adaptive_max_pool2d(x, 1) 29 | x_avg = self.fc2(self.act(self.fc1(x_avg))) 30 | x_max = self.fc2(self.act(self.fc1(x_max))) 31 | x_attn = x_avg + x_max 32 | return x * x_attn.sigmoid() 33 | 34 | 35 | class LightChannelAttn(ChannelAttn): 36 | """An experimental 'lightweight' that sums avg + max pool first 37 | """ 38 | def __init__(self, channels, reduction=16): 39 | super(LightChannelAttn, self).__init__(channels, reduction) 40 | 41 | def forward(self, x): 42 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) 43 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 44 | return x * x_attn.sigmoid() 45 | 46 | 47 | class SpatialAttn(nn.Module): 48 | """ Original CBAM spatial attention module 49 | """ 50 | def __init__(self, kernel_size=7): 51 | super(SpatialAttn, self).__init__() 52 | self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) 53 | 54 | def forward(self, x): 55 | x_avg = torch.mean(x, dim=1, keepdim=True) 56 | x_max = torch.max(x, dim=1, keepdim=True)[0] 57 | x_attn = torch.cat([x_avg, x_max], dim=1) 58 | x_attn = self.conv(x_attn) 59 | return x * x_attn.sigmoid() 60 | 61 | 62 | class LightSpatialAttn(nn.Module): 63 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 64 | """ 65 | def __init__(self, kernel_size=7): 66 | super(LightSpatialAttn, self).__init__() 67 | self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) 68 | 69 | def forward(self, x): 70 | x_avg = torch.mean(x, dim=1, keepdim=True) 71 | x_max = torch.max(x, dim=1, keepdim=True)[0] 72 | x_attn = 0.5 * x_avg + 0.5 * x_max 73 | x_attn = self.conv(x_attn) 74 | return x * x_attn.sigmoid() 75 | 76 | 77 | class CbamModule(nn.Module): 78 | def __init__(self, channels, spatial_kernel_size=7): 79 | super(CbamModule, self).__init__() 80 | self.channel = ChannelAttn(channels) 81 | self.spatial = SpatialAttn(spatial_kernel_size) 82 | 83 | def forward(self, x): 84 | x = self.channel(x) 85 | x = self.spatial(x) 86 | return x 87 | 88 | 89 | class LightCbamModule(nn.Module): 90 | def __init__(self, channels, spatial_kernel_size=7): 91 | super(LightCbamModule, self).__init__() 92 | self.channel = LightChannelAttn(channels) 93 | self.spatial = LightSpatialAttn(spatial_kernel_size) 94 | 95 | def forward(self, x): 96 | x = self.channel(x) 97 | x = self.spatial(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | from .linear import Linear 10 | 11 | 12 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 13 | flatten = not use_conv # flatten when we use a Linear layer after pooling 14 | if not pool_type: 15 | assert num_classes == 0 or use_conv,\ 16 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 17 | flatten = False # disable flattening if pooling is pass-through (no pooling) 18 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten) 19 | num_pooled_features = num_features * global_pool.feat_mult() 20 | if num_classes <= 0: 21 | fc = nn.Identity() # pass-through (no classifier) 22 | elif use_conv: 23 | fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) 24 | else: 25 | # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue 26 | fc = Linear(num_pooled_features, num_classes, bias=True) 27 | return global_pool, fc 28 | 29 | 30 | class ClassifierHead(nn.Module): 31 | """Classifier head w/ configurable global pooling and dropout.""" 32 | 33 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): 34 | super(ClassifierHead, self).__init__() 35 | self.drop_rate = drop_rate 36 | self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type) 37 | 38 | def forward(self, x): 39 | x = self.global_pool(x) 40 | if self.drop_rate: 41 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 42 | x = self.fc(x) 43 | return x 44 | -------------------------------------------------------------------------------- /timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import convert_norm_act_type 9 | 10 | 11 | class ConvBnAct(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 13 | norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, 14 | drop_block=None, aa_layer=None): 15 | super(ConvBnAct, self).__init__() 16 | use_aa = aa_layer is not None 17 | 18 | self.conv = create_conv2d( 19 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 20 | padding=padding, dilation=dilation, groups=groups, bias=False) 21 | 22 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 23 | norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) 25 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | if self.aa is not None: 39 | x = self.aa(x) 40 | return x 41 | -------------------------------------------------------------------------------- /timm/models/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .activations import * 5 | from .activations_jit import * 6 | from .activations_me import * 7 | from .config import is_exportable, is_scriptable, is_no_jit 8 | 9 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code 10 | # will use native version if present. Eventually, the custom Swish layers will be removed 11 | # and only native 'silu' will be used. 12 | _has_silu = 'silu' in dir(torch.nn.functional) 13 | 14 | _ACT_FN_DEFAULT = dict( 15 | silu=F.silu if _has_silu else swish, 16 | swish=F.silu if _has_silu else swish, 17 | mish=mish, 18 | relu=F.relu, 19 | relu6=F.relu6, 20 | leaky_relu=F.leaky_relu, 21 | elu=F.elu, 22 | celu=F.celu, 23 | selu=F.selu, 24 | gelu=gelu, 25 | sigmoid=sigmoid, 26 | tanh=tanh, 27 | hard_sigmoid=hard_sigmoid, 28 | hard_swish=hard_swish, 29 | hard_mish=hard_mish, 30 | ) 31 | 32 | _ACT_FN_JIT = dict( 33 | silu=F.silu if _has_silu else swish_jit, 34 | swish=F.silu if _has_silu else swish_jit, 35 | mish=mish_jit, 36 | hard_sigmoid=hard_sigmoid_jit, 37 | hard_swish=hard_swish_jit, 38 | hard_mish=hard_mish_jit 39 | ) 40 | 41 | _ACT_FN_ME = dict( 42 | silu=F.silu if _has_silu else swish_me, 43 | swish=F.silu if _has_silu else swish_me, 44 | mish=mish_me, 45 | hard_sigmoid=hard_sigmoid_me, 46 | hard_swish=hard_swish_me, 47 | hard_mish=hard_mish_me, 48 | ) 49 | 50 | _ACT_LAYER_DEFAULT = dict( 51 | silu=nn.SiLU if _has_silu else Swish, 52 | swish=nn.SiLU if _has_silu else Swish, 53 | mish=Mish, 54 | relu=nn.ReLU, 55 | relu6=nn.ReLU6, 56 | leaky_relu=nn.LeakyReLU, 57 | elu=nn.ELU, 58 | prelu=PReLU, 59 | celu=nn.CELU, 60 | selu=nn.SELU, 61 | gelu=GELU, 62 | sigmoid=Sigmoid, 63 | tanh=Tanh, 64 | hard_sigmoid=HardSigmoid, 65 | hard_swish=HardSwish, 66 | hard_mish=HardMish, 67 | ) 68 | 69 | _ACT_LAYER_JIT = dict( 70 | silu=nn.SiLU if _has_silu else SwishJit, 71 | swish=nn.SiLU if _has_silu else SwishJit, 72 | mish=MishJit, 73 | hard_sigmoid=HardSigmoidJit, 74 | hard_swish=HardSwishJit, 75 | hard_mish=HardMishJit 76 | ) 77 | 78 | _ACT_LAYER_ME = dict( 79 | silu=nn.SiLU if _has_silu else SwishMe, 80 | swish=nn.SiLU if _has_silu else SwishMe, 81 | mish=MishMe, 82 | hard_sigmoid=HardSigmoidMe, 83 | hard_swish=HardSwishMe, 84 | hard_mish=HardMishMe, 85 | ) 86 | 87 | 88 | def get_act_fn(name='relu'): 89 | """ Activation Function Factory 90 | Fetching activation fns by name with this function allows export or torch script friendly 91 | functions to be returned dynamically based on current config. 92 | """ 93 | if not name: 94 | return None 95 | if not (is_no_jit() or is_exportable() or is_scriptable()): 96 | # If not exporting or scripting the model, first look for a memory-efficient version with 97 | # custom autograd, then fallback 98 | if name in _ACT_FN_ME: 99 | return _ACT_FN_ME[name] 100 | if is_exportable() and name in ('silu', 'swish'): 101 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 102 | return swish 103 | if not (is_no_jit() or is_exportable()): 104 | if name in _ACT_FN_JIT: 105 | return _ACT_FN_JIT[name] 106 | return _ACT_FN_DEFAULT[name] 107 | 108 | 109 | def get_act_layer(name='relu'): 110 | """ Activation Layer Factory 111 | Fetching activation layers by name with this function allows export or torch script friendly 112 | functions to be returned dynamically based on current config. 113 | """ 114 | if not name: 115 | return None 116 | if not (is_no_jit() or is_exportable() or is_scriptable()): 117 | if name in _ACT_LAYER_ME: 118 | return _ACT_LAYER_ME[name] 119 | if is_exportable() and name in ('silu', 'swish'): 120 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 121 | return Swish 122 | if not (is_no_jit() or is_exportable()): 123 | if name in _ACT_LAYER_JIT: 124 | return _ACT_LAYER_JIT[name] 125 | return _ACT_LAYER_DEFAULT[name] 126 | 127 | 128 | def create_act_layer(name, inplace=False, **kwargs): 129 | act_layer = get_act_layer(name) 130 | if act_layer is not None: 131 | return act_layer(inplace=inplace, **kwargs) 132 | else: 133 | return None 134 | -------------------------------------------------------------------------------- /timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Select AttentionFactory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from .se import SEModule, EffectiveSEModule 7 | from .eca import EcaModule, CecaModule 8 | from .cbam import CbamModule, LightCbamModule 9 | 10 | 11 | def create_attn(attn_type, channels, **kwargs): 12 | module_cls = None 13 | if attn_type is not None: 14 | if isinstance(attn_type, str): 15 | attn_type = attn_type.lower() 16 | if attn_type == 'se': 17 | module_cls = SEModule 18 | elif attn_type == 'ese': 19 | module_cls = EffectiveSEModule 20 | elif attn_type == 'eca': 21 | module_cls = EcaModule 22 | elif attn_type == 'ceca': 23 | module_cls = CecaModule 24 | elif attn_type == 'cbam': 25 | module_cls = CbamModule 26 | elif attn_type == 'lcbam': 27 | module_cls = LightCbamModule 28 | else: 29 | assert False, "Invalid attn module (%s)" % attn_type 30 | elif isinstance(attn_type, bool): 31 | if attn_type: 32 | module_cls = SEModule 33 | else: 34 | module_cls = attn_type 35 | if module_cls is not None: 36 | return module_cls(channels, **kwargs) 37 | return None 38 | -------------------------------------------------------------------------------- /timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | groups = out_channels if depthwise else kwargs.pop('groups', 1) 26 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 27 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 28 | else: 29 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 30 | return m 31 | -------------------------------------------------------------------------------- /timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 16 | from .norm_act import BatchNormAct2d, GroupNormAct 17 | from .inplace_abn import InplaceAbn 18 | 19 | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} 20 | _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type 21 | 22 | def get_norm_act_layer(layer_class): 23 | layer_class = layer_class.replace('_', '').lower() 24 | if layer_class.startswith("batchnorm"): 25 | layer = BatchNormAct2d 26 | elif layer_class.startswith("groupnorm"): 27 | layer = GroupNormAct 28 | elif layer_class == "evonormbatch": 29 | layer = EvoNormBatch2d 30 | elif layer_class == "evonormsample": 31 | layer = EvoNormSample2d 32 | elif layer_class == "iabn" or layer_class == "inplaceabn": 33 | layer = InplaceAbn 34 | else: 35 | assert False, "Invalid norm_act layer (%s)" % layer_class 36 | return layer 37 | 38 | 39 | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): 40 | layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu 41 | assert len(layer_parts) in (1, 2) 42 | layer = get_norm_act_layer(layer_parts[0]) 43 | #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? 44 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 45 | if jit: 46 | layer_instance = torch.jit.script(layer_instance) 47 | return layer_instance 48 | 49 | 50 | def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): 51 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 52 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 53 | norm_act_args = norm_kwargs.copy() if norm_kwargs else {} 54 | if isinstance(norm_layer, str): 55 | norm_act_layer = get_norm_act_layer(norm_layer) 56 | elif norm_layer in _NORM_ACT_TYPES: 57 | norm_act_layer = norm_layer 58 | elif isinstance(norm_layer, (types.FunctionType, functools.partial)): 59 | # assuming this is a lambda/fn/bound partial that creates norm_act layer 60 | norm_act_layer = norm_layer 61 | else: 62 | type_name = norm_layer.__name__.lower() 63 | if type_name.startswith('batchnorm'): 64 | norm_act_layer = BatchNormAct2d 65 | elif type_name.startswith('groupnorm'): 66 | norm_act_layer = GroupNormAct 67 | else: 68 | assert False, f"No equivalent norm_act layer for {type_name}" 69 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 70 | # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 71 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 72 | # It is intended that functions/partial does not trigger this, they should define act. 73 | norm_act_args.update(dict(act_layer=act_layer)) 74 | return norm_act_layer, norm_act_args 75 | -------------------------------------------------------------------------------- /timm/models/layers/evo_norm.py: -------------------------------------------------------------------------------- 1 | """EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch 2 | 3 | An attempt at getting decent performing EvoNorms running in PyTorch. 4 | While currently faster than other impl, still quite a ways off the built-in BN 5 | in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). 6 | 7 | Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class EvoNormBatch2d(nn.Module): 17 | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): 18 | super(EvoNormBatch2d, self).__init__() 19 | self.apply_act = apply_act # apply activation (non-linearity) 20 | self.momentum = momentum 21 | self.eps = eps 22 | param_shape = (1, num_features, 1, 1) 23 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 24 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 25 | if apply_act: 26 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 27 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.apply_act: 34 | nn.init.ones_(self.v) 35 | 36 | def forward(self, x): 37 | assert x.dim() == 4, 'expected 4D input' 38 | x_type = x.dtype 39 | if self.training: 40 | var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) 41 | n = x.numel() / x.shape[1] 42 | self.running_var.copy_( 43 | var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) 44 | else: 45 | var = self.running_var 46 | 47 | if self.apply_act: 48 | v = self.v.to(dtype=x_type) 49 | d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) 50 | d = d.max((var + self.eps).sqrt().to(dtype=x_type)) 51 | x = x / d 52 | return x * self.weight + self.bias 53 | 54 | 55 | class EvoNormSample2d(nn.Module): 56 | def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): 57 | super(EvoNormSample2d, self).__init__() 58 | self.apply_act = apply_act # apply activation (non-linearity) 59 | self.groups = groups 60 | self.eps = eps 61 | param_shape = (1, num_features, 1, 1) 62 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 63 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 64 | if apply_act: 65 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | nn.init.ones_(self.weight) 70 | nn.init.zeros_(self.bias) 71 | if self.apply_act: 72 | nn.init.ones_(self.v) 73 | 74 | def forward(self, x): 75 | assert x.dim() == 4, 'expected 4D input' 76 | B, C, H, W = x.shape 77 | assert C % self.groups == 0 78 | if self.apply_act: 79 | n = x * (x * self.v).sigmoid() 80 | x = x.reshape(B, self.groups, -1) 81 | x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() 82 | x = x.reshape(B, C, H, W) 83 | return x * self.weight + self.bias 84 | -------------------------------------------------------------------------------- /timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import torch 7 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 8 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 9 | 10 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 11 | from torch._six import container_abcs 12 | else: 13 | import collections.abc as container_abcs 14 | 15 | 16 | # From PyTorch internals 17 | def _ntuple(n): 18 | def parse(x): 19 | if isinstance(x, container_abcs.Iterable): 20 | return x 21 | return tuple(repeat(x, n)) 22 | return parse 23 | 24 | 25 | to_1tuple = _ntuple(1) 26 | to_2tuple = _ntuple(2) 27 | to_3tuple = _ntuple(3) 28 | to_4tuple = _ntuple(4) 29 | to_ntuple = _ntuple 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_block=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = out_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/norm_act.py: -------------------------------------------------------------------------------- 1 | """ Normalization + Activation Layers 2 | """ 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .create_act import get_act_layer 8 | 9 | 10 | class BatchNormAct2d(nn.BatchNorm2d): 11 | """BatchNorm + Activation 12 | 13 | This module performs BatchNorm + Activation in a manner that will remain backwards 14 | compatible with weights trained with separate bn, act. This is why we inherit from BN 15 | instead of composing it as a .bn member. 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, 18 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 19 | super(BatchNormAct2d, self).__init__( 20 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 21 | if isinstance(act_layer, str): 22 | act_layer = get_act_layer(act_layer) 23 | if act_layer is not None and apply_act: 24 | act_args = dict(inplace=True) if inplace else {} 25 | self.act = act_layer(**act_args) 26 | else: 27 | self.act = None 28 | 29 | def _forward_jit(self, x): 30 | """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function 31 | """ 32 | # exponential_average_factor is self.momentum set to 33 | # (when it is available) only so that if gets updated 34 | # in ONNX graph when this node is exported to ONNX. 35 | if self.momentum is None: 36 | exponential_average_factor = 0.0 37 | else: 38 | exponential_average_factor = self.momentum 39 | 40 | if self.training and self.track_running_stats: 41 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 42 | if self.num_batches_tracked is not None: 43 | self.num_batches_tracked += 1 44 | if self.momentum is None: # use cumulative moving average 45 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 46 | else: # use exponential moving average 47 | exponential_average_factor = self.momentum 48 | 49 | x = F.batch_norm( 50 | x, self.running_mean, self.running_var, self.weight, self.bias, 51 | self.training or not self.track_running_stats, 52 | exponential_average_factor, self.eps) 53 | return x 54 | 55 | @torch.jit.ignore 56 | def _forward_python(self, x): 57 | return super(BatchNormAct2d, self).forward(x) 58 | 59 | def forward(self, x): 60 | # FIXME cannot call parent forward() and maintain jit.script compatibility? 61 | if torch.jit.is_scripting(): 62 | x = self._forward_jit(x) 63 | else: 64 | x = self._forward_python(x) 65 | if self.act is not None: 66 | x = self.act(x) 67 | return x 68 | 69 | 70 | class GroupNormAct(nn.GroupNorm): 71 | 72 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, 73 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 74 | super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) 75 | if isinstance(act_layer, str): 76 | act_layer = get_act_layer(act_layer) 77 | if act_layer is not None and apply_act: 78 | self.act = act_layer(inplace=inplace) 79 | else: 80 | self.act = None 81 | 82 | def forward(self, x): 83 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 84 | if self.act is not None: 85 | x = self.act(x) 86 | return x 87 | -------------------------------------------------------------------------------- /timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | return avg_pool2d_same( 31 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 32 | 33 | 34 | def max_pool2d_same( 35 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 36 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 37 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 38 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 39 | 40 | 41 | class MaxPool2dSame(nn.MaxPool2d): 42 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 43 | """ 44 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): 45 | kernel_size = to_2tuple(kernel_size) 46 | stride = to_2tuple(stride) 47 | dilation = to_2tuple(dilation) 48 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) 49 | 50 | def forward(self, x): 51 | return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) 52 | 53 | 54 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 55 | stride = stride or kernel_size 56 | padding = kwargs.pop('padding', '') 57 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 58 | if is_dynamic: 59 | if pool_type == 'avg': 60 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 61 | elif pool_type == 'max': 62 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 63 | else: 64 | assert False, f'Unsupported pool type {pool_type}' 65 | else: 66 | if pool_type == 'avg': 67 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 68 | elif pool_type == 'max': 69 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | else: 71 | assert False, f'Unsupported pool type {pool_type}' 72 | -------------------------------------------------------------------------------- /timm/models/layers/se.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from .create_act import create_act_layer 3 | 4 | 5 | class SEModule(nn.Module): 6 | 7 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, 8 | gate_layer='sigmoid'): 9 | super(SEModule, self).__init__() 10 | reduction_channels = reduction_channels or max(channels // reduction, min_channels) 11 | self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) 12 | self.act = act_layer(inplace=True) 13 | self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) 14 | self.gate = create_act_layer(gate_layer) 15 | 16 | def forward(self, x): 17 | x_se = x.mean((2, 3), keepdim=True) 18 | x_se = self.fc1(x_se) 19 | x_se = self.act(x_se) 20 | x_se = self.fc2(x_se) 21 | return x * self.gate(x_se) 22 | 23 | 24 | class EffectiveSEModule(nn.Module): 25 | """ 'Effective Squeeze-Excitation 26 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 27 | """ 28 | def __init__(self, channels, gate_layer='hard_sigmoid'): 29 | super(EffectiveSEModule, self).__init__() 30 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 31 | self.gate = create_act_layer(gate_layer, inplace=True) 32 | 33 | def forward(self, x): 34 | x_se = x.mean((2, 3), keepdim=True) 35 | x_se = self.fc(x_se) 36 | return x * self.gate(x_se) 37 | -------------------------------------------------------------------------------- /timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import convert_norm_act_type 12 | 13 | 14 | class SeparableConvBnAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, 19 | act_layer=nn.ReLU, apply_act=True, drop_block=None): 20 | super(SeparableConvBnAct, self).__init__() 21 | norm_kwargs = norm_kwargs or {} 22 | 23 | self.conv_dw = create_conv2d( 24 | in_channels, int(in_channels * channel_multiplier), kernel_size, 25 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 26 | 27 | self.conv_pw = create_conv2d( 28 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 29 | 30 | norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | if self.bn is not None: 45 | x = self.bn(x) 46 | return x 47 | 48 | 49 | class SeparableConv2d(nn.Module): 50 | """ Separable Conv 51 | """ 52 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 53 | channel_multiplier=1.0, pw_kernel_size=1): 54 | super(SeparableConv2d, self).__init__() 55 | 56 | self.conv_dw = create_conv2d( 57 | in_channels, int(in_channels * channel_multiplier), kernel_size, 58 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 59 | 60 | self.conv_pw = create_conv2d( 61 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 62 | 63 | @property 64 | def in_channels(self): 65 | return self.conv_dw.in_channels 66 | 67 | @property 68 | def out_channels(self): 69 | return self.conv_pw.out_channels 70 | 71 | def forward(self, x): 72 | x = self.conv_dw(x) 73 | x = self.conv_pw(x) 74 | return x 75 | -------------------------------------------------------------------------------- /timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class RadixSoftmax(nn.Module): 15 | def __init__(self, radix, cardinality): 16 | super(RadixSoftmax, self).__init__() 17 | self.radix = radix 18 | self.cardinality = cardinality 19 | 20 | def forward(self, x): 21 | batch = x.size(0) 22 | if self.radix > 1: 23 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 24 | x = F.softmax(x, dim=1) 25 | x = x.reshape(batch, -1) 26 | else: 27 | x = torch.sigmoid(x) 28 | return x 29 | 30 | 31 | class SplitAttnConv2d(nn.Module): 32 | """Split-Attention Conv2d 33 | """ 34 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 35 | dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, 36 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 37 | super(SplitAttnConv2d, self).__init__() 38 | self.radix = radix 39 | self.drop_block = drop_block 40 | mid_chs = out_channels * radix 41 | attn_chs = max(in_channels * radix // reduction_factor, 32) 42 | 43 | self.conv = nn.Conv2d( 44 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 45 | groups=groups * radix, bias=bias, **kwargs) 46 | self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None 47 | self.act0 = act_layer(inplace=True) 48 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 49 | self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None 50 | self.act1 = act_layer(inplace=True) 51 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 52 | self.rsoftmax = RadixSoftmax(radix, groups) 53 | 54 | @property 55 | def in_channels(self): 56 | return self.conv.in_channels 57 | 58 | @property 59 | def out_channels(self): 60 | return self.fc1.out_channels 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | if self.bn0 is not None: 65 | x = self.bn0(x) 66 | if self.drop_block is not None: 67 | x = self.drop_block(x) 68 | x = self.act0(x) 69 | 70 | B, RC, H, W = x.shape 71 | if self.radix > 1: 72 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 73 | x_gap = x.sum(dim=1) 74 | else: 75 | x_gap = x 76 | x_gap = F.adaptive_avg_pool2d(x_gap, 1) 77 | x_gap = self.fc1(x_gap) 78 | if self.bn1 is not None: 79 | x_gap = self.bn1(x_gap) 80 | x_gap = self.act1(x_gap) 81 | x_attn = self.fc2(x_gap) 82 | 83 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 84 | if self.radix > 1: 85 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 86 | else: 87 | out = x * x_attn 88 | return out.contiguous() 89 | -------------------------------------------------------------------------------- /timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and 44 | config['input_size'][-2] > model.default_cfg['input_size'][-2]): 45 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 46 | (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) 47 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 48 | test_time_pool = True 49 | return model, test_time_pool 50 | -------------------------------------------------------------------------------- /timm/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | -------------------------------------------------------------------------------- /timm/models/pruned/ecaresnet50d_pruned.txt: -------------------------------------------------------------------------------- 1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | 10 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] 11 | 12 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 13 | _model_to_module = {} # mapping of model names to module names 14 | _model_entrypoints = {} # mapping of model names to entrypoint fns 15 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 16 | 17 | 18 | def register_model(fn): 19 | # lookup containing module 20 | mod = sys.modules[fn.__module__] 21 | module_name_split = fn.__module__.split('.') 22 | module_name = module_name_split[-1] if len(module_name_split) else '' 23 | 24 | # add model to __all__ in module 25 | model_name = fn.__name__ 26 | if hasattr(mod, '__all__'): 27 | mod.__all__.append(model_name) 28 | else: 29 | mod.__all__ = [model_name] 30 | 31 | # add entries to registry dict/sets 32 | _model_entrypoints[model_name] = fn 33 | _model_to_module[model_name] = module_name 34 | _module_to_models[module_name].add(model_name) 35 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 36 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 37 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 38 | # entrypoints or non-matching combos 39 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 40 | if has_pretrained: 41 | _model_has_pretrained.add(model_name) 42 | return fn 43 | 44 | 45 | def _natural_key(string_): 46 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 47 | 48 | 49 | def list_models(filter='', module='', pretrained=False, exclude_filters=''): 50 | """ Return list of available model names, sorted alphabetically 51 | 52 | Args: 53 | filter (str) - Wildcard filter string that works with fnmatch 54 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 55 | pretrained (bool) - Include only models with pretrained weights if True 56 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 57 | 58 | Example: 59 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 60 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 61 | """ 62 | if module: 63 | models = list(_module_to_models[module]) 64 | else: 65 | models = _model_entrypoints.keys() 66 | if filter: 67 | models = fnmatch.filter(models, filter) # include these models 68 | if exclude_filters: 69 | if not isinstance(exclude_filters, list): 70 | exclude_filters = [exclude_filters] 71 | for xf in exclude_filters: 72 | exclude_models = fnmatch.filter(models, xf) # exclude these models 73 | if len(exclude_models): 74 | models = set(models).difference(exclude_models) 75 | if pretrained: 76 | models = _model_has_pretrained.intersection(models) 77 | return list(sorted(models, key=_natural_key)) 78 | 79 | 80 | def is_model(model_name): 81 | """ Check if a model name exists 82 | """ 83 | return model_name in _model_entrypoints 84 | 85 | 86 | def model_entrypoint(model_name): 87 | """Fetch a model entrypoint for specified model name 88 | """ 89 | return _model_entrypoints[model_name] 90 | 91 | 92 | def list_modules(): 93 | """ Return list of module names that contain models / model entrypoints 94 | """ 95 | modules = _module_to_models.keys() 96 | return list(sorted(modules)) 97 | 98 | 99 | def is_model_in_modules(model_name, module_names): 100 | """Check if a model exists within a subset of modules 101 | Args: 102 | model_name (str) - name of model to check 103 | module_names (tuple, list, set) - names of modules to search in 104 | """ 105 | assert isinstance(module_names, (tuple, list, set)) 106 | return any(model_name in _module_to_models[n] for n in module_names) 107 | 108 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | 13 | from .optim_factory import create_optimizer -------------------------------------------------------------------------------- /timm/optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /timm/optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /timm/scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class CosineLRScheduler(Scheduler): 19 | """ 20 | Cosine decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1608.03983. 22 | 23 | Inspiration from 24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | t_initial: int, 30 | t_mul: float = 1., 31 | lr_min: float = 0., 32 | decay_rate: float = 1., 33 | warmup_t=0, 34 | warmup_lr_init=0, 35 | warmup_prefix=False, 36 | cycle_limit=0, 37 | t_in_epochs=True, 38 | noise_range_t=None, 39 | noise_pct=0.67, 40 | noise_std=1.0, 41 | noise_seed=42, 42 | initialize=True) -> None: 43 | super().__init__( 44 | optimizer, param_group_field="lr", 45 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 46 | initialize=initialize) 47 | 48 | assert t_initial > 0 49 | assert lr_min >= 0 50 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 51 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 52 | "rate since t_initial = t_mul = eta_mul = 1.") 53 | self.t_initial = t_initial 54 | self.t_mul = t_mul 55 | self.lr_min = lr_min 56 | self.decay_rate = decay_rate 57 | self.cycle_limit = cycle_limit 58 | self.warmup_t = warmup_t 59 | self.warmup_lr_init = warmup_lr_init 60 | self.warmup_prefix = warmup_prefix 61 | self.t_in_epochs = t_in_epochs 62 | if self.warmup_t: 63 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 64 | super().update_groups(self.warmup_lr_init) 65 | else: 66 | self.warmup_steps = [1 for _ in self.base_values] 67 | 68 | def _get_lr(self, t): 69 | if t < self.warmup_t: 70 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 71 | else: 72 | if self.warmup_prefix: 73 | t = t - self.warmup_t 74 | 75 | if self.t_mul != 1: 76 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 77 | t_i = self.t_mul ** i * self.t_initial 78 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 79 | else: 80 | i = t // self.t_initial 81 | t_i = self.t_initial 82 | t_curr = t - (self.t_initial * i) 83 | 84 | gamma = self.decay_rate ** i 85 | lr_min = self.lr_min * gamma 86 | lr_max_values = [v * gamma for v in self.base_values] 87 | 88 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 89 | lrs = [ 90 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 91 | ] 92 | else: 93 | lrs = [self.lr_min for _ in self.base_values] 94 | 95 | return lrs 96 | 97 | def get_epoch_values(self, epoch: int): 98 | if self.t_in_epochs: 99 | return self._get_lr(epoch) 100 | else: 101 | return None 102 | 103 | def get_update_values(self, num_updates: int): 104 | if not self.t_in_epochs: 105 | return self._get_lr(num_updates) 106 | else: 107 | return None 108 | 109 | def get_cycle_length(self, cycles=0): 110 | if not cycles: 111 | cycles = self.cycle_limit 112 | cycles = max(1, cycles) 113 | if self.t_mul == 1.0: 114 | return self.t_initial * cycles 115 | else: 116 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 117 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | 9 | 10 | def create_scheduler(args, optimizer): 11 | num_epochs = args.epochs 12 | 13 | if getattr(args, 'lr_noise', None) is not None: 14 | lr_noise = getattr(args, 'lr_noise') 15 | if isinstance(lr_noise, (list, tuple)): 16 | noise_range = [n * num_epochs for n in lr_noise] 17 | if len(noise_range) == 1: 18 | noise_range = noise_range[0] 19 | else: 20 | noise_range = lr_noise * num_epochs 21 | else: 22 | noise_range = None 23 | 24 | lr_scheduler = None 25 | if args.sched == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_epochs, 29 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 30 | lr_min=args.min_lr, 31 | decay_rate=args.decay_rate, 32 | warmup_lr_init=args.warmup_lr, 33 | warmup_t=args.warmup_epochs, 34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 35 | t_in_epochs=True, 36 | noise_range_t=noise_range, 37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 38 | noise_std=getattr(args, 'lr_noise_std', 1.), 39 | noise_seed=getattr(args, 'seed', 42), 40 | ) 41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 42 | elif args.sched == 'tanh': 43 | lr_scheduler = TanhLRScheduler( 44 | optimizer, 45 | t_initial=num_epochs, 46 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 47 | lr_min=args.min_lr, 48 | warmup_lr_init=args.warmup_lr, 49 | warmup_t=args.warmup_epochs, 50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 51 | t_in_epochs=True, 52 | noise_range_t=noise_range, 53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 54 | noise_std=getattr(args, 'lr_noise_std', 1.), 55 | noise_seed=getattr(args, 'seed', 42), 56 | ) 57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 58 | elif args.sched == 'step': 59 | lr_scheduler = StepLRScheduler( 60 | optimizer, 61 | decay_t=args.decay_epochs, 62 | decay_rate=args.decay_rate, 63 | warmup_lr_init=args.warmup_lr, 64 | warmup_t=args.warmup_epochs, 65 | noise_range_t=noise_range, 66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 67 | noise_std=getattr(args, 'lr_noise_std', 1.), 68 | noise_seed=getattr(args, 'seed', 42), 69 | ) 70 | elif args.sched == 'plateau': 71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 72 | lr_scheduler = PlateauLRScheduler( 73 | optimizer, 74 | decay_rate=args.decay_rate, 75 | patience_t=args.patience_epochs, 76 | lr_min=args.min_lr, 77 | mode=mode, 78 | warmup_lr_init=args.warmup_lr, 79 | warmup_t=args.warmup_epochs, 80 | cooldown_t=0, 81 | noise_range_t=noise_range, 82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 83 | noise_std=getattr(args, 'lr_noise_std', 1.), 84 | noise_seed=getattr(args, 'seed', 42), 85 | ) 86 | 87 | return lr_scheduler, num_epochs 88 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_saver import CheckpointSaver 2 | from .cuda import ApexScaler, NativeScaler 3 | from .distributed import distribute_bn, reduce_tensor 4 | from .jit import set_jit_legacy 5 | from .log import setup_default_logging, FormatterNoInfo 6 | from .metrics import AverageMeter, accuracy 7 | from .misc import natural_key, add_bool_arg 8 | from .model import unwrap_model, get_state_dict 9 | from .model_ema import ModelEma, ModelEmaV2 10 | from .summary import update_summary, get_outdir 11 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | 15 | class ApexScaler: 16 | state_dict_key = "amp" 17 | 18 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): 19 | with amp.scale_loss(loss, optimizer) as scaled_loss: 20 | scaled_loss.backward(create_graph=create_graph) 21 | if clip_grad is not None: 22 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) 23 | optimizer.step() 24 | 25 | def state_dict(self): 26 | if 'state_dict' in amp.__dict__: 27 | return amp.state_dict() 28 | 29 | def load_state_dict(self, state_dict): 30 | if 'load_state_dict' in amp.__dict__: 31 | amp.load_state_dict(state_dict) 32 | 33 | 34 | class NativeScaler: 35 | state_dict_key = "amp_scaler" 36 | 37 | def __init__(self): 38 | self._scaler = torch.cuda.amp.GradScaler() 39 | 40 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): 41 | self._scaler.scale(loss).backward(create_graph=create_graph) 42 | if clip_grad is not None: 43 | assert parameters is not None 44 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 45 | torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 46 | self._scaler.step(optimizer) 47 | self._scaler.update() 48 | 49 | def state_dict(self): 50 | return self._scaler.state_dict() 51 | 52 | def load_state_dict(self, state_dict): 53 | self._scaler.load_state_dict(state_dict) 54 | -------------------------------------------------------------------------------- /timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | 8 | def set_jit_legacy(): 9 | """ Set JIT executor to legacy w/ support for op fusion 10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 11 | in the JIT exectutor. These API are not supported so could change. 12 | """ 13 | # 14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 15 | torch._C._jit_set_profiling_executor(False) 16 | torch._C._jit_set_profiling_mode(False) 17 | torch._C._jit_override_can_fuse_on_gpu(True) 18 | #torch._C._jit_set_texpr_fuser_enabled(True) 19 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /timm/utils/model.py: -------------------------------------------------------------------------------- 1 | """ Model / state_dict utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from .model_ema import ModelEma 6 | 7 | 8 | def unwrap_model(model): 9 | if isinstance(model, ModelEma): 10 | return unwrap_model(model.ema) 11 | else: 12 | return model.module if hasattr(model, 'module') else model 13 | 14 | 15 | def get_state_dict(model, unwrap_fn=unwrap_model): 16 | return unwrap_fn(model).state_dict() 17 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | 9 | 10 | def get_outdir(path, *paths, inc=False): 11 | outdir = os.path.join(path, *paths) 12 | if not os.path.exists(outdir): 13 | os.makedirs(outdir) 14 | elif inc: 15 | count = 1 16 | outdir_inc = outdir + '-' + str(count) 17 | while os.path.exists(outdir_inc): 18 | count = count + 1 19 | outdir_inc = outdir + '-' + str(count) 20 | assert count < 100 21 | outdir = outdir_inc 22 | os.makedirs(outdir) 23 | return outdir 24 | 25 | 26 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): 27 | rowd = OrderedDict(epoch=epoch) 28 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 29 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 30 | with open(filename, mode='a') as cf: 31 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 32 | if write_header: # first iteration (epoch == 1 can't be used) 33 | dw.writeheader() 34 | dw.writerow(rowd) 35 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.2' 2 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zideliu/StyleDrop-PyTorch/3538a60ca79ac57d5ca04be59dcbd13d48f68921/tools/__init__.py --------------------------------------------------------------------------------