├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── azure-pipelines.yml ├── docs ├── G0mjFqytJt4_000152_000162.mp4 ├── m2_qmRnjICE_000017_000027.mp4 └── swinbert-overview.png ├── launch_container.sh ├── models └── captioning │ └── bert-base-uncased │ ├── added_tokens.json │ ├── config.json │ ├── special_tokens_map.json │ └── vocab.txt ├── prepro ├── create_image_frame_tsv.py ├── extract_frames.py ├── extract_youcook2_frms.sh ├── tsv_preproc_msrvtt.py ├── tsv_preproc_msvd.py ├── tsv_preproc_tvc.py ├── tsv_preproc_vatex.py └── tsv_preproc_youcook2.py ├── scripts ├── download_annotations.sh ├── download_models.sh └── download_value_preds.sh ├── setup.sh └── src ├── configs ├── VidSwinBert │ ├── local_msrvtt_debug.json │ ├── msrvtt_8frm_default.json │ ├── msvd_8frm_default.json │ ├── tvc_8frm_default.json │ ├── vatex_8frm_default.json │ └── youcook2_8frm_default.json └── config.py ├── datasets ├── caption_tensorizer.py ├── data_sampler.py ├── data_utils │ ├── image_ops.py │ ├── video_decoder.py │ ├── video_functional.py │ ├── video_ops.py │ ├── video_transforms.py │ └── volume_transforms.py ├── sampler_utils.py ├── vision_language_tsv.py └── vl_dataloader.py ├── evalcap └── utils_caption_evaluate.py ├── layers └── bert │ ├── __init__.py │ ├── file_utils.py │ ├── modeling_bert.py │ ├── modeling_utils.py │ ├── tokenization_bert.py │ └── tokenization_utils.py ├── modeling ├── load_bert.py ├── load_swin.py ├── swin │ ├── __init__.py │ ├── build.py │ ├── config.py │ ├── swin_base_patch4_window12_384.yaml │ ├── swin_base_patch4_window7_224.yaml │ ├── swin_base_patch4_window7_224_22k.yaml │ ├── swin_large_patch4_window12_384.yaml │ ├── swin_large_patch4_window7_224.yaml │ ├── swin_small_patch4_window7_224.yaml │ ├── swin_tiny_patch4_window7_224.yaml │ └── swin_transformer.py ├── video_captioning_e2e_vid_swin_bert.py └── video_swin │ ├── config.py │ ├── default_runtime.py │ ├── swin_base.py │ ├── swin_base_patch244_window1677_sthv2.py │ ├── swin_base_patch244_window877_kinetics400_1k.py │ ├── swin_base_patch244_window877_kinetics400_22k.py │ ├── swin_base_patch244_window877_kinetics600_22k.py │ ├── swin_large.py │ ├── swin_large_384_patch244_window81212_kinetics400_22k.py │ ├── swin_large_384_patch244_window81212_kinetics600_22k.py │ ├── swin_large_patch244_window877_kinetics400_22k.py │ ├── swin_small_patch244_window877_kinetics400_1k.py │ ├── swin_tiny.py │ ├── swin_tiny_patch244_window877_kinetics400_1k.py │ └── swin_transformer.py ├── solver ├── LARC.py ├── __init__.py ├── bertadam.py ├── build.py ├── get_solver.py ├── lr_scheduler.py └── optimization.py ├── tasks ├── run_caption_VidSwinBert.py └── run_caption_VidSwinBert_inference.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── parsers │ │ ├── __init__.py │ │ ├── class_map.py │ │ ├── constants.py │ │ ├── parser.py │ │ ├── parser_factory.py │ │ ├── parser_image_folder.py │ │ ├── parser_image_in_tar.py │ │ ├── parser_image_tar.py │ │ └── parser_tfds.py │ ├── random_erasing.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── models │ ├── __init__.py │ ├── byoanet.py │ ├── byobnet.py │ ├── cspnet.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── features.py │ ├── ghostnet.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── hardcorenas.py │ ├── helpers.py │ ├── hrnet.py │ ├── hub.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── anti_aliasing.py │ │ ├── blur_pool.py │ │ ├── bottleneck_attn.py │ │ ├── cbam.py │ │ ├── classifier.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm_act.py │ │ ├── create_self_attn.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── halo_attn.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── lambda_layer.py │ │ ├── linear.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── norm.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── pool2d_same.py │ │ ├── se.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── std_conv.py │ │ ├── test_time_pool.py │ │ └── weight_init.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── nfnet.py │ ├── pit.py │ ├── pnasnet.py │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── resnetv2.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── swin_transformer.py │ ├── tnt.py │ ├── tresnet.py │ ├── vgg.py │ ├── vision_transformer.py │ ├── vision_transformer_hybrid.py │ ├── vovnet.py │ ├── xception.py │ └── xception_aligned.py ├── utils │ ├── __init__.py │ ├── checkpoint_saver.py │ ├── cuda.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ └── summary.py └── version.py └── utils ├── __init__.py ├── basic_utils.py ├── cloud_storage.py ├── comm.py ├── deepspeed.py ├── latex_writer.py ├── load_files.py ├── load_save.py ├── logger.py ├── metric_logger.py ├── miscellaneous.py ├── qd_common.py ├── tsv_file.py ├── tsv_file_ops.py └── tsv_io.py /.gitignore: -------------------------------------------------------------------------------- 1 | # @linjli 2 | # compilation and distribution 3 | __pycache__ 4 | _ext 5 | *.pyc 6 | *.so 7 | # maskrcnn_benchmark.egg-info/ 8 | cpu_soft_nms.egg-info/ 9 | build/ 10 | dist/ 11 | src/evalcap/coco_caption/ 12 | src/evalcap/cider/ 13 | src/evalcap/coco_caption 14 | src/evalcap/cider 15 | # amulet/longer_seq_exp_linjie 16 | 17 | .vscode 18 | 19 | # script 20 | tmp_all/script/ 21 | 22 | # Philly-realted # 23 | pt/ 24 | .ptconfig 25 | .azureml 26 | .amltconfig 27 | 28 | 29 | 30 | # Project-related # 31 | */*results*/ 32 | *results*/ 33 | tmp*/ 34 | cache/* 35 | */cache*/ 36 | tmp*.py 37 | *pickle 38 | *output*/ 39 | */*output*/ 40 | linjli 41 | ./datasets 42 | # ./models 43 | 44 | 45 | # compiled files # 46 | *.pyc 47 | 48 | # Packages # 49 | ############ 50 | # it's better to unpack these files and commit the raw source 51 | # git has its own built in compression methods 52 | *.7z 53 | *.dmg 54 | *.gz 55 | *.iso 56 | *.jar 57 | *.rar 58 | *.tar 59 | *.zip 60 | 61 | # Logs and databases # 62 | ###################### 63 | *.log 64 | *.sql 65 | *.sqlite 66 | .ipynb_checkpoints/ 67 | *.swp 68 | *.vscode/ 69 | *.idea/ 70 | 71 | # OS generated files # 72 | ###################### 73 | .DS_Store 74 | .DS_Store? 75 | ._* 76 | .Spotlight-V100 77 | .Trashes 78 | ehthumbs.db 79 | Thumbs.db 80 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Starter pipeline 2 | # Start with a minimal pipeline that you can customize to build and deploy your code. 3 | # Add steps that build, run tests, deploy, and more: 4 | # https://aka.ms/yaml 5 | 6 | trigger: 7 | - main 8 | 9 | pool: 10 | vmImage: ubuntu-latest 11 | 12 | steps: 13 | - script: echo Hello, world! 14 | displayName: 'Run a one-line script' 15 | 16 | - script: | 17 | python -m pip install --upgrade pip 18 | pip install --upgrade azureml-core 19 | displayName: 'Run a minimal setup' 20 | -------------------------------------------------------------------------------- /docs/G0mjFqytJt4_000152_000162.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SwinBERT/03116f1f3fd7e42d4700a25090f13aa2aa253011/docs/G0mjFqytJt4_000152_000162.mp4 -------------------------------------------------------------------------------- /docs/m2_qmRnjICE_000017_000027.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SwinBERT/03116f1f3fd7e42d4700a25090f13aa2aa253011/docs/m2_qmRnjICE_000017_000027.mp4 -------------------------------------------------------------------------------- /docs/swinbert-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SwinBERT/03116f1f3fd7e42d4700a25090f13aa2aa253011/docs/swinbert-overview.png -------------------------------------------------------------------------------- /launch_container.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=$1 2 | MODEL_DIR=$2 3 | OUTPUT=$3 4 | 5 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 6 | CUDA_VISIBLE_DEVICES='all' 7 | fi 8 | 9 | if [ "$4" = "--prepro" ]; then 10 | RO="" 11 | else 12 | RO=",readonly" 13 | fi 14 | 15 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ 16 | --mount src=$(pwd),dst=/videocap,type=bind \ 17 | --mount src=$DATA_DIR,dst=/videocap/datasets,type=bind$RO \ 18 | --mount src=$MODEL_DIR,dst=/videocap/models,type=bind,readonly \ 19 | --mount src=$OUTPUT,dst=/videocap/output,type=bind \ 20 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 21 | -w /videocap linjieli222/videocap_torch1.7:fairscale \ 22 | bash -c "source /videocap/setup.sh && bash" 23 | -------------------------------------------------------------------------------- /models/captioning/bert-base-uncased/added_tokens.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /models/captioning/bert-base-uncased/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /models/captioning/bert-base-uncased/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /prepro/extract_youcook2_frms.sh: -------------------------------------------------------------------------------- 1 | python ./prepro/extract_frames.py \ 2 | --video_root_dir ./datasets/YouCook2/raw_videos/training/ \ 3 | --save_dir ./datasets/YouCook2/ \ 4 | --video_info_tsv ./datasets/YouCook2/training.img.tsv \ 5 | --num_frames 32 \ 6 | # --debug 7 | 8 | 9 | python ./prepro/create_image_frame_tsv.py \ 10 | --dataset YouCook2 \ 11 | --split training \ 12 | --image_size 256 \ 13 | --num_frames 32 \ -------------------------------------------------------------------------------- /scripts/download_annotations.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/datasets ] ; then 6 | mkdir -p $REPO_DIR/datasets 7 | fi 8 | BLOB='https://datarelease.blob.core.windows.net/swinbert' 9 | 10 | # -------------------------------- 11 | # Download caption annotations pre-parsed in TSV format 12 | # -------------------------------- 13 | 14 | for DATASET in 'VATEX' 'MSRVTT-v2' 'TVC' 'YouCook2' 'MSVD' 15 | do 16 | wget -nc $BLOB/datasets/${DATASET}.zip -O $REPO_DIR/datasets/${DATASET}.zip 17 | unzip $REPO_DIR/datasets/${DATASET}.zip -d $REPO_DIR/datasets/ 18 | rm $REPO_DIR/datasets/${DATASET}.zip 19 | done 20 | 21 | # -------------------------------- 22 | # Note: Due to copyright issue, we are not able to release raw video files 23 | # Please visit each dataset website to download the videos 24 | # -------------------------------- -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/models ] ; then 6 | mkdir -p $REPO_DIR/models 7 | fi 8 | if [ ! -d $REPO_DIR/models/table1 ] ; then 9 | mkdir -p $REPO_DIR/models/table1 10 | fi 11 | if [ ! -d $REPO_DIR/models/32frm ] ; then 12 | mkdir -p $REPO_DIR/models/32frm 13 | fi 14 | BLOB='https://datarelease.blob.core.windows.net/swinbert' 15 | 16 | 17 | # -------------------------------- 18 | # Download our best performing checkpoints for each dataset (corresponding to Table 1 in paper) 19 | # -------------------------------- 20 | 21 | for DATASET in 'vatex' 'msrvtt' 'tvc' 'youcook2' 'msvd' 22 | do 23 | wget -nc $BLOB/models/${DATASET}-table1.zip -O $REPO_DIR/models/table1/${DATASET}-table1.zip 24 | unzip $REPO_DIR/models/table1/${DATASET}-table1.zip -d $REPO_DIR/models/table1/${DATASET}/ 25 | rm $REPO_DIR/models/table1/${DATASET}-table1.zip 26 | done 27 | 28 | 29 | # -------------------------------- 30 | # Download our 32-frame-based model 31 | # -------------------------------- 32 | 33 | for DATASET in 'vatex' 'tvc' 'youcook2' 'msvd' 34 | do 35 | wget -nc $BLOB/models/${DATASET}-32frm.zip -O $REPO_DIR/models/32frm/${DATASET}-32frm.zip 36 | unzip $REPO_DIR/models/32frm/${DATASET}-32frm.zip -d $REPO_DIR/models/32frm/${DATASET}/ 37 | rm $REPO_DIR/models/32frm/${DATASET}-32frm.zip 38 | done 39 | -------------------------------------------------------------------------------- /scripts/download_value_preds.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/value-submit ] ; then 6 | mkdir -p $REPO_DIR/value-submit 7 | fi 8 | 9 | BLOB='https://datarelease.blob.core.windows.net/swinbert' 10 | 11 | 12 | # -------------------------------- 13 | # Download our prediction files that were evaluated on VALUE Leaderboard Evaluation Server 14 | # -------------------------------- 15 | 16 | wget -nc $BLOB/swinbert-value-submit.zip -O $REPO_DIR/value-submit/swinbert-value-submit.zip -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | ln -s /evalcap/coco_caption /videocap/src/evalcap/coco_caption 2 | ln -s /evalcap/cider /videocap/src/evalcap/cider 3 | pip install fvcore ete3 transformers 4 | pip install --upgrade azureml-core 5 | df -h 6 | ls -al 7 | export TORCH_HOME=/models 8 | -------------------------------------------------------------------------------- /src/configs/VidSwinBert/local_msrvtt_debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "MSRVTT-v2/train_32frames.yaml", 6 | "val_yaml": "MSRVTT-v2/val_32frames.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 4, 15 | "per_gpu_train_batch_size": 4, 16 | "num_workers": 4, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/local_debug", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.15, 27 | "max_masked_tokens": 3, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 0.00015, 34 | "backbone_coef_lr": 0.01, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 4, 40 | "mixed_precision_method": "apex", 41 | "amp_opt_level": 2, 42 | "deepspeed_fp16": false, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": -1, 45 | "debug": false, 46 | "seed": 88 47 | } -------------------------------------------------------------------------------- /src/configs/VidSwinBert/msrvtt_8frm_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "MSRVTT-v2/train_32frames.yaml", 6 | "val_yaml": "MSRVTT-v2/val_32frames.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 36, 15 | "per_gpu_train_batch_size": 36, 16 | "num_workers": 10, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/msrvtt_8frm_default", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.5, 27 | "max_masked_tokens": 45, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 3e-4, 34 | "backbone_coef_lr": 0.05, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "mixed_precision_method": "deepspeed", 41 | "amp_opt_level": 0, 42 | "deepspeed_fp16": true, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": 1, 45 | "restore_ratio": -1, 46 | "debug": false, 47 | "debug_speed": false, 48 | "seed": 88 49 | } -------------------------------------------------------------------------------- /src/configs/VidSwinBert/msvd_8frm_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "MSVD/train_32frames.yaml", 6 | "val_yaml": "MSVD/val_32frames.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 36, 15 | "per_gpu_train_batch_size": 36, 16 | "num_workers": 10, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/msvd_8frm_default", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.5, 27 | "max_masked_tokens": 45, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 3e-4, 34 | "backbone_coef_lr": 0.05, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "mixed_precision_method": "deepspeed", 41 | "amp_opt_level": 0, 42 | "deepspeed_fp16": true, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": 1, 45 | "restore_ratio": -1, 46 | "debug": false, 47 | "debug_speed": false, 48 | "seed": 88 49 | } -------------------------------------------------------------------------------- /src/configs/VidSwinBert/tvc_8frm_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "TVC/train.yaml", 6 | "val_yaml": "TVC/val.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 36, 15 | "per_gpu_train_batch_size": 36, 16 | "num_workers": 10, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/tvc_8frm_default", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.5, 27 | "max_masked_tokens": 45, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 3e-4, 34 | "backbone_coef_lr": 0.05, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "mixed_precision_method": "deepspeed", 41 | "amp_opt_level": 0, 42 | "deepspeed_fp16": true, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": 1, 45 | "restore_ratio": -1, 46 | "debug": false, 47 | "debug_speed": false, 48 | "seed": 88 49 | } -------------------------------------------------------------------------------- /src/configs/VidSwinBert/vatex_8frm_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "VATEX/train_32frames.yaml", 6 | "val_yaml": "VATEX/public_test_32frames.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 36, 15 | "per_gpu_train_batch_size": 36, 16 | "num_workers": 10, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/vatex_8frm_default", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.5, 27 | "max_masked_tokens": 45, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 3e-4, 34 | "backbone_coef_lr": 0.05, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "mixed_precision_method": "deepspeed", 41 | "amp_opt_level": 0, 42 | "deepspeed_fp16": true, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": 1, 45 | "restore_ratio": -1, 46 | "debug": false, 47 | "debug_speed": false, 48 | "seed": 88 49 | } -------------------------------------------------------------------------------- /src/configs/VidSwinBert/youcook2_8frm_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_train": true, 3 | "evaluate_during_training": true, 4 | "data_dir": "datasets", 5 | "train_yaml": "YouCook2/training_32frames.yaml", 6 | "val_yaml": "YouCook2/validation_32frames.yaml", 7 | "do_lower_case": true, 8 | "max_seq_a_length": 50, 9 | "max_seq_length": 50, 10 | "max_img_seq_length": 196, 11 | "img_res": 224, 12 | "max_num_frames": 8, 13 | "patch_size": 32, 14 | "per_gpu_eval_batch_size": 36, 15 | "per_gpu_train_batch_size": 36, 16 | "num_workers": 10, 17 | "model_name_or_path": "models/captioning/bert-base-uncased/", 18 | "pretrained_checkpoint": "", 19 | "output_dir": "output/youcook2_8frm_default", 20 | "img_feature_dim": 512, 21 | "vidswin_size": "base", 22 | "kinetics": "600", 23 | "use_clip_model": true, 24 | "pretrained_2d": false, 25 | "grid_feat": true, 26 | "mask_prob": 0.5, 27 | "max_masked_tokens": 45, 28 | "attn_mask_type": "seq2seq", 29 | "max_gen_length": 20, 30 | "on_memory": false, 31 | "use_checkpoint": true, 32 | "num_train_epochs": 1, 33 | "learning_rate": 3e-4, 34 | "backbone_coef_lr": 0.05, 35 | "scheduler": "warmup_linear", 36 | "warmup_ratio": 0.1, 37 | "weight_decay": 0.05, 38 | "max_grad_norm": 1.0, 39 | "gradient_accumulation_steps": 1, 40 | "mixed_precision_method": "deepspeed", 41 | "amp_opt_level": 0, 42 | "deepspeed_fp16": true, 43 | "fairscale_fp16": false, 44 | "zero_opt_stage": 1, 45 | "restore_ratio": -1, 46 | "debug": false, 47 | "debug_speed": false, 48 | "seed": 88 49 | } -------------------------------------------------------------------------------- /src/datasets/data_utils/video_functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | import cv2 4 | import numpy as np 5 | import PIL 6 | from PIL import Image 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def to_grayscale(img, num_output_channels=1): 27 | """Convert image to grayscale version of image. 28 | 29 | Args: 30 | img (PIL Image): Image to be converted to grayscale. 31 | 32 | Returns: 33 | PIL Image: Grayscale version of the image. 34 | if num_output_channels = 1 : returned image is single channel 35 | 36 | if num_output_channels = 3 : returned image is 3 channel with r = g = b 37 | """ 38 | if not isinstance(img,PIL.Image.Image): 39 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 40 | 41 | if num_output_channels == 1: 42 | img = img.convert('L') 43 | elif num_output_channels == 3: 44 | img = img.convert('L') 45 | np_img = np.array(img, dtype=np.uint8) 46 | np_img = np.dstack([np_img, np_img, np_img]) 47 | img = Image.fromarray(np_img, 'RGB') 48 | else: 49 | raise ValueError('num_output_channels should be either 1 or 3') 50 | 51 | return img 52 | 53 | def resize_clip(clip, size, interpolation='bilinear'): 54 | if isinstance(clip[0], np.ndarray): 55 | if isinstance(size, numbers.Number): 56 | im_h, im_w, im_c = clip[0].shape 57 | # Min spatial dim already matches minimal size 58 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 59 | and im_h == size): 60 | return clip 61 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 62 | size = (new_w, new_h) 63 | else: 64 | size = size[1], size[0] 65 | if interpolation == 'bilinear': 66 | np_inter = cv2.INTER_LINEAR 67 | else: 68 | np_inter = cv2.INTER_NEAREST 69 | scaled = [ 70 | cv2.resize(img, size, interpolation=np_inter) for img in clip 71 | ] 72 | elif isinstance(clip[0], PIL.Image.Image): 73 | if isinstance(size, numbers.Number): 74 | im_w, im_h = clip[0].size 75 | # Min spatial dim already matches minimal size 76 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 77 | and im_h == size): 78 | return clip 79 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 80 | size = (new_w, new_h) 81 | else: 82 | size = size[1], size[0] 83 | if interpolation == 'bilinear': 84 | pil_inter = PIL.Image.NEAREST 85 | else: 86 | pil_inter = PIL.Image.BILINEAR 87 | scaled = [img.resize(size, pil_inter) for img in clip] 88 | else: 89 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 90 | 'but got list of {0}'.format(type(clip[0]))) 91 | return scaled 92 | 93 | 94 | def get_resize_sizes(im_h, im_w, size): 95 | if im_w < im_h: 96 | ow = size 97 | oh = int(size * im_h / im_w) 98 | else: 99 | oh = size 100 | ow = int(size * im_w / im_h) 101 | return oh, ow 102 | 103 | 104 | def normalize(clip, mean, std, inplace=False): 105 | if not _is_tensor_clip(clip): 106 | raise TypeError('tensor is not a torch clip.') 107 | 108 | if not inplace: 109 | clip = clip.clone() 110 | 111 | dtype = clip.dtype 112 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 113 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 114 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 115 | 116 | return clip 117 | -------------------------------------------------------------------------------- /src/datasets/data_utils/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def my_convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = my_convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = tensor_clip.div(255) 67 | return tensor_clip 68 | 69 | 70 | class ToTensor(object): 71 | """Converts numpy array to tensor 72 | """ 73 | 74 | def __call__(self, array): 75 | tensor = torch.from_numpy(array) 76 | return tensor 77 | -------------------------------------------------------------------------------- /src/layers/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization) 4 | 5 | from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, 6 | BertForMaskedLM, BertForNextSentencePrediction, 7 | BertForSequenceClassification, BertForMultipleChoice, 8 | BertForTokenClassification, BertForQuestionAnswering, 9 | BertForImageCaptioning, BertImgForPreTraining, 10 | BertForVLGrounding, BertImgForGroundedPreTraining, 11 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 12 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 13 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 14 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 15 | 16 | from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 17 | -------------------------------------------------------------------------------- /src/modeling/load_bert.py: -------------------------------------------------------------------------------- 1 | from src.layers.bert import BertTokenizer, BertConfig, BertForImageCaptioning 2 | from src.utils.logger import LOGGER as logger 3 | 4 | def get_bert_model(args): 5 | # Load pretrained bert and tokenizer based on training configs 6 | config_class, model_class, tokenizer_class = BertConfig, BertForImageCaptioning, BertTokenizer 7 | config = config_class.from_pretrained(args.config_name if args.config_name else \ 8 | args.model_name_or_path, num_labels=2, finetuning_task='image_captioning') 9 | 10 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name \ 11 | else args.model_name_or_path, do_lower_case=args.do_lower_case) 12 | config.img_feature_type = 'frcnn' 13 | config.hidden_dropout_prob = args.drop_out 14 | config.loss_type = 'classification' 15 | config.tie_weights = args.tie_weights 16 | config.freeze_embedding = args.freeze_embedding 17 | config.label_smoothing = args.label_smoothing 18 | config.drop_worst_ratio = args.drop_worst_ratio 19 | config.drop_worst_after = args.drop_worst_after 20 | # update model structure if specified in arguments 21 | update_params = ['img_feature_dim', 'num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] 22 | model_structure_changed = [False] * len(update_params) 23 | # model_structure_changed[0] = True # cclin hack 24 | for idx, param in enumerate(update_params): 25 | arg_param = getattr(args, param) 26 | # bert-base-uncased do not have img_feature_dim 27 | config_param = getattr(config, param) if hasattr(config, param) else -1 28 | if arg_param > 0 and arg_param != config_param: 29 | logger.info(f"Update config parameter {param}: {config_param} -> {arg_param}") 30 | setattr(config, param, arg_param) 31 | model_structure_changed[idx] = True 32 | if any(model_structure_changed): 33 | assert config.hidden_size % config.num_attention_heads == 0 34 | if args.load_partial_weights: 35 | # can load partial weights when changing layer only. 36 | assert not any(model_structure_changed[2:]), "Cannot load partial weights " \ 37 | "when any of ({}) is changed.".format(', '.join(update_params[2:])) 38 | model = model_class.from_pretrained(args.model_name_or_path, 39 | from_tf=bool('.ckpt' in args.model_name_or_path), config=config) 40 | logger.info("Load partial weights for bert layers.") 41 | else: 42 | model = model_class(config=config) # init from scratch 43 | logger.info("Init model from scratch.") 44 | else: 45 | model = model_class.from_pretrained(args.model_name_or_path, 46 | from_tf=bool('.ckpt' in args.model_name_or_path), config=config) 47 | logger.info(f"Load pretrained model: {args.model_name_or_path}") 48 | 49 | total_params = sum(p.numel() for p in model.parameters()) 50 | logger.info(f'Model total parameters: {total_params}') 51 | return model, config, tokenizer -------------------------------------------------------------------------------- /src/modeling/load_swin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.utils.logger import LOGGER as logger 3 | from src.modeling.video_swin.swin_transformer import SwinTransformer3D 4 | from src.modeling.video_swin.config import Config 5 | 6 | def get_swin_model(args): 7 | if int(args.img_res) == 384: 8 | assert args.vidswin_size == "large" 9 | config_path = 'src/modeling/video_swin/swin_%s_384_patch244_window81212_kinetics%s_22k.py'%(args.vidswin_size, args.kinetics) 10 | model_path = './models/video_swin_transformer/swin_%s_384_patch244_window81212_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics) 11 | else: 12 | # in the case that args.img_res == '224' 13 | config_path = 'src/modeling/video_swin/swin_%s_patch244_window877_kinetics%s_22k.py'%(args.vidswin_size, args.kinetics) 14 | model_path = './models/video_swin_transformer/swin_%s_patch244_window877_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics) 15 | if args.pretrained_2d: 16 | config_path = 'src/modeling/video_swin/swin_base_patch244_window877_kinetics400_22k.py' 17 | model_path = './models/swin_transformer/swin_base_patch4_window7_224_22k.pth' 18 | 19 | logger.info(f'video swin (config path): {config_path}') 20 | if args.pretrained_checkpoint == '': 21 | logger.info(f'video swin (model path): {model_path}') 22 | cfg = Config.fromfile(config_path) 23 | pretrained_path = model_path if args.pretrained_2d else None 24 | backbone = SwinTransformer3D( 25 | pretrained=pretrained_path, 26 | pretrained2d=args.pretrained_2d, 27 | patch_size=cfg.model['backbone']['patch_size'], 28 | in_chans=3, 29 | embed_dim=cfg.model['backbone']['embed_dim'], 30 | depths=cfg.model['backbone']['depths'], 31 | num_heads=cfg.model['backbone']['num_heads'], 32 | window_size=cfg.model['backbone']['window_size'], 33 | mlp_ratio=4., 34 | qkv_bias=True, 35 | qk_scale=None, 36 | drop_rate=0., 37 | attn_drop_rate=0., 38 | drop_path_rate=0.2, 39 | norm_layer=torch.nn.LayerNorm, 40 | patch_norm=cfg.model['backbone']['patch_norm'], 41 | frozen_stages=-1, 42 | use_checkpoint=False) 43 | 44 | video_swin = myVideoSwin(args=args, cfg=cfg, backbone=backbone) 45 | 46 | if not args.pretrained_2d: 47 | checkpoint_3d = torch.load(model_path, map_location='cpu') 48 | video_swin.load_state_dict(checkpoint_3d['state_dict'], strict=False) 49 | else: 50 | video_swin.backbone.init_weights() 51 | return video_swin 52 | 53 | def reload_pretrained_swin(video_swin, args): 54 | if not args.reload_pretrained_swin: 55 | return video_swin 56 | if int(args.img_res) == 384: 57 | model_path = './models/video_swin_transformer/swin_%s_384_patch244_window81212_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics) 58 | else: 59 | # in the case that args.img_res == '224' 60 | model_path = './models/video_swin_transformer/swin_%s_patch244_window877_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics) 61 | 62 | checkpoint_3d = torch.load(model_path, map_location='cpu') 63 | missing, unexpected = video_swin.load_state_dict(checkpoint_3d['state_dict'], strict=False) 64 | logger.info(f"re-loaded video_swin_transformer from {model_path}") 65 | 66 | logger.info(f"Missing keys in loaded video_swin_transformerr: {missing}") 67 | logger.info(f"Unexpected keys in loaded video_swin_transformer: {unexpected}") 68 | return video_swin 69 | 70 | class myVideoSwin(torch.nn.Module): 71 | def __init__(self, args, cfg, backbone): 72 | super(myVideoSwin, self).__init__() 73 | self.backbone = backbone 74 | self.use_grid_feature = args.grid_feat 75 | 76 | def forward(self, x): 77 | x = self.backbone(x) 78 | return x 79 | -------------------------------------------------------------------------------- /src/modeling/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /src/modeling/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | def build_model(config): 11 | model_type = config.MODEL.TYPE 12 | if model_type == 'swin': 13 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 14 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 15 | in_chans=config.MODEL.SWIN.IN_CHANS, 16 | num_classes=config.MODEL.NUM_CLASSES, 17 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 18 | depths=config.MODEL.SWIN.DEPTHS, 19 | num_heads=config.MODEL.SWIN.NUM_HEADS, 20 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 21 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 22 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 23 | qk_scale=config.MODEL.SWIN.QK_SCALE, 24 | drop_rate=config.MODEL.DROP_RATE, 25 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 26 | ape=config.MODEL.SWIN.APE, 27 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 28 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 29 | else: 30 | raise NotImplementedError(f"Unkown model: {model_type}") 31 | 32 | return model -------------------------------------------------------------------------------- /src/modeling/swin/swin_base_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_base_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /src/modeling/swin/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | NUM_CLASSES: 1000 -------------------------------------------------------------------------------- /src/modeling/swin/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | NUM_CLASSES: 21841 -------------------------------------------------------------------------------- /src/modeling/swin/swin_large_patch4_window12_384.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | DATA: 3 | IMG_SIZE: 384 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_large_patch4_window12_384 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TEST: 13 | CROP: False -------------------------------------------------------------------------------- /src/modeling/swin/swin_large_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | # only for evaluation 2 | MODEL: 3 | TYPE: swin 4 | NAME: swin_large_patch4_window7_224 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /src/modeling/swin/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /src/modeling/swin/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /src/modeling/video_swin/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | log_config = dict( 3 | interval=20, 4 | hooks=[ 5 | dict(type='TextLoggerHook'), 6 | # dict(type='TensorboardLoggerHook'), 7 | ]) 8 | # runtime settings 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_base.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | _base_ = "swin_tiny.py" 3 | model = dict(backbone=dict(depths=[2, 2, 18, 2], 4 | embed_dim=128, 5 | num_heads=[4, 8, 16, 32]), 6 | cls_head=dict(in_channels=1024)) -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_base_patch244_window877_kinetics400_1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'swin_base.py', 'default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.3), test_cfg=dict(max_testing_views=4)) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 256)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 26 | dict(type='ToTensor', keys=['imgs', 'label']) 27 | ] 28 | val_pipeline = [ 29 | dict(type='DecordInit'), 30 | dict( 31 | type='SampleFrames', 32 | clip_len=32, 33 | frame_interval=2, 34 | num_clips=1, 35 | test_mode=True), 36 | dict(type='DecordDecode'), 37 | dict(type='Resize', scale=(-1, 256)), 38 | dict(type='CenterCrop', crop_size=224), 39 | dict(type='Flip', flip_ratio=0), 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='FormatShape', input_format='NCTHW'), 42 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 43 | dict(type='ToTensor', keys=['imgs']) 44 | ] 45 | test_pipeline = [ 46 | dict(type='DecordInit'), 47 | dict( 48 | type='SampleFrames', 49 | clip_len=32, 50 | frame_interval=2, 51 | num_clips=4, 52 | test_mode=True), 53 | dict(type='DecordDecode'), 54 | dict(type='Resize', scale=(-1, 224)), 55 | dict(type='ThreeCrop', crop_size=224), 56 | dict(type='Flip', flip_ratio=0), 57 | dict(type='Normalize', **img_norm_cfg), 58 | dict(type='FormatShape', input_format='NCTHW'), 59 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 60 | dict(type='ToTensor', keys=['imgs']) 61 | ] 62 | data = dict( 63 | videos_per_gpu=8, 64 | workers_per_gpu=4, 65 | val_dataloader=dict( 66 | videos_per_gpu=1, 67 | workers_per_gpu=1 68 | ), 69 | test_dataloader=dict( 70 | videos_per_gpu=1, 71 | workers_per_gpu=1 72 | ), 73 | train=dict( 74 | type=dataset_type, 75 | ann_file=ann_file_train, 76 | data_prefix=data_root, 77 | pipeline=train_pipeline), 78 | val=dict( 79 | type=dataset_type, 80 | ann_file=ann_file_val, 81 | data_prefix=data_root_val, 82 | pipeline=val_pipeline), 83 | test=dict( 84 | type=dataset_type, 85 | ann_file=ann_file_test, 86 | data_prefix=data_root_val, 87 | pipeline=test_pipeline)) 88 | evaluation = dict( 89 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 90 | 91 | # optimizer 92 | optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.05, 93 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 94 | 'relative_position_bias_table': dict(decay_mult=0.), 95 | 'norm': dict(decay_mult=0.), 96 | 'backbone': dict(lr_mult=0.1)})) 97 | # learning policy 98 | lr_config = dict( 99 | policy='CosineAnnealing', 100 | min_lr=0, 101 | warmup='linear', 102 | warmup_by_epoch=True, 103 | warmup_iters=2.5 104 | ) 105 | total_epochs = 30 106 | 107 | # runtime settings 108 | checkpoint_config = dict(interval=1) 109 | work_dir = work_dir = './work_dirs/k400_swin_base_patch244_window877.py' 110 | find_unused_parameters = False 111 | 112 | 113 | # do not use mmdet version fp16 114 | fp16 = None 115 | optimizer_config = dict( 116 | type="DistOptimizerHook", 117 | update_interval=8, 118 | grad_clip=None, 119 | coalesce=True, 120 | bucket_size_mb=-1, 121 | use_fp16=True, 122 | ) 123 | -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_base_patch244_window877_kinetics400_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'swin_base.py', 'default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.2), test_cfg=dict(max_testing_views=2)) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 256)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 26 | dict(type='ToTensor', keys=['imgs', 'label']) 27 | ] 28 | val_pipeline = [ 29 | dict(type='DecordInit'), 30 | dict( 31 | type='SampleFrames', 32 | clip_len=32, 33 | frame_interval=2, 34 | num_clips=1, 35 | test_mode=True), 36 | dict(type='DecordDecode'), 37 | dict(type='Resize', scale=(-1, 256)), 38 | dict(type='CenterCrop', crop_size=224), 39 | dict(type='Flip', flip_ratio=0), 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='FormatShape', input_format='NCTHW'), 42 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 43 | dict(type='ToTensor', keys=['imgs']) 44 | ] 45 | test_pipeline = [ 46 | dict(type='DecordInit'), 47 | dict( 48 | type='SampleFrames', 49 | clip_len=32, 50 | frame_interval=2, 51 | num_clips=4, 52 | test_mode=True), 53 | dict(type='DecordDecode'), 54 | dict(type='Resize', scale=(-1, 224)), 55 | dict(type='ThreeCrop', crop_size=224), 56 | dict(type='Flip', flip_ratio=0), 57 | dict(type='Normalize', **img_norm_cfg), 58 | dict(type='FormatShape', input_format='NCTHW'), 59 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 60 | dict(type='ToTensor', keys=['imgs']) 61 | ] 62 | data = dict( 63 | videos_per_gpu=8, 64 | workers_per_gpu=4, 65 | val_dataloader=dict( 66 | videos_per_gpu=1, 67 | workers_per_gpu=1 68 | ), 69 | test_dataloader=dict( 70 | videos_per_gpu=1, 71 | workers_per_gpu=1 72 | ), 73 | train=dict( 74 | type=dataset_type, 75 | ann_file=ann_file_train, 76 | data_prefix=data_root, 77 | pipeline=train_pipeline), 78 | val=dict( 79 | type=dataset_type, 80 | ann_file=ann_file_val, 81 | data_prefix=data_root_val, 82 | pipeline=val_pipeline), 83 | test=dict( 84 | type=dataset_type, 85 | ann_file=ann_file_test, 86 | data_prefix=data_root_val, 87 | pipeline=test_pipeline)) 88 | evaluation = dict( 89 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 90 | 91 | # optimizer 92 | optimizer = dict(type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05, 93 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 94 | 'relative_position_bias_table': dict(decay_mult=0.), 95 | 'norm': dict(decay_mult=0.), 96 | 'backbone': dict(lr_mult=0.1)})) 97 | # learning policy 98 | lr_config = dict( 99 | policy='CosineAnnealing', 100 | min_lr=0, 101 | warmup='linear', 102 | warmup_by_epoch=True, 103 | warmup_iters=2.5 104 | ) 105 | total_epochs = 30 106 | 107 | # runtime settings 108 | checkpoint_config = dict(interval=1) 109 | work_dir = work_dir = './work_dirs/k400_swin_base_22k_patch244_window877.py' 110 | find_unused_parameters = False 111 | 112 | 113 | # do not use mmdet version fp16 114 | fp16 = None 115 | optimizer_config = dict( 116 | type="DistOptimizerHook", 117 | update_interval=8, 118 | grad_clip=None, 119 | coalesce=True, 120 | bucket_size_mb=-1, 121 | use_fp16=True, 122 | ) 123 | -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_base_patch244_window877_kinetics600_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = "swin_base_patch244_window877_kinetics400_22k.py" 2 | 3 | data_root = 'data/kinetics600/train' 4 | data_root_val = 'data/kinetics600/val' 5 | ann_file_train = 'data/kinetics600/kinetics600_train_list.txt' 6 | ann_file_val = 'data/kinetics600/kinetics600_val_list.txt' 7 | ann_file_test = 'data/kinetics600/kinetics600_val_list.txt' 8 | 9 | data = dict( 10 | train=dict( 11 | ann_file=ann_file_train, 12 | data_prefix=data_root), 13 | val=dict( 14 | ann_file=ann_file_val, 15 | data_prefix=data_root_val), 16 | test=dict( 17 | ann_file=ann_file_test, 18 | data_prefix=data_root_val)) 19 | 20 | model=dict(cls_head=dict(num_classes=600)) 21 | -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_large.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | _base_ = "swin_tiny.py" 3 | model = dict(backbone=dict(depths=[2, 2, 18, 2], 4 | embed_dim=192, 5 | num_heads=[6, 12, 24, 48]), 6 | cls_head=dict(in_channels=1536)) -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_large_384_patch244_window81212_kinetics400_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'swin_large.py', 'default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), window_size=(8,12,12), drop_path_rate=0.5), test_cfg=dict(max_testing_views=1), train_cfg=dict(blending=dict(type='LabelSmoothing', num_classes=400, smoothing=0.1))) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 416)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(384, 384), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Imgaug', transforms=[dict(type='RandAugment', n=4, m=7)]), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='RandomErasing', probability=0.25), 26 | dict(type='FormatShape', input_format='NCTHW'), 27 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 28 | dict(type='ToTensor', keys=['imgs', 'label']) 29 | ] 30 | val_pipeline = [ 31 | dict(type='DecordInit'), 32 | dict( 33 | type='SampleFrames', 34 | clip_len=32, 35 | frame_interval=2, 36 | num_clips=1, 37 | test_mode=True), 38 | dict(type='DecordDecode'), 39 | dict(type='Resize', scale=(-1, 416)), 40 | dict(type='CenterCrop', crop_size=384), 41 | dict(type='Flip', flip_ratio=0), 42 | dict(type='Normalize', **img_norm_cfg), 43 | dict(type='FormatShape', input_format='NCTHW'), 44 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 45 | dict(type='ToTensor', keys=['imgs']) 46 | ] 47 | test_pipeline = [ 48 | dict(type='DecordInit'), 49 | dict( 50 | type='SampleFrames', 51 | clip_len=32, 52 | frame_interval=2, 53 | num_clips=4, 54 | test_mode=True), 55 | dict(type='DecordDecode'), 56 | dict(type='Resize', scale=(-1, 384)), 57 | dict(type='ThreeCrop', crop_size=384), 58 | dict(type='Flip', flip_ratio=0), 59 | dict(type='Normalize', **img_norm_cfg), 60 | dict(type='FormatShape', input_format='NCTHW'), 61 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 62 | dict(type='ToTensor', keys=['imgs']) 63 | ] 64 | data = dict( 65 | videos_per_gpu=8, 66 | workers_per_gpu=1, 67 | val_dataloader=dict( 68 | videos_per_gpu=1, 69 | workers_per_gpu=1 70 | ), 71 | test_dataloader=dict( 72 | videos_per_gpu=1, 73 | workers_per_gpu=1 74 | ), 75 | train=dict( 76 | type=dataset_type, 77 | ann_file=ann_file_train, 78 | data_prefix=data_root, 79 | pipeline=train_pipeline), 80 | val=dict( 81 | type=dataset_type, 82 | ann_file=ann_file_val, 83 | data_prefix=data_root_val, 84 | pipeline=val_pipeline), 85 | test=dict( 86 | type=dataset_type, 87 | ann_file=ann_file_test, 88 | data_prefix=data_root_val, 89 | pipeline=test_pipeline)) 90 | evaluation = dict( 91 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 92 | 93 | # optimizer 94 | optimizer = dict(type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05, 95 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 96 | 'relative_position_bias_table': dict(decay_mult=0.), 97 | 'norm': dict(decay_mult=0.), 98 | 'backbone': dict(lr_mult=0.1)})) 99 | # learning policy 100 | lr_config = dict( 101 | policy='CosineAnnealing', 102 | min_lr=0, 103 | warmup='linear', 104 | warmup_by_epoch=True, 105 | warmup_iters=2.5 106 | ) 107 | total_epochs = 60 108 | 109 | # runtime settings 110 | checkpoint_config = dict(interval=1) 111 | work_dir = work_dir = './work_dirs/swin_large_384_patch244_window81212_kinetics400_22k' 112 | find_unused_parameters = False 113 | 114 | 115 | # do not use mmdet version fp16 116 | fp16 = None 117 | optimizer_config = dict( 118 | type="DistOptimizerHook", 119 | update_interval=8, 120 | grad_clip=None, 121 | coalesce=True, 122 | bucket_size_mb=-1, 123 | use_fp16=True, 124 | ) -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_large_patch244_window877_kinetics400_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'swin_large.py', 'default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.2), test_cfg=dict(max_testing_views=1)) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 256)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 26 | dict(type='ToTensor', keys=['imgs', 'label']) 27 | ] 28 | val_pipeline = [ 29 | dict(type='DecordInit'), 30 | dict( 31 | type='SampleFrames', 32 | clip_len=32, 33 | frame_interval=2, 34 | num_clips=1, 35 | test_mode=True), 36 | dict(type='DecordDecode'), 37 | dict(type='Resize', scale=(-1, 256)), 38 | dict(type='CenterCrop', crop_size=224), 39 | dict(type='Flip', flip_ratio=0), 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='FormatShape', input_format='NCTHW'), 42 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 43 | dict(type='ToTensor', keys=['imgs']) 44 | ] 45 | test_pipeline = [ 46 | dict(type='DecordInit'), 47 | dict( 48 | type='SampleFrames', 49 | clip_len=32, 50 | frame_interval=2, 51 | num_clips=4, 52 | test_mode=True), 53 | dict(type='DecordDecode'), 54 | dict(type='Resize', scale=(-1, 224)), 55 | dict(type='ThreeCrop', crop_size=224), 56 | dict(type='Flip', flip_ratio=0), 57 | dict(type='Normalize', **img_norm_cfg), 58 | dict(type='FormatShape', input_format='NCTHW'), 59 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 60 | dict(type='ToTensor', keys=['imgs']) 61 | ] 62 | data = dict( 63 | videos_per_gpu=8, 64 | workers_per_gpu=1, 65 | val_dataloader=dict( 66 | videos_per_gpu=1, 67 | workers_per_gpu=1 68 | ), 69 | test_dataloader=dict( 70 | videos_per_gpu=1, 71 | workers_per_gpu=1 72 | ), 73 | train=dict( 74 | type=dataset_type, 75 | ann_file=ann_file_train, 76 | data_prefix=data_root, 77 | pipeline=train_pipeline), 78 | val=dict( 79 | type=dataset_type, 80 | ann_file=ann_file_val, 81 | data_prefix=data_root_val, 82 | pipeline=val_pipeline), 83 | test=dict( 84 | type=dataset_type, 85 | ann_file=ann_file_test, 86 | data_prefix=data_root_val, 87 | pipeline=test_pipeline)) 88 | evaluation = dict( 89 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 90 | 91 | # optimizer 92 | optimizer = dict(type='AdamW', lr=3e-4, betas=(0.9, 0.999), weight_decay=0.05, 93 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 94 | 'relative_position_bias_table': dict(decay_mult=0.), 95 | 'norm': dict(decay_mult=0.), 96 | 'backbone': dict(lr_mult=0.1)})) 97 | # learning policy 98 | lr_config = dict( 99 | policy='CosineAnnealing', 100 | min_lr=0, 101 | warmup='linear', 102 | warmup_by_epoch=True, 103 | warmup_iters=2.5 104 | ) 105 | total_epochs = 30 106 | 107 | # runtime settings 108 | checkpoint_config = dict(interval=1) 109 | work_dir = work_dir = './work_dirs/swin_large_patch244_window877_kinetics400_22k' 110 | find_unused_parameters = False 111 | 112 | 113 | # do not use mmdet version fp16 114 | fp16 = None 115 | optimizer_config = dict( 116 | type="DistOptimizerHook", 117 | update_interval=8, 118 | grad_clip=None, 119 | coalesce=True, 120 | bucket_size_mb=-1, 121 | use_fp16=True, 122 | ) -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_small_patch244_window877_kinetics400_1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/swin/swin_small.py', '../../_base_/default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.1), test_cfg=dict(max_testing_views=4)) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 256)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 26 | dict(type='ToTensor', keys=['imgs', 'label']) 27 | ] 28 | val_pipeline = [ 29 | dict(type='DecordInit'), 30 | dict( 31 | type='SampleFrames', 32 | clip_len=32, 33 | frame_interval=2, 34 | num_clips=1, 35 | test_mode=True), 36 | dict(type='DecordDecode'), 37 | dict(type='Resize', scale=(-1, 256)), 38 | dict(type='CenterCrop', crop_size=224), 39 | dict(type='Flip', flip_ratio=0), 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='FormatShape', input_format='NCTHW'), 42 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 43 | dict(type='ToTensor', keys=['imgs']) 44 | ] 45 | test_pipeline = [ 46 | dict(type='DecordInit'), 47 | dict( 48 | type='SampleFrames', 49 | clip_len=32, 50 | frame_interval=2, 51 | num_clips=4, 52 | test_mode=True), 53 | dict(type='DecordDecode'), 54 | dict(type='Resize', scale=(-1, 224)), 55 | dict(type='ThreeCrop', crop_size=224), 56 | dict(type='Flip', flip_ratio=0), 57 | dict(type='Normalize', **img_norm_cfg), 58 | dict(type='FormatShape', input_format='NCTHW'), 59 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 60 | dict(type='ToTensor', keys=['imgs']) 61 | ] 62 | data = dict( 63 | videos_per_gpu=8, 64 | workers_per_gpu=4, 65 | val_dataloader=dict( 66 | videos_per_gpu=1, 67 | workers_per_gpu=1 68 | ), 69 | test_dataloader=dict( 70 | videos_per_gpu=1, 71 | workers_per_gpu=1 72 | ), 73 | train=dict( 74 | type=dataset_type, 75 | ann_file=ann_file_train, 76 | data_prefix=data_root, 77 | pipeline=train_pipeline), 78 | val=dict( 79 | type=dataset_type, 80 | ann_file=ann_file_val, 81 | data_prefix=data_root_val, 82 | pipeline=val_pipeline), 83 | test=dict( 84 | type=dataset_type, 85 | ann_file=ann_file_test, 86 | data_prefix=data_root_val, 87 | pipeline=test_pipeline)) 88 | evaluation = dict( 89 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 90 | 91 | # optimizer 92 | optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.02, 93 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 94 | 'relative_position_bias_table': dict(decay_mult=0.), 95 | 'norm': dict(decay_mult=0.), 96 | 'backbone': dict(lr_mult=0.1)})) 97 | # learning policy 98 | lr_config = dict( 99 | policy='CosineAnnealing', 100 | min_lr=0, 101 | warmup='linear', 102 | warmup_by_epoch=True, 103 | warmup_iters=2.5 104 | ) 105 | total_epochs = 30 106 | 107 | # runtime settings 108 | checkpoint_config = dict(interval=1) 109 | work_dir = work_dir = './work_dirs/k400_swin_small_patch244_window877.py' 110 | find_unused_parameters = False 111 | 112 | 113 | # do not use mmdet version fp16 114 | fp16 = None 115 | optimizer_config = dict( 116 | type="DistOptimizerHook", 117 | update_interval=8, 118 | grad_clip=None, 119 | coalesce=True, 120 | bucket_size_mb=-1, 121 | use_fp16=True, 122 | ) 123 | -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_tiny.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='Recognizer3D', 4 | backbone=dict( 5 | type='SwinTransformer3D', 6 | patch_size=(4,4,4), 7 | embed_dim=96, 8 | depths=[2, 2, 6, 2], 9 | num_heads=[3, 6, 12, 24], 10 | window_size=(8,7,7), 11 | mlp_ratio=4., 12 | qkv_bias=True, 13 | qk_scale=None, 14 | drop_rate=0., 15 | attn_drop_rate=0., 16 | drop_path_rate=0.2, 17 | patch_norm=True), 18 | cls_head=dict( 19 | type='I3DHead', 20 | in_channels=768, 21 | num_classes=400, 22 | spatial_type='avg', 23 | dropout_ratio=0.5), 24 | test_cfg = dict(average_clips='prob')) -------------------------------------------------------------------------------- /src/modeling/video_swin/swin_tiny_patch244_window877_kinetics400_1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/swin/swin_tiny.py', '../../_base_/default_runtime.py' 3 | ] 4 | model=dict(backbone=dict(patch_size=(2,4,4), drop_path_rate=0.1), test_cfg=dict(max_testing_views=4)) 5 | 6 | # dataset settings 7 | dataset_type = 'VideoDataset' 8 | data_root = 'data/kinetics400/train' 9 | data_root_val = 'data/kinetics400/val' 10 | ann_file_train = 'data/kinetics400/kinetics400_train_list.txt' 11 | ann_file_val = 'data/kinetics400/kinetics400_val_list.txt' 12 | ann_file_test = 'data/kinetics400/kinetics400_val_list.txt' 13 | img_norm_cfg = dict( 14 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) 15 | train_pipeline = [ 16 | dict(type='DecordInit'), 17 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1), 18 | dict(type='DecordDecode'), 19 | dict(type='Resize', scale=(-1, 256)), 20 | dict(type='RandomResizedCrop'), 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | dict(type='Flip', flip_ratio=0.5), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 26 | dict(type='ToTensor', keys=['imgs', 'label']) 27 | ] 28 | val_pipeline = [ 29 | dict(type='DecordInit'), 30 | dict( 31 | type='SampleFrames', 32 | clip_len=32, 33 | frame_interval=2, 34 | num_clips=1, 35 | test_mode=True), 36 | dict(type='DecordDecode'), 37 | dict(type='Resize', scale=(-1, 256)), 38 | dict(type='CenterCrop', crop_size=224), 39 | dict(type='Flip', flip_ratio=0), 40 | dict(type='Normalize', **img_norm_cfg), 41 | dict(type='FormatShape', input_format='NCTHW'), 42 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 43 | dict(type='ToTensor', keys=['imgs']) 44 | ] 45 | test_pipeline = [ 46 | dict(type='DecordInit'), 47 | dict( 48 | type='SampleFrames', 49 | clip_len=32, 50 | frame_interval=2, 51 | num_clips=4, 52 | test_mode=True), 53 | dict(type='DecordDecode'), 54 | dict(type='Resize', scale=(-1, 224)), 55 | dict(type='ThreeCrop', crop_size=224), 56 | dict(type='Flip', flip_ratio=0), 57 | dict(type='Normalize', **img_norm_cfg), 58 | dict(type='FormatShape', input_format='NCTHW'), 59 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 60 | dict(type='ToTensor', keys=['imgs']) 61 | ] 62 | data = dict( 63 | videos_per_gpu=8, 64 | workers_per_gpu=4, 65 | val_dataloader=dict( 66 | videos_per_gpu=1, 67 | workers_per_gpu=1 68 | ), 69 | test_dataloader=dict( 70 | videos_per_gpu=1, 71 | workers_per_gpu=1 72 | ), 73 | train=dict( 74 | type=dataset_type, 75 | ann_file=ann_file_train, 76 | data_prefix=data_root, 77 | pipeline=train_pipeline), 78 | val=dict( 79 | type=dataset_type, 80 | ann_file=ann_file_val, 81 | data_prefix=data_root_val, 82 | pipeline=val_pipeline), 83 | test=dict( 84 | type=dataset_type, 85 | ann_file=ann_file_test, 86 | data_prefix=data_root_val, 87 | pipeline=test_pipeline)) 88 | evaluation = dict( 89 | interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) 90 | 91 | # optimizer 92 | optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.02, 93 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 94 | 'relative_position_bias_table': dict(decay_mult=0.), 95 | 'norm': dict(decay_mult=0.), 96 | 'backbone': dict(lr_mult=0.1)})) 97 | # learning policy 98 | lr_config = dict( 99 | policy='CosineAnnealing', 100 | min_lr=0, 101 | warmup='linear', 102 | warmup_by_epoch=True, 103 | warmup_iters=2.5 104 | ) 105 | total_epochs = 30 106 | 107 | # runtime settings 108 | checkpoint_config = dict(interval=1) 109 | work_dir = work_dir = './work_dirs/k400_swin_tiny_patch244_window877.py' 110 | find_unused_parameters = False 111 | 112 | 113 | # do not use mmdet version fp16 114 | fp16 = None 115 | optimizer_config = dict( 116 | type="DistOptimizerHook", 117 | update_interval=4, 118 | grad_clip=None, 119 | coalesce=True, 120 | bucket_size_mb=-1, 121 | use_fp16=True, 122 | ) 123 | -------------------------------------------------------------------------------- /src/solver/LARC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | 5 | class LARC(object): 6 | """ 7 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 8 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 9 | local learning rate for each individual parameter. The algorithm is designed to improve 10 | convergence of large batch training. 11 | 12 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 13 | 14 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 15 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 16 | 17 | ``` 18 | model = ... 19 | optim = torch.optim.Adam(model.parameters(), lr=...) 20 | optim = LARC(optim) 21 | ``` 22 | 23 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 24 | 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | optim = apex.fp16_utils.FP16_Optimizer(optim) 30 | ``` 31 | 32 | Args: 33 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 34 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 35 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 36 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 37 | """ 38 | 39 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 40 | self.param_groups = optimizer.param_groups 41 | self.optim = optimizer 42 | self.trust_coefficient = trust_coefficient 43 | self.eps = eps 44 | self.clip = clip 45 | 46 | def __getstate__(self): 47 | return self.optim.__getstate__() 48 | 49 | def __setstate__(self, state): 50 | self.optim.__setstate__(state) 51 | 52 | def __repr__(self): 53 | return self.optim.__repr__() 54 | 55 | def state_dict(self): 56 | return self.optim.state_dict() 57 | 58 | def load_state_dict(self, state_dict): 59 | self.optim.load_state_dict(state_dict) 60 | 61 | def zero_grad(self): 62 | self.optim.zero_grad() 63 | 64 | def add_param_group(self, param_group): 65 | self.optim.add_param_group(param_group) 66 | 67 | def step(self): 68 | with torch.no_grad(): 69 | weight_decays = [] 70 | for group in self.optim.param_groups: 71 | # absorb weight decay control from optimizer 72 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 73 | weight_decays.append(weight_decay) 74 | group['weight_decay'] = 0 75 | adlrs = [] 76 | for p in group['params']: 77 | if p.grad is None: 78 | continue 79 | param_norm = torch.norm(p.data) 80 | grad_norm = torch.norm(p.grad.data) 81 | 82 | if param_norm != 0 and grad_norm != 0: 83 | # calculate adaptive lr + weight decay 84 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 85 | 86 | # clip learning rate for LARC 87 | if self.clip: 88 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 89 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 90 | adlrs.append(adaptive_lr) 91 | p.grad.data += weight_decay * p.data 92 | p.grad.data *= adaptive_lr 93 | group['adaptive_lr'] = sum(adlrs) / len(adlrs) if len(adlrs)!=0 else 1 94 | 95 | self.optim.step() 96 | # return weight decay control to optimizer 97 | for i, group in enumerate(self.optim.param_groups): 98 | group['weight_decay'] = weight_decays[i] 99 | -------------------------------------------------------------------------------- /src/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_optimizer 3 | from .build import make_lr_scheduler 4 | from .lr_scheduler import WarmupMultiStepLR 5 | from .lr_scheduler import WarmupCosineAnnealingLR 6 | from .lr_scheduler import WarmupLinearLR 7 | from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, 8 | WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule, 9 | WarmupMultiStepSchedule) 10 | from .get_solver import get_optimizer, get_scheduler 11 | from .bertadam import BertAdam 12 | -------------------------------------------------------------------------------- /src/solver/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import re 3 | import torch 4 | from .LARC import LARC 5 | 6 | from .lr_scheduler import WarmupMultiStepLR 7 | from .lr_scheduler import WarmupCosineAnnealingLR 8 | from .optimization import AdamW 9 | from .optimization import WarmupLinearSchedule 10 | 11 | 12 | def make_optimizer(cfg, model, resume=False): 13 | params = [] 14 | for key, value in model.named_parameters(): 15 | if not value.requires_grad: 16 | continue 17 | lr = cfg.SOLVER.BASE_LR 18 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 19 | 20 | for reg_lr in cfg.SOLVER.REGEXP_LR_FACTOR: 21 | regexp, lr_factor = reg_lr 22 | if re.match(regexp, key): 23 | if lr != cfg.SOLVER.BASE_LR: 24 | print("WARNING: {} matched multiple " 25 | "regular expressions!".format(key)) 26 | lr *= lr_factor 27 | 28 | if "bias" in key: 29 | lr *= cfg.SOLVER.BIAS_LR_FACTOR 30 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 31 | 32 | if resume: 33 | params += [{"params": [value], "initial_lr": lr, "lr": lr, "weight_decay": weight_decay}] 34 | else: 35 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 36 | 37 | if cfg.SOLVER.OPTIMIZER == 'sgd': 38 | optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) 39 | elif cfg.SOLVER.OPTIMIZER == 'adam': 40 | optimizer = torch.optim.Adam(params) 41 | elif cfg.SOLVER.OPTIMIZER == 'adamw': 42 | # optimizer = torch.optim.AdamW(params) 43 | if hasattr(cfg, 'adam_epsilon'): 44 | optimizer = AdamW(params, eps=cfg.adam_epsilon) 45 | else: 46 | optimizer = AdamW(params) 47 | else: 48 | raise ValueError( 49 | 'Optimizer "{}" is not supported'.format(cfg.SOLVER.OPTIMIZER) 50 | ) 51 | if cfg.SOLVER.USE_LARC: 52 | optimizer = LARC(optimizer, clip=True, trust_coefficient=cfg.SOLVER.LARC_COEFFICIENT) 53 | return optimizer 54 | 55 | 56 | def make_lr_scheduler(cfg, optimizer, last_iter=-1): 57 | lr_policy = cfg.SOLVER.LR_POLICY 58 | if lr_policy not in ("multistep", "cosine", 'linear'): 59 | raise ValueError( 60 | "Only 'multistep' or 'cosine' lr policy is accepted" 61 | "got {}".format(lr_policy) 62 | ) 63 | if lr_policy == "multistep": 64 | return WarmupMultiStepLR( 65 | optimizer, 66 | cfg.SOLVER.STEPS, 67 | cfg.SOLVER.GAMMA, 68 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 69 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 70 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 71 | last_epoch=last_iter 72 | ) 73 | elif lr_policy == "cosine": 74 | return WarmupCosineAnnealingLR( 75 | optimizer, 76 | cfg.SOLVER.MAX_ITER, 77 | cfg.SOLVER.MIN_LR, 78 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 79 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 80 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 81 | last_epoch=last_iter 82 | ) 83 | elif lr_policy == "linear": 84 | return WarmupLinearSchedule( 85 | optimizer, 86 | warmup_steps=cfg.SOLVER.WARMUP_ITERS, 87 | t_total=cfg.SOLVER.MAX_ITER, 88 | ) 89 | -------------------------------------------------------------------------------- /src/solver/get_solver.py: -------------------------------------------------------------------------------- 1 | from .optimization import AdamW, WarmupLinearSchedule 2 | from .optimization import WarmupConstantSchedule, WarmupCosineSchedule 3 | 4 | 5 | def get_optimizer(model, weight_decay, learning_rate, adam_epsilon): 6 | no_decay = ['bias', 'LayerNorm.weight'] 7 | grouped_parameters = [ 8 | {'params': [p for n, p in model.named_parameters() if not \ 9 | any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, 10 | {'params': [p for n, p in model.named_parameters() if \ 11 | any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 12 | ] 13 | return AdamW(grouped_parameters, lr=learning_rate, eps=adam_epsilon) 14 | 15 | 16 | def get_scheduler(optimizer, scheduler_type, warmup_steps, t_total): 17 | if scheduler_type == "constant": 18 | scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmup_steps) 19 | elif scheduler_type == "linear": 20 | scheduler = WarmupLinearSchedule( 21 | optimizer, warmup_steps=warmup_steps, t_total=t_total 22 | ) 23 | elif scheduler_type == "cosine": 24 | scheduler = WarmupCosineSchedule( 25 | optimizer, warmup_steps=warmup_steps, t_total=t_total 26 | ) 27 | else: 28 | raise ValueError("Unknown scheduler type: {}".format(scheduler_type)) 29 | return scheduler 30 | 31 | -------------------------------------------------------------------------------- /src/timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ 4 | get_model_default_value, is_model_pretrained 5 | -------------------------------------------------------------------------------- /src/timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser 10 | from .real_labels import RealLabelsImagenet 11 | from .transforms import * 12 | from .transforms_factory import create_transform -------------------------------------------------------------------------------- /src/timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, verbose=True): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | elif 'input_size' in default_cfg: 29 | input_size = default_cfg['input_size'] 30 | new_config['input_size'] = input_size 31 | 32 | # resolve interpolation method 33 | new_config['interpolation'] = 'bicubic' 34 | if 'interpolation' in args and args['interpolation']: 35 | new_config['interpolation'] = args['interpolation'] 36 | elif 'interpolation' in default_cfg: 37 | new_config['interpolation'] = default_cfg['interpolation'] 38 | 39 | # resolve dataset + model mean for normalization 40 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 41 | if 'mean' in args and args['mean'] is not None: 42 | mean = tuple(args['mean']) 43 | if len(mean) == 1: 44 | mean = tuple(list(mean) * in_chans) 45 | else: 46 | assert len(mean) == in_chans 47 | new_config['mean'] = mean 48 | elif 'mean' in default_cfg: 49 | new_config['mean'] = default_cfg['mean'] 50 | 51 | # resolve dataset + model std deviation for normalization 52 | new_config['std'] = IMAGENET_DEFAULT_STD 53 | if 'std' in args and args['std'] is not None: 54 | std = tuple(args['std']) 55 | if len(std) == 1: 56 | std = tuple(list(std) * in_chans) 57 | else: 58 | assert len(std) == in_chans 59 | new_config['std'] = std 60 | elif 'std' in default_cfg: 61 | new_config['std'] = default_cfg['std'] 62 | 63 | # resolve default crop percentage 64 | new_config['crop_pct'] = DEFAULT_CROP_PCT 65 | if 'crop_pct' in args and args['crop_pct'] is not None: 66 | new_config['crop_pct'] = args['crop_pct'] 67 | elif 'crop_pct' in default_cfg: 68 | new_config['crop_pct'] = default_cfg['crop_pct'] 69 | 70 | if verbose: 71 | _logger.info('Data processing configuration for current model + dataset:') 72 | for n, v in new_config.items(): 73 | _logger.info('\t%s: %s' % (n, str(v))) 74 | 75 | return new_config 76 | -------------------------------------------------------------------------------- /src/timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /src/timm/data/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import IterableImageDataset, ImageDataset 4 | 5 | 6 | def _search_split(root, split): 7 | # look for sub-folder with name of split in root and use that if it exists 8 | split_name = split.split('[')[0] 9 | try_root = os.path.join(root, split_name) 10 | if os.path.exists(try_root): 11 | return try_root 12 | if split_name == 'validation': 13 | try_root = os.path.join(root, 'val') 14 | if os.path.exists(try_root): 15 | return try_root 16 | return root 17 | 18 | 19 | def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): 20 | name = name.lower() 21 | if name.startswith('tfds'): 22 | ds = IterableImageDataset( 23 | root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) 24 | else: 25 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future 26 | if search_split and os.path.isdir(root): 27 | root = _search_split(root, split) 28 | ds = ImageDataset(root, parser=name, **kwargs) 29 | return ds 30 | -------------------------------------------------------------------------------- /src/timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /src/timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | -------------------------------------------------------------------------------- /src/timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_class_map(filename, root=''): 5 | class_map_path = filename 6 | if not os.path.exists(class_map_path): 7 | class_map_path = os.path.join(root, filename) 8 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename 9 | class_map_ext = os.path.splitext(filename)[-1].lower() 10 | if class_map_ext == '.txt': 11 | with open(class_map_path) as f: 12 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 13 | else: 14 | assert False, 'Unsupported class map extension' 15 | return class_to_idx 16 | 17 | -------------------------------------------------------------------------------- /src/timm/data/parsers/constants.py: -------------------------------------------------------------------------------- 1 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') 2 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_tar import ParserImageTar 5 | from .parser_image_in_tar import ParserImageInTar 6 | 7 | 8 | def create_parser(name, root, split='train', **kwargs): 9 | name = name.lower() 10 | name = name.split('/', 2) 11 | prefix = '' 12 | if len(name) > 1: 13 | prefix = name[0] 14 | name = name[-1] 15 | 16 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 17 | # explicitly select other options shortly 18 | if prefix == 'tfds': 19 | from .parser_tfds import ParserTfds # defer tensorflow import 20 | parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) 21 | else: 22 | assert os.path.exists(root) 23 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 24 | # FIXME support split here, in parser? 25 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 26 | parser = ParserImageInTar(root, **kwargs) 27 | else: 28 | parser = ParserImageFolder(root, **kwargs) 29 | return parser 30 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads images from folders 2 | 3 | Folders are scannerd recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | 10 | from src.timm.utils.misc import natural_key 11 | 12 | from .parser import Parser 13 | from .class_map import load_class_map 14 | from .constants import IMG_EXTENSIONS 15 | 16 | 17 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 18 | labels = [] 19 | filenames = [] 20 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 21 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 22 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 23 | for f in files: 24 | base, ext = os.path.splitext(f) 25 | if ext.lower() in types: 26 | filenames.append(os.path.join(root, f)) 27 | labels.append(label) 28 | if class_to_idx is None: 29 | # building class index 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 34 | if sort: 35 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 36 | return images_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageFolder(Parser): 40 | 41 | def __init__( 42 | self, 43 | root, 44 | class_map=''): 45 | super().__init__() 46 | 47 | self.root = root 48 | class_to_idx = None 49 | if class_map: 50 | class_to_idx = load_class_map(class_map, root) 51 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 52 | if len(self.samples) == 0: 53 | raise RuntimeError( 54 | f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') 55 | 56 | def __getitem__(self, index): 57 | path, target = self.samples[index] 58 | return open(path, 'rb'), target 59 | 60 | def __len__(self): 61 | return len(self.samples) 62 | 63 | def _filename(self, index, basename=False, absolute=False): 64 | filename = self.samples[index][0] 65 | if basename: 66 | filename = os.path.basename(filename) 67 | elif not absolute: 68 | filename = os.path.relpath(filename, self.root) 69 | return filename 70 | -------------------------------------------------------------------------------- /src/timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from .parser import Parser 12 | from .class_map import load_class_map 13 | from .constants import IMG_EXTENSIONS 14 | from src.timm.utils.misc import natural_key 15 | 16 | 17 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 18 | files = [] 19 | labels = [] 20 | for ti in tarfile.getmembers(): 21 | if not ti.isfile(): 22 | continue 23 | dirname, basename = os.path.split(ti.path) 24 | label = os.path.basename(dirname) 25 | ext = os.path.splitext(basename)[1] 26 | if ext.lower() in IMG_EXTENSIONS: 27 | files.append(ti) 28 | labels.append(label) 29 | if class_to_idx is None: 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 34 | if sort: 35 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 36 | return tarinfo_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageTar(Parser): 40 | """ Single tarfile dataset where classes are mapped to folders within tar 41 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 42 | operate on folders of tars or tars in tars. 43 | """ 44 | def __init__(self, root, class_map=''): 45 | super().__init__() 46 | 47 | class_to_idx = None 48 | if class_map: 49 | class_to_idx = load_class_map(class_map, root) 50 | assert os.path.isfile(root) 51 | self.root = root 52 | 53 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 54 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 55 | self.imgs = self.samples 56 | self.tarfile = None # lazy init in __getitem__ 57 | 58 | def __getitem__(self, index): 59 | if self.tarfile is None: 60 | self.tarfile = tarfile.open(self.root) 61 | tarinfo, target = self.samples[index] 62 | fileobj = self.tarfile.extractfile(tarinfo) 63 | return fileobj, target 64 | 65 | def __len__(self): 66 | return len(self.samples) 67 | 68 | def _filename(self, index, basename=False, absolute=False): 69 | filename = self.samples[index][0].name 70 | if basename: 71 | filename = os.path.basename(filename) 72 | return filename 73 | -------------------------------------------------------------------------------- /src/timm/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ Random Erasing (Cutout) 2 | 3 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 4 | Copyright Zhun Zhong & Liang Zheng 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import random 9 | import math 10 | import torch 11 | 12 | 13 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 21 | else: 22 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 23 | 24 | 25 | class RandomErasing: 26 | """ Randomly selects a rectangle region in an image and erases its pixels. 27 | 'Random Erasing Data Augmentation' by Zhong et al. 28 | See https://arxiv.org/pdf/1708.04896.pdf 29 | 30 | This variant of RandomErasing is intended to be applied to either a batch 31 | or single image tensor after it has been normalized by dataset mean and std. 32 | Args: 33 | probability: Probability that the Random Erasing operation will be performed. 34 | min_area: Minimum percentage of erased area wrt input image area. 35 | max_area: Maximum percentage of erased area wrt input image area. 36 | min_aspect: Minimum aspect ratio of erased area. 37 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 38 | 'const' - erase block is constant color of 0 for all channels 39 | 'rand' - erase block is same per-channel random (normal) color 40 | 'pixel' - erase block is per-pixel random (normal) color 41 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 42 | per-image count is randomly chosen between 1 and this value. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, 48 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 49 | self.probability = probability 50 | self.min_area = min_area 51 | self.max_area = max_area 52 | max_aspect = max_aspect or 1 / min_aspect 53 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 54 | self.min_count = min_count 55 | self.max_count = max_count or min_count 56 | self.num_splits = num_splits 57 | mode = mode.lower() 58 | self.rand_color = False 59 | self.per_pixel = False 60 | if mode == 'rand': 61 | self.rand_color = True # per block random normal 62 | elif mode == 'pixel': 63 | self.per_pixel = True # per pixel random normal 64 | else: 65 | assert not mode or mode == 'const' 66 | self.device = device 67 | 68 | def _erase(self, img, chan, img_h, img_w, dtype): 69 | if random.random() > self.probability: 70 | return 71 | area = img_h * img_w 72 | count = self.min_count if self.min_count == self.max_count else \ 73 | random.randint(self.min_count, self.max_count) 74 | for _ in range(count): 75 | for attempt in range(10): 76 | target_area = random.uniform(self.min_area, self.max_area) * area / count 77 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 78 | h = int(round(math.sqrt(target_area * aspect_ratio))) 79 | w = int(round(math.sqrt(target_area / aspect_ratio))) 80 | if w < img_w and h < img_h: 81 | top = random.randint(0, img_h - h) 82 | left = random.randint(0, img_w - w) 83 | img[:, top:top + h, left:left + w] = _get_pixels( 84 | self.per_pixel, self.rand_color, (chan, h, w), 85 | dtype=dtype, device=self.device) 86 | break 87 | 88 | def __call__(self, input): 89 | if len(input.size()) == 3: 90 | self._erase(input, *input.size(), input.dtype) 91 | else: 92 | batch_size, chan, img_h, img_w = input.size() 93 | # skip first slice of batch if num_splits is set (for clean portion of samples) 94 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 95 | for i in range(batch_start, batch_size): 96 | self._erase(input[i], chan, img_h, img_w, input.dtype) 97 | return input 98 | -------------------------------------------------------------------------------- /src/timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /src/timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .byoanet import * 2 | from .byobnet import * 3 | from .cspnet import * 4 | from .densenet import * 5 | from .dla import * 6 | from .dpn import * 7 | from .efficientnet import * 8 | from .ghostnet import * 9 | from .gluon_resnet import * 10 | from .gluon_xception import * 11 | from .hardcorenas import * 12 | from .hrnet import * 13 | from .inception_resnet_v2 import * 14 | from .inception_v3 import * 15 | from .inception_v4 import * 16 | from .mobilenetv3 import * 17 | from .nasnet import * 18 | from .nfnet import * 19 | from .pit import * 20 | from .pnasnet import * 21 | from .regnet import * 22 | from .res2net import * 23 | from .resnest import * 24 | from .resnet import * 25 | from .resnetv2 import * 26 | from .rexnet import * 27 | from .selecsls import * 28 | from .senet import * 29 | from .sknet import * 30 | from .swin_transformer import * 31 | from .tnt import * 32 | from .tresnet import * 33 | from .vgg import * 34 | from .vision_transformer import * 35 | from .vision_transformer_hybrid import * 36 | from .vovnet import * 37 | from .xception import * 38 | from .xception_aligned import * 39 | 40 | from .factory import create_model, split_model_name, safe_model_name 41 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 42 | from .layers import TestTimePoolHead, apply_test_time_pool 43 | from .layers import convert_splitbn_model 44 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 45 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 46 | has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained 47 | -------------------------------------------------------------------------------- /src/timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | from .layers import set_layer_config 4 | from .hub import load_model_config_from_hf 5 | 6 | 7 | def split_model_name(model_name): 8 | model_split = model_name.split(':', 1) 9 | if len(model_split) == 1: 10 | return '', model_split[0] 11 | else: 12 | source_name, model_name = model_split 13 | assert source_name in ('timm', 'hf_hub') 14 | return source_name, model_name 15 | 16 | 17 | def safe_model_name(model_name, remove_source=True): 18 | def make_safe(name): 19 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 20 | if remove_source: 21 | model_name = split_model_name(model_name)[-1] 22 | return make_safe(model_name) 23 | 24 | 25 | def create_model( 26 | model_name, 27 | pretrained=False, 28 | checkpoint_path='', 29 | scriptable=None, 30 | exportable=None, 31 | no_jit=None, 32 | **kwargs): 33 | """Create a model 34 | 35 | Args: 36 | model_name (str): name of model to instantiate 37 | pretrained (bool): load pretrained ImageNet-1k weights if true 38 | checkpoint_path (str): path of checkpoint to load after model is initialized 39 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 40 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 41 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 42 | 43 | Keyword Args: 44 | drop_rate (float): dropout rate for training (default: 0.0) 45 | global_pool (str): global pool type (default: 'avg') 46 | **: other kwargs are model specific 47 | """ 48 | source_name, model_name = split_model_name(model_name) 49 | 50 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 51 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 52 | if not is_efficientnet: 53 | kwargs.pop('bn_tf', None) 54 | kwargs.pop('bn_momentum', None) 55 | kwargs.pop('bn_eps', None) 56 | 57 | # handle backwards compat with drop_connect -> drop_path change 58 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 59 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 60 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 61 | " Setting drop_path to %f." % drop_connect_rate) 62 | kwargs['drop_path_rate'] = drop_connect_rate 63 | 64 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 65 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 66 | # non-supporting models don't break and default args remain in effect. 67 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 68 | 69 | if source_name == 'hf_hub': 70 | # For model names specified in the form `hf_hub:path/architecture_name#revision`, 71 | # load model weights + default_cfg from Hugging Face hub. 72 | hf_default_cfg, model_name = load_model_config_from_hf(model_name) 73 | kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday 74 | 75 | if is_model(model_name): 76 | create_fn = model_entrypoint(model_name) 77 | else: 78 | raise RuntimeError('Unknown model (%s)' % model_name) 79 | 80 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 81 | model = create_fn(pretrained=pretrained, **kwargs) 82 | 83 | if checkpoint_path: 84 | load_checkpoint(model, checkpoint_path) 85 | 86 | return model 87 | -------------------------------------------------------------------------------- /src/timm/models/hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from typing import Union, Optional 6 | 7 | import torch 8 | from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX 9 | try: 10 | from torch.hub import get_dir 11 | except ImportError: 12 | from torch.hub import _get_torch_home as get_dir 13 | 14 | from src.timm import __version__ 15 | try: 16 | from huggingface_hub import hf_hub_url 17 | from huggingface_hub import cached_download 18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__) 19 | except ImportError: 20 | hf_hub_url = None 21 | cached_download = None 22 | 23 | _logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_cache_dir(child_dir=''): 27 | """ 28 | Returns the location of the directory where models are cached (and creates it if necessary). 29 | """ 30 | # Issue warning to move data if old env is set 31 | if os.getenv('TORCH_MODEL_ZOO'): 32 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 33 | 34 | hub_dir = get_dir() 35 | child_dir = () if not child_dir else (child_dir,) 36 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 37 | os.makedirs(model_dir, exist_ok=True) 38 | return model_dir 39 | 40 | 41 | def download_cached_file(url, check_hash=True, progress=False): 42 | parts = urlparse(url) 43 | filename = os.path.basename(parts.path) 44 | cached_file = os.path.join(get_cache_dir(), filename) 45 | if not os.path.exists(cached_file): 46 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) 47 | hash_prefix = None 48 | if check_hash: 49 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 50 | hash_prefix = r.group(1) if r else None 51 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 52 | return cached_file 53 | 54 | 55 | def has_hf_hub(necessary=False): 56 | if hf_hub_url is None and necessary: 57 | # if no HF Hub module installed and it is necessary to continue, raise error 58 | raise RuntimeError( 59 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 60 | return hf_hub_url is not None 61 | 62 | 63 | def hf_split(hf_id): 64 | rev_split = hf_id.split('@') 65 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' 66 | hf_model_id = rev_split[0] 67 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None 68 | return hf_model_id, hf_revision 69 | 70 | 71 | def load_cfg_from_json(json_file: Union[str, os.PathLike]): 72 | with open(json_file, "r", encoding="utf-8") as reader: 73 | text = reader.read() 74 | return json.loads(text) 75 | 76 | 77 | def _download_from_hf(model_id: str, filename: str): 78 | hf_model_id, hf_revision = hf_split(model_id) 79 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision) 80 | return cached_download(url, cache_dir=get_cache_dir('hf')) 81 | 82 | 83 | def load_model_config_from_hf(model_id: str): 84 | assert has_hf_hub(True) 85 | cached_file = _download_from_hf(model_id, 'config.json') 86 | default_cfg = load_cfg_from_json(cached_file) 87 | default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation 88 | model_name = default_cfg.get('architecture') 89 | return default_cfg, model_name 90 | 91 | 92 | def load_state_dict_from_hf(model_id: str): 93 | assert has_hf_hub(True) 94 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin') 95 | state_dict = torch.load(cached_file, map_location='cpu') 96 | return state_dict 97 | -------------------------------------------------------------------------------- /src/timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .anti_aliasing import AntiAliasDownsampleLayer 5 | from .blur_pool import BlurPool2d 6 | from .classifier import ClassifierHead, create_classifier 7 | from .cond_conv2d import CondConv2d, get_condconv_initializer 8 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 9 | set_layer_config 10 | from .conv2d_same import Conv2dSame, conv2d_same 11 | from .conv_bn_act import ConvBnAct 12 | from .create_act import create_act_layer, get_act_layer, get_act_fn 13 | from .create_attn import get_attn, create_attn 14 | from .create_conv2d import create_conv2d 15 | from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act 16 | from .create_self_attn import get_self_attn, create_self_attn 17 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 18 | from .eca import EcaModule, CecaModule 19 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 20 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 21 | from .inplace_abn import InplaceAbn 22 | from .linear import Linear 23 | from .mixed_conv2d import MixedConv2d 24 | from .norm import GroupNorm 25 | from .norm_act import BatchNormAct2d, GroupNormAct 26 | from .padding import get_padding, get_same_padding, pad_same 27 | from .pool2d_same import AvgPool2dSame, create_pool2d 28 | from .se import SEModule 29 | from .selective_kernel import SelectiveKernelConv 30 | from .separable_conv import SeparableConv2d, SeparableConvBnAct 31 | from .space_to_depth import SpaceToDepthModule 32 | from .split_attn import SplitAttnConv2d 33 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 34 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 35 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 36 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ 37 | -------------------------------------------------------------------------------- /src/timm/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /src/timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /src/timm/models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = flatten 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(self.flatten) 91 | self.flatten = False 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return self.pool_type == '' 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | if self.flatten: 109 | x = x.flatten(1) 110 | return x 111 | 112 | def feat_mult(self): 113 | return adaptive_pool_feat_mult(self.pool_type) 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + ' (' \ 117 | + 'pool_type=' + self.pool_type \ 118 | + ', flatten=' + str(self.flatten) + ')' 119 | 120 | -------------------------------------------------------------------------------- /src/timm/models/layers/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AntiAliasDownsampleLayer(nn.Module): 8 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): 9 | super(AntiAliasDownsampleLayer, self).__init__() 10 | if no_jit: 11 | self.op = Downsample(channels, filt_size, stride) 12 | else: 13 | self.op = DownsampleJIT(channels, filt_size, stride) 14 | 15 | # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls 16 | 17 | def forward(self, x): 18 | return self.op(x) 19 | 20 | 21 | @torch.jit.script 22 | class DownsampleJIT(object): 23 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): 24 | self.channels = channels 25 | self.stride = stride 26 | self.filt_size = filt_size 27 | assert self.filt_size == 3 28 | assert stride == 2 29 | self.filt = {} # lazy init by device for DataParallel compat 30 | 31 | def _create_filter(self, like: torch.Tensor): 32 | filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) 33 | filt = filt[:, None] * filt[None, :] 34 | filt = filt / torch.sum(filt) 35 | return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 36 | 37 | def __call__(self, input: torch.Tensor): 38 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 39 | filt = self.filt.get(str(input.device), self._create_filter(input)) 40 | return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) 41 | 42 | 43 | class Downsample(nn.Module): 44 | def __init__(self, channels=None, filt_size=3, stride=2): 45 | super(Downsample, self).__init__() 46 | self.channels = channels 47 | self.filt_size = filt_size 48 | self.stride = stride 49 | 50 | assert self.filt_size == 3 51 | filt = torch.tensor([1., 2., 1.]) 52 | filt = filt[:, None] * filt[None, :] 53 | filt = filt / torch.sum(filt) 54 | 55 | # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 56 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 57 | 58 | def forward(self, input): 59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) 61 | -------------------------------------------------------------------------------- /src/timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | FIXME merge this impl with those in `anti_aliasing.py` 7 | 8 | Hacked together by Chris Ha and Ross Wightman 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from typing import Dict 16 | from .padding import get_padding 17 | 18 | 19 | class BlurPool2d(nn.Module): 20 | r"""Creates a module that computes blurs and downsample a given feature map. 21 | See :cite:`zhang2019shiftinvar` for more details. 22 | Corresponds to the Downsample class, which does blurring and subsampling 23 | 24 | Args: 25 | channels = Number of input channels 26 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 27 | stride (int): downsampling filter stride 28 | 29 | Returns: 30 | torch.Tensor: the transformed tensor. 31 | """ 32 | filt: Dict[str, torch.Tensor] 33 | 34 | def __init__(self, channels, filt_size=3, stride=2) -> None: 35 | super(BlurPool2d, self).__init__() 36 | assert filt_size > 1 37 | self.channels = channels 38 | self.filt_size = filt_size 39 | self.stride = stride 40 | pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 41 | self.padding = nn.ReflectionPad2d(pad_size) 42 | self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat 43 | self.filt = {} # lazy init by device for DataParallel compat 44 | 45 | def _create_filter(self, like: torch.Tensor): 46 | blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) 47 | return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) 48 | 49 | def _apply(self, fn): 50 | # override nn.Module _apply, reset filter cache if used 51 | self.filt = {} 52 | super(BlurPool2d, self)._apply(fn) 53 | 54 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 55 | C = input_tensor.shape[1] 56 | blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) 57 | return F.conv2d( 58 | self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) 59 | -------------------------------------------------------------------------------- /src/timm/models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | import torch.nn.functional as F 14 | from .conv_bn_act import ConvBnAct 15 | 16 | 17 | class ChannelAttn(nn.Module): 18 | """ Original CBAM channel attention module, currently avg + max pool variant only. 19 | """ 20 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 21 | super(ChannelAttn, self).__init__() 22 | self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) 23 | self.act = act_layer(inplace=True) 24 | self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) 25 | 26 | def forward(self, x): 27 | x_avg = x.mean((2, 3), keepdim=True) 28 | x_max = F.adaptive_max_pool2d(x, 1) 29 | x_avg = self.fc2(self.act(self.fc1(x_avg))) 30 | x_max = self.fc2(self.act(self.fc1(x_max))) 31 | x_attn = x_avg + x_max 32 | return x * x_attn.sigmoid() 33 | 34 | 35 | class LightChannelAttn(ChannelAttn): 36 | """An experimental 'lightweight' that sums avg + max pool first 37 | """ 38 | def __init__(self, channels, reduction=16): 39 | super(LightChannelAttn, self).__init__(channels, reduction) 40 | 41 | def forward(self, x): 42 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) 43 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 44 | return x * x_attn.sigmoid() 45 | 46 | 47 | class SpatialAttn(nn.Module): 48 | """ Original CBAM spatial attention module 49 | """ 50 | def __init__(self, kernel_size=7): 51 | super(SpatialAttn, self).__init__() 52 | self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) 53 | 54 | def forward(self, x): 55 | x_avg = torch.mean(x, dim=1, keepdim=True) 56 | x_max = torch.max(x, dim=1, keepdim=True)[0] 57 | x_attn = torch.cat([x_avg, x_max], dim=1) 58 | x_attn = self.conv(x_attn) 59 | return x * x_attn.sigmoid() 60 | 61 | 62 | class LightSpatialAttn(nn.Module): 63 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 64 | """ 65 | def __init__(self, kernel_size=7): 66 | super(LightSpatialAttn, self).__init__() 67 | self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) 68 | 69 | def forward(self, x): 70 | x_avg = torch.mean(x, dim=1, keepdim=True) 71 | x_max = torch.max(x, dim=1, keepdim=True)[0] 72 | x_attn = 0.5 * x_avg + 0.5 * x_max 73 | x_attn = self.conv(x_attn) 74 | return x * x_attn.sigmoid() 75 | 76 | 77 | class CbamModule(nn.Module): 78 | def __init__(self, channels, spatial_kernel_size=7): 79 | super(CbamModule, self).__init__() 80 | self.channel = ChannelAttn(channels) 81 | self.spatial = SpatialAttn(spatial_kernel_size) 82 | 83 | def forward(self, x): 84 | x = self.channel(x) 85 | x = self.spatial(x) 86 | return x 87 | 88 | 89 | class LightCbamModule(nn.Module): 90 | def __init__(self, channels, spatial_kernel_size=7): 91 | super(LightCbamModule, self).__init__() 92 | self.channel = LightChannelAttn(channels) 93 | self.spatial = LightSpatialAttn(spatial_kernel_size) 94 | 95 | def forward(self, x): 96 | x = self.channel(x) 97 | x = self.spatial(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /src/timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | from .linear import Linear 10 | 11 | 12 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 13 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 14 | if not pool_type: 15 | assert num_classes == 0 or use_conv,\ 16 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 17 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 18 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 19 | num_pooled_features = num_features * global_pool.feat_mult() 20 | return global_pool, num_pooled_features 21 | 22 | 23 | def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): 24 | if num_classes <= 0: 25 | fc = nn.Identity() # pass-through (no classifier) 26 | elif use_conv: 27 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 28 | else: 29 | # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue 30 | fc = Linear(num_features, num_classes, bias=True) 31 | return fc 32 | 33 | 34 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 35 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 36 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 37 | return global_pool, fc 38 | 39 | 40 | class ClassifierHead(nn.Module): 41 | """Classifier head w/ configurable global pooling and dropout.""" 42 | 43 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 44 | super(ClassifierHead, self).__init__() 45 | self.drop_rate = drop_rate 46 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 47 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 48 | self.flatten_after_fc = use_conv and pool_type 49 | 50 | def forward(self, x): 51 | x = self.global_pool(x) 52 | if self.drop_rate: 53 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 54 | x = self.fc(x) 55 | return x 56 | -------------------------------------------------------------------------------- /src/timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /src/timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /src/timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import convert_norm_act 9 | 10 | 11 | class ConvBnAct(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 13 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, 14 | drop_block=None): 15 | super(ConvBnAct, self).__init__() 16 | use_aa = aa_layer is not None 17 | 18 | self.conv = create_conv2d( 19 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 20 | padding=padding, dilation=dilation, groups=groups, bias=bias) 21 | 22 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 23 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 25 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | if self.aa is not None: 39 | x = self.aa(x) 40 | return x 41 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .activations import * 5 | from .activations_jit import * 6 | from .activations_me import * 7 | from .config import is_exportable, is_scriptable, is_no_jit 8 | 9 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code 10 | # will use native version if present. Eventually, the custom Swish layers will be removed 11 | # and only native 'silu' will be used. 12 | _has_silu = 'silu' in dir(torch.nn.functional) 13 | 14 | _ACT_FN_DEFAULT = dict( 15 | silu=F.silu if _has_silu else swish, 16 | swish=F.silu if _has_silu else swish, 17 | mish=mish, 18 | relu=F.relu, 19 | relu6=F.relu6, 20 | leaky_relu=F.leaky_relu, 21 | elu=F.elu, 22 | celu=F.celu, 23 | selu=F.selu, 24 | gelu=gelu, 25 | sigmoid=sigmoid, 26 | tanh=tanh, 27 | hard_sigmoid=hard_sigmoid, 28 | hard_swish=hard_swish, 29 | hard_mish=hard_mish, 30 | ) 31 | 32 | _ACT_FN_JIT = dict( 33 | silu=F.silu if _has_silu else swish_jit, 34 | swish=F.silu if _has_silu else swish_jit, 35 | mish=mish_jit, 36 | hard_sigmoid=hard_sigmoid_jit, 37 | hard_swish=hard_swish_jit, 38 | hard_mish=hard_mish_jit 39 | ) 40 | 41 | _ACT_FN_ME = dict( 42 | silu=F.silu if _has_silu else swish_me, 43 | swish=F.silu if _has_silu else swish_me, 44 | mish=mish_me, 45 | hard_sigmoid=hard_sigmoid_me, 46 | hard_swish=hard_swish_me, 47 | hard_mish=hard_mish_me, 48 | ) 49 | 50 | _ACT_LAYER_DEFAULT = dict( 51 | silu=nn.SiLU if _has_silu else Swish, 52 | swish=nn.SiLU if _has_silu else Swish, 53 | mish=Mish, 54 | relu=nn.ReLU, 55 | relu6=nn.ReLU6, 56 | leaky_relu=nn.LeakyReLU, 57 | elu=nn.ELU, 58 | prelu=PReLU, 59 | celu=nn.CELU, 60 | selu=nn.SELU, 61 | gelu=GELU, 62 | sigmoid=Sigmoid, 63 | tanh=Tanh, 64 | hard_sigmoid=HardSigmoid, 65 | hard_swish=HardSwish, 66 | hard_mish=HardMish, 67 | ) 68 | 69 | _ACT_LAYER_JIT = dict( 70 | silu=nn.SiLU if _has_silu else SwishJit, 71 | swish=nn.SiLU if _has_silu else SwishJit, 72 | mish=MishJit, 73 | hard_sigmoid=HardSigmoidJit, 74 | hard_swish=HardSwishJit, 75 | hard_mish=HardMishJit 76 | ) 77 | 78 | _ACT_LAYER_ME = dict( 79 | silu=nn.SiLU if _has_silu else SwishMe, 80 | swish=nn.SiLU if _has_silu else SwishMe, 81 | mish=MishMe, 82 | hard_sigmoid=HardSigmoidMe, 83 | hard_swish=HardSwishMe, 84 | hard_mish=HardMishMe, 85 | ) 86 | 87 | 88 | def get_act_fn(name='relu'): 89 | """ Activation Function Factory 90 | Fetching activation fns by name with this function allows export or torch script friendly 91 | functions to be returned dynamically based on current config. 92 | """ 93 | if not name: 94 | return None 95 | if not (is_no_jit() or is_exportable() or is_scriptable()): 96 | # If not exporting or scripting the model, first look for a memory-efficient version with 97 | # custom autograd, then fallback 98 | if name in _ACT_FN_ME: 99 | return _ACT_FN_ME[name] 100 | if is_exportable() and name in ('silu', 'swish'): 101 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 102 | return swish 103 | if not (is_no_jit() or is_exportable()): 104 | if name in _ACT_FN_JIT: 105 | return _ACT_FN_JIT[name] 106 | return _ACT_FN_DEFAULT[name] 107 | 108 | 109 | def get_act_layer(name='relu'): 110 | """ Activation Layer Factory 111 | Fetching activation layers by name with this function allows export or torch script friendly 112 | functions to be returned dynamically based on current config. 113 | """ 114 | if not name: 115 | return None 116 | if not (is_no_jit() or is_exportable() or is_scriptable()): 117 | if name in _ACT_LAYER_ME: 118 | return _ACT_LAYER_ME[name] 119 | if is_exportable() and name in ('silu', 'swish'): 120 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 121 | return Swish 122 | if not (is_no_jit() or is_exportable()): 123 | if name in _ACT_LAYER_JIT: 124 | return _ACT_LAYER_JIT[name] 125 | return _ACT_LAYER_DEFAULT[name] 126 | 127 | 128 | def create_act_layer(name, inplace=False, **kwargs): 129 | act_layer = get_act_layer(name) 130 | if act_layer is not None: 131 | return act_layer(inplace=inplace, **kwargs) 132 | else: 133 | return None 134 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Select AttentionFactory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from .se import SEModule, EffectiveSEModule 7 | from .eca import EcaModule, CecaModule 8 | from .cbam import CbamModule, LightCbamModule 9 | 10 | 11 | def get_attn(attn_type): 12 | if isinstance(attn_type, torch.nn.Module): 13 | return attn_type 14 | module_cls = None 15 | if attn_type is not None: 16 | if isinstance(attn_type, str): 17 | attn_type = attn_type.lower() 18 | if attn_type == 'se': 19 | module_cls = SEModule 20 | elif attn_type == 'ese': 21 | module_cls = EffectiveSEModule 22 | elif attn_type == 'eca': 23 | module_cls = EcaModule 24 | elif attn_type == 'ceca': 25 | module_cls = CecaModule 26 | elif attn_type == 'cbam': 27 | module_cls = CbamModule 28 | elif attn_type == 'lcbam': 29 | module_cls = LightCbamModule 30 | else: 31 | assert False, "Invalid attn module (%s)" % attn_type 32 | elif isinstance(attn_type, bool): 33 | if attn_type: 34 | module_cls = SEModule 35 | else: 36 | module_cls = attn_type 37 | return module_cls 38 | 39 | 40 | def create_attn(attn_type, channels, **kwargs): 41 | module_cls = get_attn(attn_type) 42 | if module_cls is not None: 43 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 44 | return module_cls(channels, **kwargs) 45 | return None 46 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 26 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 27 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 28 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 29 | else: 30 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 31 | return m 32 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 16 | from .norm_act import BatchNormAct2d, GroupNormAct 17 | from .inplace_abn import InplaceAbn 18 | 19 | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} 20 | _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type 21 | 22 | 23 | def get_norm_act_layer(layer_class): 24 | layer_class = layer_class.replace('_', '').lower() 25 | if layer_class.startswith("batchnorm"): 26 | layer = BatchNormAct2d 27 | elif layer_class.startswith("groupnorm"): 28 | layer = GroupNormAct 29 | elif layer_class == "evonormbatch": 30 | layer = EvoNormBatch2d 31 | elif layer_class == "evonormsample": 32 | layer = EvoNormSample2d 33 | elif layer_class == "iabn" or layer_class == "inplaceabn": 34 | layer = InplaceAbn 35 | else: 36 | assert False, "Invalid norm_act layer (%s)" % layer_class 37 | return layer 38 | 39 | 40 | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): 41 | layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu 42 | assert len(layer_parts) in (1, 2) 43 | layer = get_norm_act_layer(layer_parts[0]) 44 | #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? 45 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 46 | if jit: 47 | layer_instance = torch.jit.script(layer_instance) 48 | return layer_instance 49 | 50 | 51 | def convert_norm_act(norm_layer, act_layer): 52 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 53 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 54 | norm_act_kwargs = {} 55 | 56 | # unbind partial fn, so args can be rebound later 57 | if isinstance(norm_layer, functools.partial): 58 | norm_act_kwargs.update(norm_layer.keywords) 59 | norm_layer = norm_layer.func 60 | 61 | if isinstance(norm_layer, str): 62 | norm_act_layer = get_norm_act_layer(norm_layer) 63 | elif norm_layer in _NORM_ACT_TYPES: 64 | norm_act_layer = norm_layer 65 | elif isinstance(norm_layer, types.FunctionType): 66 | # if function type, must be a lambda/fn that creates a norm_act layer 67 | norm_act_layer = norm_layer 68 | else: 69 | type_name = norm_layer.__name__.lower() 70 | if type_name.startswith('batchnorm'): 71 | norm_act_layer = BatchNormAct2d 72 | elif type_name.startswith('groupnorm'): 73 | norm_act_layer = GroupNormAct 74 | else: 75 | assert False, f"No equivalent norm_act layer for {type_name}" 76 | 77 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 78 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 79 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 80 | norm_act_kwargs.setdefault('act_layer', act_layer) 81 | if norm_act_kwargs: 82 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 83 | return norm_act_layer 84 | -------------------------------------------------------------------------------- /src/timm/models/layers/create_self_attn.py: -------------------------------------------------------------------------------- 1 | from .bottleneck_attn import BottleneckAttn 2 | from .halo_attn import HaloAttn 3 | from .lambda_layer import LambdaLayer 4 | 5 | 6 | def get_self_attn(attn_type): 7 | if attn_type == 'bottleneck': 8 | return BottleneckAttn 9 | elif attn_type == 'halo': 10 | return HaloAttn 11 | elif attn_type == 'lambda': 12 | return LambdaLayer 13 | 14 | 15 | def create_self_attn(attn_type, dim, stride=1, **kwargs): 16 | attn_fn = get_self_attn(attn_type) 17 | return attn_fn(dim, stride=stride, **kwargs) 18 | -------------------------------------------------------------------------------- /src/timm/models/layers/evo_norm.py: -------------------------------------------------------------------------------- 1 | """EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch 2 | 3 | An attempt at getting decent performing EvoNorms running in PyTorch. 4 | While currently faster than other impl, still quite a ways off the built-in BN 5 | in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). 6 | 7 | Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class EvoNormBatch2d(nn.Module): 17 | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): 18 | super(EvoNormBatch2d, self).__init__() 19 | self.apply_act = apply_act # apply activation (non-linearity) 20 | self.momentum = momentum 21 | self.eps = eps 22 | param_shape = (1, num_features, 1, 1) 23 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 24 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 25 | if apply_act: 26 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 27 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.apply_act: 34 | nn.init.ones_(self.v) 35 | 36 | def forward(self, x): 37 | assert x.dim() == 4, 'expected 4D input' 38 | x_type = x.dtype 39 | if self.training: 40 | var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) 41 | n = x.numel() / x.shape[1] 42 | self.running_var.copy_( 43 | var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) 44 | else: 45 | var = self.running_var 46 | 47 | if self.apply_act: 48 | v = self.v.to(dtype=x_type) 49 | d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) 50 | d = d.max((var + self.eps).sqrt().to(dtype=x_type)) 51 | x = x / d 52 | return x * self.weight + self.bias 53 | 54 | 55 | class EvoNormSample2d(nn.Module): 56 | def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): 57 | super(EvoNormSample2d, self).__init__() 58 | self.apply_act = apply_act # apply activation (non-linearity) 59 | self.groups = groups 60 | self.eps = eps 61 | param_shape = (1, num_features, 1, 1) 62 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 63 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 64 | if apply_act: 65 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | nn.init.ones_(self.weight) 70 | nn.init.zeros_(self.bias) 71 | if self.apply_act: 72 | nn.init.ones_(self.v) 73 | 74 | def forward(self, x): 75 | assert x.dim() == 4, 'expected 4D input' 76 | B, C, H, W = x.shape 77 | assert C % self.groups == 0 78 | if self.apply_act: 79 | n = x * (x * self.v).sigmoid() 80 | x = x.reshape(B, self.groups, -1) 81 | x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() 82 | x = x.reshape(B, C, H, W) 83 | return x * self.weight + self.bias 84 | -------------------------------------------------------------------------------- /src/timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /src/timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_block=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /src/timm/models/layers/lambda_layer.py: -------------------------------------------------------------------------------- 1 | """ Lambda Layer 2 | 3 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 4 | - https://arxiv.org/abs/2102.08602 5 | 6 | @misc{2102.08602, 7 | Author = {Irwan Bello}, 8 | Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, 9 | Year = {2021}, 10 | } 11 | 12 | Status: 13 | This impl is a WIP. Code snippets in the paper were used as reference but 14 | good chance some details are missing/wrong. 15 | 16 | I've only implemented local lambda conv based pos embeddings. 17 | 18 | For a PyTorch impl that includes other embedding options checkout 19 | https://github.com/lucidrains/lambda-networks 20 | 21 | Hacked together by / Copyright 2021 Ross Wightman 22 | """ 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | 27 | 28 | 29 | class LambdaLayer(nn.Module): 30 | """Lambda Layer w/ lambda conv position embedding 31 | 32 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 33 | - https://arxiv.org/abs/2102.08602 34 | """ 35 | def __init__( 36 | self, 37 | dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False): 38 | super().__init__() 39 | self.dim_out = dim_out or dim 40 | self.dim_k = dim_head # query depth 'k' 41 | self.num_heads = num_heads 42 | assert self.dim_out % num_heads == 0, ' should be divided by num_heads' 43 | self.dim_v = self.dim_out // num_heads # value depth 'v' 44 | self.r = r # relative position neighbourhood (lambda conv kernel size) 45 | 46 | self.qkv = nn.Conv2d( 47 | dim, 48 | num_heads * dim_head + dim_head + self.dim_v, 49 | kernel_size=1, bias=qkv_bias) 50 | self.norm_q = nn.BatchNorm2d(num_heads * dim_head) 51 | self.norm_v = nn.BatchNorm2d(self.dim_v) 52 | 53 | # NOTE currently only supporting the local lambda convolutions for positional 54 | self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) 55 | 56 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() 57 | 58 | def forward(self, x): 59 | B, C, H, W = x.shape 60 | M = H * W 61 | 62 | qkv = self.qkv(x) 63 | q, k, v = torch.split(qkv, [ 64 | self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) 65 | q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K 66 | v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V 67 | k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M 68 | 69 | content_lam = k @ v # B, K, V 70 | content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V 71 | 72 | position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K 73 | position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V 74 | position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V 75 | 76 | out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W 77 | out = self.pool(out) 78 | return out 79 | -------------------------------------------------------------------------------- /src/timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /src/timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /src/timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /src/timm/models/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GroupNorm(nn.GroupNorm): 9 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): 10 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 11 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 12 | 13 | def forward(self, x): 14 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 15 | -------------------------------------------------------------------------------- /src/timm/models/layers/norm_act.py: -------------------------------------------------------------------------------- 1 | """ Normalization + Activation Layers 2 | """ 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .create_act import get_act_layer 8 | 9 | 10 | class BatchNormAct2d(nn.BatchNorm2d): 11 | """BatchNorm + Activation 12 | 13 | This module performs BatchNorm + Activation in a manner that will remain backwards 14 | compatible with weights trained with separate bn, act. This is why we inherit from BN 15 | instead of composing it as a .bn member. 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, 18 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 19 | super(BatchNormAct2d, self).__init__( 20 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 21 | if isinstance(act_layer, str): 22 | act_layer = get_act_layer(act_layer) 23 | if act_layer is not None and apply_act: 24 | act_args = dict(inplace=True) if inplace else {} 25 | self.act = act_layer(**act_args) 26 | else: 27 | self.act = nn.Identity() 28 | 29 | def _forward_jit(self, x): 30 | """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function 31 | """ 32 | # exponential_average_factor is self.momentum set to 33 | # (when it is available) only so that if gets updated 34 | # in ONNX graph when this node is exported to ONNX. 35 | if self.momentum is None: 36 | exponential_average_factor = 0.0 37 | else: 38 | exponential_average_factor = self.momentum 39 | 40 | if self.training and self.track_running_stats: 41 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 42 | if self.num_batches_tracked is not None: 43 | self.num_batches_tracked += 1 44 | if self.momentum is None: # use cumulative moving average 45 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 46 | else: # use exponential moving average 47 | exponential_average_factor = self.momentum 48 | 49 | x = F.batch_norm( 50 | x, self.running_mean, self.running_var, self.weight, self.bias, 51 | self.training or not self.track_running_stats, 52 | exponential_average_factor, self.eps) 53 | return x 54 | 55 | @torch.jit.ignore 56 | def _forward_python(self, x): 57 | return super(BatchNormAct2d, self).forward(x) 58 | 59 | def forward(self, x): 60 | # FIXME cannot call parent forward() and maintain jit.script compatibility? 61 | if torch.jit.is_scripting(): 62 | x = self._forward_jit(x) 63 | else: 64 | x = self._forward_python(x) 65 | x = self.act(x) 66 | return x 67 | 68 | 69 | class GroupNormAct(nn.GroupNorm): 70 | # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args 71 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, 72 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 73 | super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) 74 | if isinstance(act_layer, str): 75 | act_layer = get_act_layer(act_layer) 76 | if act_layer is not None and apply_act: 77 | act_args = dict(inplace=True) if inplace else {} 78 | self.act = act_layer(**act_args) 79 | else: 80 | self.act = nn.Identity() 81 | 82 | def forward(self, x): 83 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 84 | x = self.act(x) 85 | return x 86 | -------------------------------------------------------------------------------- /src/timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /src/timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | return avg_pool2d_same( 31 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 32 | 33 | 34 | def max_pool2d_same( 35 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 36 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 37 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 38 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 39 | 40 | 41 | class MaxPool2dSame(nn.MaxPool2d): 42 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 43 | """ 44 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): 45 | kernel_size = to_2tuple(kernel_size) 46 | stride = to_2tuple(stride) 47 | dilation = to_2tuple(dilation) 48 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) 49 | 50 | def forward(self, x): 51 | return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) 52 | 53 | 54 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 55 | stride = stride or kernel_size 56 | padding = kwargs.pop('padding', '') 57 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 58 | if is_dynamic: 59 | if pool_type == 'avg': 60 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 61 | elif pool_type == 'max': 62 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 63 | else: 64 | assert False, f'Unsupported pool type {pool_type}' 65 | else: 66 | if pool_type == 'avg': 67 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 68 | elif pool_type == 'max': 69 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | else: 71 | assert False, f'Unsupported pool type {pool_type}' 72 | -------------------------------------------------------------------------------- /src/timm/models/layers/se.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .create_act import create_act_layer 5 | from .helpers import make_divisible 6 | 7 | 8 | class SEModule(nn.Module): 9 | """ SE Module as defined in original SE-Nets with a few additions 10 | Additions include: 11 | * min_channels can be specified to keep reduced channel count at a minimum (default: 8) 12 | * divisor can be specified to keep channels rounded to specified values (default: 1) 13 | * reduction channels can be specified directly by arg (if reduction_channels is set) 14 | * reduction channels can be specified by float ratio (if reduction_ratio is set) 15 | """ 16 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid', 17 | reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1): 18 | super(SEModule, self).__init__() 19 | if reduction_channels is not None: 20 | reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done 21 | elif reduction_ratio is not None: 22 | reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels) 23 | else: 24 | reduction_channels = make_divisible(channels // reduction, divisor, min_channels) 25 | self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) 26 | self.act = act_layer(inplace=True) 27 | self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) 28 | self.gate = create_act_layer(gate_layer) 29 | 30 | def forward(self, x): 31 | x_se = x.mean((2, 3), keepdim=True) 32 | x_se = self.fc1(x_se) 33 | x_se = self.act(x_se) 34 | x_se = self.fc2(x_se) 35 | return x * self.gate(x_se) 36 | 37 | 38 | class EffectiveSEModule(nn.Module): 39 | """ 'Effective Squeeze-Excitation 40 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 41 | """ 42 | def __init__(self, channels, gate_layer='hard_sigmoid'): 43 | super(EffectiveSEModule, self).__init__() 44 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 45 | self.gate = create_act_layer(gate_layer, inplace=True) 46 | 47 | def forward(self, x): 48 | x_se = x.mean((2, 3), keepdim=True) 49 | x_se = self.fc(x_se) 50 | return x * self.gate(x_se) 51 | -------------------------------------------------------------------------------- /src/timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import convert_norm_act 12 | 13 | 14 | class SeparableConvBnAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_block=None): 20 | super(SeparableConvBnAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 30 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 31 | 32 | @property 33 | def in_channels(self): 34 | return self.conv_dw.in_channels 35 | 36 | @property 37 | def out_channels(self): 38 | return self.conv_pw.out_channels 39 | 40 | def forward(self, x): 41 | x = self.conv_dw(x) 42 | x = self.conv_pw(x) 43 | if self.bn is not None: 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | class SeparableConv2d(nn.Module): 49 | """ Separable Conv 50 | """ 51 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 52 | channel_multiplier=1.0, pw_kernel_size=1): 53 | super(SeparableConv2d, self).__init__() 54 | 55 | self.conv_dw = create_conv2d( 56 | in_channels, int(in_channels * channel_multiplier), kernel_size, 57 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 58 | 59 | self.conv_pw = create_conv2d( 60 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 61 | 62 | @property 63 | def in_channels(self): 64 | return self.conv_dw.in_channels 65 | 66 | @property 67 | def out_channels(self): 68 | return self.conv_pw.out_channels 69 | 70 | def forward(self, x): 71 | x = self.conv_dw(x) 72 | x = self.conv_pw(x) 73 | return x 74 | -------------------------------------------------------------------------------- /src/timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /src/timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class RadixSoftmax(nn.Module): 15 | def __init__(self, radix, cardinality): 16 | super(RadixSoftmax, self).__init__() 17 | self.radix = radix 18 | self.cardinality = cardinality 19 | 20 | def forward(self, x): 21 | batch = x.size(0) 22 | if self.radix > 1: 23 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 24 | x = F.softmax(x, dim=1) 25 | x = x.reshape(batch, -1) 26 | else: 27 | x = torch.sigmoid(x) 28 | return x 29 | 30 | 31 | class SplitAttnConv2d(nn.Module): 32 | """Split-Attention Conv2d 33 | """ 34 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 35 | dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, 36 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 37 | super(SplitAttnConv2d, self).__init__() 38 | self.radix = radix 39 | self.drop_block = drop_block 40 | mid_chs = out_channels * radix 41 | attn_chs = max(in_channels * radix // reduction_factor, 32) 42 | 43 | self.conv = nn.Conv2d( 44 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 45 | groups=groups * radix, bias=bias, **kwargs) 46 | self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None 47 | self.act0 = act_layer(inplace=True) 48 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 49 | self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None 50 | self.act1 = act_layer(inplace=True) 51 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 52 | self.rsoftmax = RadixSoftmax(radix, groups) 53 | 54 | @property 55 | def in_channels(self): 56 | return self.conv.in_channels 57 | 58 | @property 59 | def out_channels(self): 60 | return self.fc1.out_channels 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | if self.bn0 is not None: 65 | x = self.bn0(x) 66 | if self.drop_block is not None: 67 | x = self.drop_block(x) 68 | x = self.act0(x) 69 | 70 | B, RC, H, W = x.shape 71 | if self.radix > 1: 72 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 73 | x_gap = x.sum(dim=1) 74 | else: 75 | x_gap = x 76 | x_gap = F.adaptive_avg_pool2d(x_gap, 1) 77 | x_gap = self.fc1(x_gap) 78 | if self.bn1 is not None: 79 | x_gap = self.bn1(x_gap) 80 | x_gap = self.act1(x_gap) 81 | x_attn = self.fc2(x_gap) 82 | 83 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 84 | if self.radix > 1: 85 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 86 | else: 87 | out = x * x_attn 88 | return out.contiguous() 89 | -------------------------------------------------------------------------------- /src/timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /src/timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=True): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /src/timm/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 67 | if mode == 'fan_in': 68 | denom = fan_in 69 | elif mode == 'fan_out': 70 | denom = fan_out 71 | elif mode == 'fan_avg': 72 | denom = (fan_in + fan_out) / 2 73 | 74 | variance = scale / denom 75 | 76 | if distribution == "truncated_normal": 77 | # constant is stddev of standard normal truncated to (-2, 2) 78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 79 | elif distribution == "normal": 80 | tensor.normal_(std=math.sqrt(variance)) 81 | elif distribution == "uniform": 82 | bound = math.sqrt(3 * variance) 83 | tensor.uniform_(-bound, bound) 84 | else: 85 | raise ValueError(f"invalid distribution {distribution}") 86 | 87 | 88 | def lecun_normal_(tensor): 89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 90 | -------------------------------------------------------------------------------- /src/timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_saver import CheckpointSaver 2 | from .cuda import ApexScaler, NativeScaler 3 | from .distributed import distribute_bn, reduce_tensor 4 | from .jit import set_jit_legacy 5 | from .log import setup_default_logging, FormatterNoInfo 6 | from .metrics import AverageMeter, accuracy 7 | from .misc import natural_key, add_bool_arg 8 | from .model import unwrap_model, get_state_dict 9 | from .model_ema import ModelEma, ModelEmaV2 10 | from .summary import update_summary, get_outdir 11 | -------------------------------------------------------------------------------- /src/timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | 15 | class ApexScaler: 16 | state_dict_key = "amp" 17 | 18 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): 19 | with amp.scale_loss(loss, optimizer) as scaled_loss: 20 | scaled_loss.backward(create_graph=create_graph) 21 | if clip_grad is not None: 22 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) 23 | optimizer.step() 24 | 25 | def state_dict(self): 26 | if 'state_dict' in amp.__dict__: 27 | return amp.state_dict() 28 | 29 | def load_state_dict(self, state_dict): 30 | if 'load_state_dict' in amp.__dict__: 31 | amp.load_state_dict(state_dict) 32 | 33 | 34 | class NativeScaler: 35 | state_dict_key = "amp_scaler" 36 | 37 | def __init__(self): 38 | self._scaler = torch.cuda.amp.GradScaler() 39 | 40 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): 41 | self._scaler.scale(loss).backward(create_graph=create_graph) 42 | if clip_grad is not None: 43 | assert parameters is not None 44 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 45 | torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 46 | self._scaler.step(optimizer) 47 | self._scaler.update() 48 | 49 | def state_dict(self): 50 | return self._scaler.state_dict() 51 | 52 | def load_state_dict(self, state_dict): 53 | self._scaler.load_state_dict(state_dict) 54 | -------------------------------------------------------------------------------- /src/timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /src/timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | 8 | def set_jit_legacy(): 9 | """ Set JIT executor to legacy w/ support for op fusion 10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 11 | in the JIT exectutor. These API are not supported so could change. 12 | """ 13 | # 14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 15 | torch._C._jit_set_profiling_executor(False) 16 | torch._C._jit_set_profiling_mode(False) 17 | torch._C._jit_override_can_fuse_on_gpu(True) 18 | #torch._C._jit_set_texpr_fuser_enabled(True) 19 | -------------------------------------------------------------------------------- /src/timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /src/timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /src/timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /src/timm/utils/model.py: -------------------------------------------------------------------------------- 1 | """ Model / state_dict utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from .model_ema import ModelEma 6 | 7 | 8 | def unwrap_model(model): 9 | if isinstance(model, ModelEma): 10 | return unwrap_model(model.ema) 11 | else: 12 | return model.module if hasattr(model, 'module') else model 13 | 14 | 15 | def get_state_dict(model, unwrap_fn=unwrap_model): 16 | return unwrap_fn(model).state_dict() 17 | -------------------------------------------------------------------------------- /src/timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | 9 | 10 | def get_outdir(path, *paths, inc=False): 11 | outdir = os.path.join(path, *paths) 12 | if not os.path.exists(outdir): 13 | os.makedirs(outdir) 14 | elif inc: 15 | count = 1 16 | outdir_inc = outdir + '-' + str(count) 17 | while os.path.exists(outdir_inc): 18 | count = count + 1 19 | outdir_inc = outdir + '-' + str(count) 20 | assert count < 100 21 | outdir = outdir_inc 22 | os.makedirs(outdir) 23 | return outdir 24 | 25 | 26 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): 27 | rowd = OrderedDict(epoch=epoch) 28 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 29 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 30 | with open(filename, mode='a') as cf: 31 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 32 | if write_header: # first iteration (epoch == 1 can't be used) 33 | dw.writeheader() 34 | dw.writerow(rowd) 35 | -------------------------------------------------------------------------------- /src/timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.1' 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SwinBERT/03116f1f3fd7e42d4700a25090f13aa2aa253011/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/deepspeed.py: -------------------------------------------------------------------------------- 1 | from .logger import LOGGER as logger 2 | from pprint import pformat 3 | import torch 4 | 5 | 6 | def get_deepspeed_config(args): 7 | config_params = { 8 | 'train_batch_size': args.effective_batch_size, 9 | } 10 | 11 | use_fp16 = args.deepspeed_fp16 12 | use_amp = not args.deepspeed_fp16 # by default, if not use deepspeed fp16, will enable deepspeed amp 13 | 14 | if use_amp: 15 | config_params['amp'] = { 16 | 'enabled': True, 17 | 'opt_level': f'O{args.amp_opt_level}', 18 | } 19 | 20 | if use_fp16: 21 | config_params['fp16'] = { 22 | 'enabled': True, 23 | } 24 | 25 | gradient_clip = args.max_grad_norm 26 | if gradient_clip: 27 | config_params['gradient_clipping'] = gradient_clip 28 | 29 | config_params['flops_profiler'] = { 30 | 'enabled': False, 31 | 'profile_step': 1, 32 | 'module_depth': -1, 33 | 'top_modules': 3, 34 | 'detailed': True, 35 | } 36 | 37 | config_params['logging'] = { 38 | 'steps_per_print': args.logging_steps*10, 39 | } 40 | if hasattr(args, "zero_opt_stage") and args.zero_opt_stage > 0: 41 | config_params['zero_optimization'] = { 42 | 'stage': args.zero_opt_stage, 43 | } 44 | if args.zero_opt_stage > 0: 45 | config_params['fp16'] = { 46 | 'enabled': True 47 | } 48 | config_params['zero_allow_untested_optimizer'] = True 49 | 50 | logger.info(pformat(config_params)) 51 | return config_params 52 | 53 | 54 | 55 | def fp32_to_fp16(inputs): 56 | # deepspeed does not auto cast inputs. 57 | for k, v in inputs.items(): 58 | if isinstance(v, torch.Tensor) and v.dtype == torch.float32: 59 | v = v.to(dtype=torch.half) 60 | inputs[k] = v 61 | return inputs 62 | -------------------------------------------------------------------------------- /src/utils/load_files.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as op 4 | import errno 5 | import yaml 6 | from collections import OrderedDict 7 | 8 | 9 | def load_labelmap_file(labelmap_file): 10 | label_dict = None 11 | 12 | if labelmap_file.endswith('json'): 13 | label_dict = json.load(open(labelmap_file, 'r'))['label_to_idx'] 14 | label_dict = {key:val-1 for key, val in label_dict.items()} 15 | return label_dict 16 | 17 | if labelmap_file is not None and op.isfile(labelmap_file): 18 | label_dict = OrderedDict() 19 | with open(labelmap_file, 'r') as fp: 20 | for line in fp: 21 | label = line.strip().split('\t')[0] 22 | if label in label_dict: 23 | raise ValueError("Duplicate label " + label + " in labelmap.") 24 | else: 25 | label_dict[label] = len(label_dict) 26 | return label_dict 27 | 28 | 29 | def config_dataset_file(data_dir, dataset_file): 30 | if dataset_file: 31 | if op.isfile(dataset_file): 32 | dataset_file = dataset_file 33 | elif op.isfile(op.join(data_dir, dataset_file)): 34 | dataset_file = op.join(data_dir, dataset_file) 35 | else: 36 | raise ValueError("cannot find file: {}".format(dataset_file)) 37 | return dataset_file 38 | 39 | 40 | def load_linelist_file(linelist_file): 41 | if linelist_file is not None: 42 | line_list = [] 43 | with open(linelist_file, 'r') as fp: 44 | for i in fp: 45 | line_list.append(int(i.strip())) 46 | return line_list 47 | 48 | 49 | def load_box_linelist_file(linelist_file): 50 | if linelist_file is not None: 51 | img_line_list = [] 52 | box_line_list = [] 53 | with open(linelist_file, 'r') as fp: 54 | for i in fp: 55 | idx = [int(_) for _ in i.strip().split('\t')] 56 | img_line_list.append(idx[0]) 57 | box_line_list.append(idx[1]) 58 | return [img_line_list, box_line_list] 59 | 60 | 61 | def load_from_yaml_file(yaml_file): 62 | with open(yaml_file, 'r') as fp: 63 | return yaml.load(fp, Loader=yaml.CLoader) 64 | 65 | 66 | def find_file_path_in_yaml(fname, root): 67 | if fname is not None: 68 | if op.isfile(fname): 69 | return fname 70 | elif op.isfile(op.join(root, fname)): 71 | return op.join(root, fname) 72 | else: 73 | raise FileNotFoundError( 74 | errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) 75 | ) 76 | --------------------------------------------------------------------------------