├── .DS_Store ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── __pycache__ ├── evaluations.cpython-39.pyc ├── model.cpython-39.pyc ├── utils_data.cpython-39.pyc ├── utils_evaluate.cpython-39.pyc └── utils_prompt.cpython-39.pyc ├── data ├── instruct_captions.json └── name_map.json ├── evaluations.py ├── extract_caption.py ├── extract_features.py ├── main.py ├── model.py ├── requirements.txt ├── run_inference.sh ├── run_training.sh ├── timm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── version.cpython-37.pyc │ └── version.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── auto_augment.cpython-37.pyc │ │ ├── auto_augment.cpython-38.pyc │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── constants.cpython-37.pyc │ │ ├── constants.cpython-38.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── dataset_factory.cpython-37.pyc │ │ ├── dataset_factory.cpython-38.pyc │ │ ├── distributed_sampler.cpython-37.pyc │ │ ├── distributed_sampler.cpython-38.pyc │ │ ├── loader.cpython-37.pyc │ │ ├── loader.cpython-38.pyc │ │ ├── mixup.cpython-37.pyc │ │ ├── mixup.cpython-38.pyc │ │ ├── random_erasing.cpython-37.pyc │ │ ├── random_erasing.cpython-38.pyc │ │ ├── real_labels.cpython-37.pyc │ │ ├── real_labels.cpython-38.pyc │ │ ├── transforms.cpython-37.pyc │ │ ├── transforms.cpython-38.pyc │ │ ├── transforms_factory.cpython-37.pyc │ │ └── transforms_factory.cpython-38.pyc │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── parsers │ │ ├── __init__.py │ │ ├── class_map.py │ │ ├── constants.py │ │ ├── parser.py │ │ ├── parser_factory.py │ │ ├── parser_image_folder.py │ │ ├── parser_image_in_tar.py │ │ ├── parser_image_tar.py │ │ └── parser_tfds.py │ ├── random_erasing.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── asymmetric_loss.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── byoanet.cpython-37.pyc │ │ ├── byoanet.cpython-38.pyc │ │ ├── byobnet.cpython-37.pyc │ │ ├── byobnet.cpython-38.pyc │ │ ├── cait.cpython-37.pyc │ │ └── cait.cpython-38.pyc │ ├── byoanet.py │ ├── byobnet.py │ ├── cait.py │ ├── coat.py │ ├── convit.py │ ├── cspnet.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── features.py │ ├── ghostnet.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── hardcorenas.py │ ├── helpers.py │ ├── hrnet.py │ ├── hub.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ └── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── activations.cpython-37.pyc │ │ │ ├── activations.cpython-38.pyc │ │ │ ├── activations_jit.cpython-37.pyc │ │ │ ├── activations_jit.cpython-38.pyc │ │ │ ├── activations_me.cpython-37.pyc │ │ │ ├── activations_me.cpython-38.pyc │ │ │ ├── adaptive_avgmax_pool.cpython-37.pyc │ │ │ ├── adaptive_avgmax_pool.cpython-38.pyc │ │ │ ├── blur_pool.cpython-37.pyc │ │ │ ├── blur_pool.cpython-38.pyc │ │ │ ├── bottleneck_attn.cpython-37.pyc │ │ │ ├── bottleneck_attn.cpython-38.pyc │ │ │ ├── cbam.cpython-37.pyc │ │ │ ├── cbam.cpython-38.pyc │ │ │ ├── classifier.cpython-37.pyc │ │ │ ├── classifier.cpython-38.pyc │ │ │ ├── cond_conv2d.cpython-37.pyc │ │ │ ├── cond_conv2d.cpython-38.pyc │ │ │ ├── config.cpython-37.pyc │ │ │ ├── config.cpython-38.pyc │ │ │ ├── conv2d_same.cpython-37.pyc │ │ │ ├── conv2d_same.cpython-38.pyc │ │ │ ├── conv_bn_act.cpython-37.pyc │ │ │ ├── conv_bn_act.cpython-38.pyc │ │ │ ├── create_act.cpython-37.pyc │ │ │ ├── create_act.cpython-38.pyc │ │ │ ├── create_attn.cpython-37.pyc │ │ │ ├── create_attn.cpython-38.pyc │ │ │ ├── create_conv2d.cpython-37.pyc │ │ │ ├── create_conv2d.cpython-38.pyc │ │ │ ├── create_norm_act.cpython-37.pyc │ │ │ ├── create_norm_act.cpython-38.pyc │ │ │ ├── drop.cpython-37.pyc │ │ │ ├── drop.cpython-38.pyc │ │ │ ├── eca.cpython-37.pyc │ │ │ ├── eca.cpython-38.pyc │ │ │ ├── evo_norm.cpython-37.pyc │ │ │ ├── evo_norm.cpython-38.pyc │ │ │ ├── gather_excite.cpython-37.pyc │ │ │ ├── gather_excite.cpython-38.pyc │ │ │ ├── global_context.cpython-37.pyc │ │ │ ├── global_context.cpython-38.pyc │ │ │ ├── halo_attn.cpython-37.pyc │ │ │ ├── halo_attn.cpython-38.pyc │ │ │ ├── helpers.cpython-37.pyc │ │ │ ├── helpers.cpython-38.pyc │ │ │ ├── inplace_abn.cpython-37.pyc │ │ │ ├── inplace_abn.cpython-38.pyc │ │ │ ├── involution.cpython-37.pyc │ │ │ ├── involution.cpython-38.pyc │ │ │ ├── lambda_layer.cpython-37.pyc │ │ │ ├── lambda_layer.cpython-38.pyc │ │ │ ├── linear.cpython-37.pyc │ │ │ ├── linear.cpython-38.pyc │ │ │ ├── mixed_conv2d.cpython-37.pyc │ │ │ ├── mixed_conv2d.cpython-38.pyc │ │ │ ├── mlp.cpython-37.pyc │ │ │ ├── mlp.cpython-38.pyc │ │ │ ├── non_local_attn.cpython-37.pyc │ │ │ ├── non_local_attn.cpython-38.pyc │ │ │ ├── norm.cpython-37.pyc │ │ │ ├── norm.cpython-38.pyc │ │ │ ├── norm_act.cpython-37.pyc │ │ │ ├── norm_act.cpython-38.pyc │ │ │ ├── padding.cpython-37.pyc │ │ │ ├── padding.cpython-38.pyc │ │ │ ├── patch_embed.cpython-37.pyc │ │ │ ├── patch_embed.cpython-38.pyc │ │ │ ├── pool2d_same.cpython-37.pyc │ │ │ ├── pool2d_same.cpython-38.pyc │ │ │ ├── selective_kernel.cpython-37.pyc │ │ │ ├── selective_kernel.cpython-38.pyc │ │ │ ├── separable_conv.cpython-37.pyc │ │ │ ├── separable_conv.cpython-38.pyc │ │ │ ├── space_to_depth.cpython-37.pyc │ │ │ ├── space_to_depth.cpython-38.pyc │ │ │ ├── split_attn.cpython-37.pyc │ │ │ ├── split_attn.cpython-38.pyc │ │ │ ├── split_batchnorm.cpython-37.pyc │ │ │ ├── split_batchnorm.cpython-38.pyc │ │ │ ├── squeeze_excite.cpython-37.pyc │ │ │ ├── squeeze_excite.cpython-38.pyc │ │ │ ├── std_conv.cpython-37.pyc │ │ │ ├── std_conv.cpython-38.pyc │ │ │ ├── swin_attn.cpython-37.pyc │ │ │ ├── swin_attn.cpython-38.pyc │ │ │ ├── test_time_pool.cpython-37.pyc │ │ │ ├── test_time_pool.cpython-38.pyc │ │ │ ├── weight_init.cpython-37.pyc │ │ │ └── weight_init.cpython-38.pyc │ ├── levit.py │ ├── mlp_mixer.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── nfnet.py │ ├── pit.py │ ├── pnasnet.py │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── resnetv2.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── swin_transformer.py │ ├── tnt.py │ ├── tresnet.py │ ├── twins.py │ ├── vgg.py │ ├── visformer.py │ ├── vision_transformer.py │ ├── vision_transformer_hybrid.py │ ├── vovnet.py │ ├── xception.py │ └── xception_aligned.py ├── optim │ ├── __init__.py │ ├── adabelief.py │ ├── adafactor.py │ ├── adahessian.py │ ├── adamp.py │ ├── adamw.py │ ├── lookahead.py │ ├── nadam.py │ ├── novograd.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ ├── rmsprop_tf.py │ └── sgdp.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── plateau_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils │ ├── __init__.py │ ├── agc.py │ ├── checkpoint_saver.py │ ├── clip_grad.py │ ├── cuda.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ ├── random.py │ └── summary.py └── version.py ├── utils_data.py ├── utils_evaluate.py ├── utils_prompt.py └── vision_features └── mm-cot.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Chain-of-Thought Reasoning in Language Models 2 | 3 |
"Imagine learning a textbook without figures or tables."
4 | 5 | Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output. 6 | 7 | ![](vision_features/mm-cot.png) 8 | 9 | 10 | ## Requirements 11 | 12 | Install all required python dependencies: 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Datasets 19 | 20 | Download the dataset from the following repository: 21 | 22 | ``` 23 | https://github.com/lupantech/ScienceQA/tree/main/data 24 | ``` 25 | The vision features (detr, resnet, clip, vit) are available at https://huggingface.co/cooelf/vision_features/tree/main 26 | 27 | Alternatively, you may download the extracted vision features (detr, resnet, clip) from [vision_features](https://drive.google.com/file/d/13B0hc_F_45-UlqPLKSgRz-ALtFQ8kIJr/view?usp=share_link) and unzip the files under `vision_features` 28 | 29 | ## Extract Features (optional) 30 | 31 | The processed vision features for ScienceQA are available at https://huggingface.co/cooelf/vision_features/tree/main. 32 | 33 | The following instructions show how we obtain those features. 34 | 35 | Download the image files from [Google Drive](https://drive.google.com/drive/folders/1w8imCXWYn2LxajmGeGH_g5DaL2rabHev?usp=sharing) and unzip all the images (train, dev, test) in the same folder (). The structure should be: 36 | 37 | ``` 38 | images 39 | ├── 1 40 | │ └── image.png 41 | ├── 2 42 | │ └── image.png 43 | ├── 3 44 | │ └── image.png 45 | ├── 5 46 | │ └── image.png 47 | ├── 7 48 | │ └── image.png 49 | ``` 50 | 51 | Run ```extract_features.py --data_root images --output_dir vision_features --img_type vit``` 52 | 53 | If you hope to use your own images, please structure those images in the way above, or modify the script ```extract_features.py```. 54 | 55 | ## Extract Captions (optional) 56 | 57 | The processed captions for ScienceQA are available at ```data/instruct_captions.json```. 58 | 59 | The following instructions show how we obtain those features. 60 | 61 | Intall lavis and prepare Vicuna weights to use InstructBLIP for caption extraction. 62 | 63 | https://github.com/salesforce/LAVIS/tree/f982acc73288408bceda2d35471a8fcf55aa04ca/projects/instructblip 64 | 65 | Assume that the images are stored in the ```images``` folder. 66 | 67 | ``` 68 | python extract_caption.py 69 | ``` 70 | 71 | ## Instructions 72 | 73 | ### Training 74 | 75 | ``` 76 | # rationale generation 77 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \ 78 | --data_root data/ScienceQA/data \ 79 | --caption_file data/instruct_captions.json \ 80 | --model declare-lab/flan-alpaca-large \ 81 | --user_msg rationale --img_type vit \ 82 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \ 83 | --use_caption --use_generate --prompt_format QCM-E \ 84 | --output_dir experiments 85 | 86 | # answer inference 87 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \ 88 | --data_root data/ScienceQA/data \ 89 | --caption_file data/instruct_captions.json \ 90 | --model declare-lab/flan-alpaca-large \ 91 | --user_msg answer --img_type vit \ 92 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \ 93 | --use_caption --use_generate --prompt_format QCMG-A \ 94 | --output_dir experiments \ 95 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \ 96 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json 97 | 98 | ``` 99 | 100 | ### Inference 101 | 102 | Our trained models are available at https://huggingface.co/cooelf/mm-cot/tree/main. To use our trained models, please put the them under the ```models``` folder. 103 | 104 | ``` 105 | # rationale generation 106 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \ 107 | --data_root data/ScienceQA/data \ 108 | --caption_file data/instruct_captions.json \ 109 | --model declare-lab/flan-alpaca-large \ 110 | --user_msg rationale --img_type vit \ 111 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \ 112 | --use_caption --use_generate --prompt_format QCM-E \ 113 | --output_dir experiments 114 | --evaluate_dir models/mm-cot-large-rationale 115 | 116 | # answer inference 117 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \ 118 | --data_root data/ScienceQA/data \ 119 | --caption_file data/instruct_captions.json \ 120 | --model declare-lab/flan-alpaca-large \ 121 | --user_msg answer --img_type vit \ 122 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \ 123 | --use_caption --use_generate --prompt_format QCMG-A \ 124 | --output_dir experiments \ 125 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \ 126 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json \ 127 | --evaluate_dir models/mm-cot-large-answer 128 | ``` 129 | 130 | ## Citing MM-CoT 131 | 132 | ``` 133 | @article{zhang2023multicot, 134 | title={Multimodal Chain-of-Thought Reasoning in Language Models}, 135 | author={Zhang, Zhuosheng and Zhang, Aston and Li, Mu and Zhao, Hai and Karypis, George and Smola, Alex}, 136 | journal={arXiv preprint arXiv:2302.00923}, 137 | year={2023} 138 | } 139 | ``` 140 | 141 | ## License 142 | 143 | This project is licensed under the Apache-2.0 License. 144 | 145 | ## Acknowledgement 146 | 147 | Part of our codes are adapted from [ScienceQA](https://github.com/lupantech/ScienceQA), [Transformers](https://github.com/huggingface/transformers), [pytorch-image-models](https://github.com/huggingface/pytorch-image-models). 148 | 149 | We thank [Pan Lu](https://lupantech.github.io/) for providing parameter size for ScienceQA baselines. 150 | -------------------------------------------------------------------------------- /__pycache__/evaluations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/evaluations.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_data.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils_evaluate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_evaluate.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/utils_prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_prompt.cpython-39.pyc -------------------------------------------------------------------------------- /evaluations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/lupantech/ScienceQA 3 | ''' 4 | 5 | import re 6 | from rouge import Rouge 7 | from nltk.translate.bleu_score import sentence_bleu 8 | from sentence_transformers import util 9 | 10 | ######################## 11 | ## BLEU 12 | ######################## 13 | def tokenize(text): 14 | tokens = re.split(r'\s|\.', text) 15 | tokens = [t for t in tokens if len(t) > 0] 16 | return tokens 17 | 18 | 19 | def bleu_score(reference, hypothesis, gram): 20 | reference_tokens = tokenize(reference) 21 | hypothesis_tokens = tokenize(hypothesis) 22 | 23 | if gram == 1: 24 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1 25 | elif gram == 2: 26 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2 27 | elif gram == 3: 28 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3 29 | elif gram == 4: 30 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4 31 | 32 | return bleu 33 | 34 | 35 | def caculate_bleu(results, data, gram): 36 | bleus = [] 37 | for qid, output in results.items(): 38 | prediction = output 39 | target = data[qid] 40 | target = target.strip() 41 | if target == "": 42 | continue 43 | bleu = bleu_score(target, prediction, gram) 44 | bleus.append(bleu) 45 | 46 | avg_bleu = sum(bleus) / len(bleus) 47 | 48 | return avg_bleu 49 | 50 | 51 | ######################## 52 | ## Rouge-L 53 | ######################## 54 | def score_rouge(str1, str2): 55 | rouge = Rouge(metrics=["rouge-l"]) 56 | scores = rouge.get_scores(str1, str2, avg=True) 57 | rouge_l = scores['rouge-l']['f'] 58 | return rouge_l 59 | 60 | 61 | def caculate_rouge(results, data): 62 | rouges = [] 63 | for qid, output in results.items(): 64 | prediction = output 65 | target = data[qid] 66 | target = target.strip() 67 | if prediction == "": 68 | continue 69 | if target == "": 70 | continue 71 | rouge = score_rouge(target, prediction) 72 | rouges.append(rouge) 73 | 74 | avg_rouge = sum(rouges) / len(rouges) 75 | return avg_rouge 76 | 77 | 78 | ######################## 79 | ## Sentence Similarity 80 | ######################## 81 | def similariry_score(str1, str2, model): 82 | # compute embedding for both lists 83 | embedding_1 = model.encode(str1, convert_to_tensor=True) 84 | embedding_2 = model.encode(str2, convert_to_tensor=True) 85 | score = util.pytorch_cos_sim(embedding_1, embedding_2).item() 86 | return score 87 | 88 | 89 | def caculate_similariry(results, data, model): 90 | scores = [] 91 | for qid, output in results.items(): 92 | prediction = output 93 | target = data[qid] 94 | target = target.strip() 95 | 96 | score = similariry_score(target, prediction, model) 97 | scores.append(score) 98 | 99 | avg_score = sum(scores) / len(scores) 100 | return avg_score 101 | -------------------------------------------------------------------------------- /extract_caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import os 4 | from tqdm import tqdm 5 | from lavis.models import load_model_and_preprocess 6 | import json 7 | 8 | # loads InstructBLIP model 9 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 10 | model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device) 11 | 12 | data_root = "data/images" 13 | output_dir = "data/instruct_captions.json" 14 | 15 | all_images = os.listdir(data_root) 16 | all_images.sort(key=lambda x:int(x)) 17 | 18 | name_map = {} 19 | 20 | for image in tqdm(all_images): 21 | if os.path.exists(os.path.join(data_root, image, "image.png")): 22 | curr_dir = os.path.join(data_root, image, "image.png") 23 | else: 24 | curr_dir = os.path.join(data_root, image, "choice_0.png") 25 | raw_image = Image.open(curr_dir).convert("RGB") 26 | # prepare the image 27 | image_features = vis_processors["eval"](raw_image).unsqueeze(0).to(device) 28 | output = model.generate({"image": image_features, "prompt": "Write a detailed description."}) 29 | name_map[str(image)] = output 30 | 31 | with open(output_dir, 'w') as outfile: 32 | json.dump(name_map, outfile, indent=2) -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import timm 5 | from timm.data import resolve_data_config 6 | from timm.data.transforms_factory import create_transform 7 | import os 8 | import argparse 9 | import json 10 | from tqdm import tqdm 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data_root', type=str, default='images') 15 | parser.add_argument('--output_dir', type=str, default='vision_features') 16 | parser.add_argument('--img_type', type=str, default="vit", choices=['detr', 'vit'], help='type of image features') 17 | args = parser.parse_args() 18 | return args 19 | 20 | def extract_features(img_type, input_image): 21 | if img_type == "vit": 22 | config = resolve_data_config({}, model=vit_model) 23 | transform = create_transform(**config) 24 | with torch.no_grad(): 25 | img = Image.open(input_image).convert("RGB") 26 | input = transform(img).unsqueeze(0) 27 | feature = vit_model.forward_features(input) 28 | return feature 29 | 30 | elif img_type == "detr": 31 | transform = T.Compose([ 32 | T.Resize(224), 33 | T.ToTensor(), 34 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | with torch.no_grad(): 37 | img = Image.open(input_image).convert("RGB") 38 | input = transform(img).unsqueeze(0) 39 | feature = detr_model(input)[-1] 40 | return feature 41 | 42 | if __name__ == '__main__': 43 | args = parse_args() 44 | print("args",args) 45 | all_images = os.listdir(args.data_root) 46 | tmp = [] 47 | name_map = {} 48 | all_images.sort(key=lambda x:int(x)) 49 | print(len(all_images)) 50 | if args.img_type == "vit": 51 | vit_model = timm.create_model("vit_large_patch32_384", pretrained=True, num_classes=0) 52 | vit_model.eval() 53 | elif args.img_type == "detr": 54 | detr_model = torch.hub.load('cooelf/detr', 'detr_resnet101_dc5', pretrained=True) 55 | detr_model.eval() 56 | for idx, image in enumerate(tqdm(all_images)): 57 | if idx % 100 == 0: print(idx) 58 | if os.path.exists(os.path.join(args.data_root, image, "image.png")): 59 | curr_dir = os.path.join(args.data_root, image, "image.png") 60 | else: 61 | curr_dir = os.path.join(args.data_root, image, "choice_0.png") 62 | feature = extract_features(args.img_type, curr_dir) 63 | tmp.append(feature.detach().cpu()) 64 | name_map[str(image)] = idx 65 | 66 | res = torch.cat(tmp).cpu() 67 | print(res.shape) 68 | torch.save(res, os.path.join(args.output_dir, args.img_type +'.pth')) 69 | with open(os.path.join(args.output_dir, 'name_map.json'), 'w') as outfile: 70 | json.dump(name_map, outfile) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub>=0.4.0 2 | numpy==1.23.2 3 | openai==0.23.0 4 | pandas==1.4.3 5 | rouge==1.0.1 6 | sentence-transformers==2.2.2 7 | transformers==4.30.0 8 | nltk==3.6.6 9 | evaluate==0.4.0 10 | rouge==1.0.1 11 | rouge_score==0.1.2 12 | rich>=13.3.2 13 | -------------------------------------------------------------------------------- /run_inference.sh: -------------------------------------------------------------------------------- 1 | # base 2 | CUDA_VISIBLE_DEVICES=0 python main_central.py \ 3 | --data_root data/ScienceQA/data \ 4 | --caption_file data/instruct_captions.json \ 5 | --model declare-lab/flan-alpaca-base \ 6 | --user_msg rationale --img_type vit \ 7 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 512 \ 8 | --use_caption --use_generate --final_eval --prompt_format QCM-E \ 9 | --output_dir experiments \ 10 | --evaluate_dir models/mm-cot-base-rationale 11 | 12 | CUDA_VISIBLE_DEVICES=0 python main_central.py \ 13 | --data_root data/ScienceQA/data \ 14 | --caption_file data/instruct_captions.json \ 15 | --model declare-lab/flan-alpaca-base \ 16 | --user_msg rationale --img_type vit \ 17 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 64 \ 18 | --use_caption --use_generate --prompt_format QCMG-A \ 19 | --output_dir experiments \ 20 | --eval_le models/mm-cot-base-rationale/predictions_ans_eval.json \ 21 | --test_le models/mm-cot-base-rationale/predictions_ans_test.json \ 22 | --evaluate_dir models/mm-cot-base-answer 23 | 24 | # large 25 | # rationale generation 26 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \ 27 | --data_root data/ScienceQA/data \ 28 | --caption_file data/instruct_captions.json \ 29 | --model declare-lab/flan-alpaca-large \ 30 | --user_msg rationale --img_type vit \ 31 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \ 32 | --use_caption --use_generate --prompt_format QCM-E \ 33 | --output_dir experiments \ 34 | --evaluate_dir models/mm-cot-large-rationale 35 | 36 | # answer inference 37 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \ 38 | --data_root data/ScienceQA/data \ 39 | --caption_file data/instruct_captions.json \ 40 | --model declare-lab/flan-alpaca-large \ 41 | --user_msg answer --img_type vit \ 42 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \ 43 | --use_caption --use_generate --prompt_format QCMG-A \ 44 | --output_dir experiments \ 45 | --eval_le models/mm-cot-large-rationale/predictions_ans_eval.json \ 46 | --test_le models/mm-cot-large-rationale/predictions_ans_test.json \ 47 | --evaluate_dir models/mm-cot-large-answer -------------------------------------------------------------------------------- /run_training.sh: -------------------------------------------------------------------------------- 1 | # base 2 | CUDA_VISIBLE_DEVICES=0 python main_central.py \ 3 | --data_root data/ScienceQA/data \ 4 | --caption_file data/instruct_captions.json \ 5 | --model declare-lab/flan-alpaca-base \ 6 | --user_msg rationale --img_type vit \ 7 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 512 \ 8 | --use_caption --use_generate --final_eval --prompt_format QCM-E \ 9 | --output_dir experiments0620 10 | 11 | CUDA_VISIBLE_DEVICES=0 python main_central.py \ 12 | --data_root data/ScienceQA/data \ 13 | --caption_file data/instruct_captions.json \ 14 | --model declare-lab/flan-alpaca-base \ 15 | --user_msg rationale --img_type vit \ 16 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 64 \ 17 | --use_caption --use_generate --prompt_format QCMG-A \ 18 | --output_dir experiments0620 \ 19 | --eval_le experiments/rationale_declare-lab-flan-alpaca-base_vit_QCM-E_lr8e-05_bs8_op512_ep20/predictions_ans_eval.json \ 20 | --test_le experiments/rationale_declare-lab-flan-alpaca-base_vit_QCM-E_lr8e-05_bs8_op512_ep20/predictions_ans_test.json 21 | 22 | # large 23 | # rationale generation 24 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \ 25 | --data_root data/ScienceQA/data \ 26 | --caption_file data/instruct_captions.json \ 27 | --model declare-lab/flan-alpaca-large \ 28 | --user_msg rationale --img_type vit \ 29 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \ 30 | --use_caption --use_generate --prompt_format QCM-E \ 31 | --output_dir experiments 32 | 33 | # answer inference 34 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \ 35 | --data_root data/ScienceQA/data \ 36 | --caption_file data/instruct_captions.json \ 37 | --model declare-lab/flan-alpaca-large \ 38 | --user_msg answer --img_type vit \ 39 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \ 40 | --use_caption --use_generate --prompt_format QCMG-A \ 41 | --output_dir experiments \ 42 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \ 43 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ 4 | get_model_default_value, is_model_pretrained 5 | -------------------------------------------------------------------------------- /timm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /timm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /timm/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /timm/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser 10 | from .real_labels import RealLabelsImagenet 11 | from .transforms import * 12 | from .transforms_factory import create_transform -------------------------------------------------------------------------------- /timm/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/auto_augment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/auto_augment.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/auto_augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/auto_augment.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/constants.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/constants.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset_factory.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/dataset_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset_factory.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/distributed_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/distributed_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/distributed_sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/distributed_sampler.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/loader.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/mixup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/mixup.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/mixup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/mixup.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/random_erasing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/random_erasing.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/random_erasing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/random_erasing.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/real_labels.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/real_labels.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/real_labels.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/real_labels.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms_factory.cpython-37.pyc -------------------------------------------------------------------------------- /timm/data/__pycache__/transforms_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms_factory.cpython-38.pyc -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | else: 29 | if use_test_size and 'test_input_size' in default_cfg: 30 | input_size = default_cfg['test_input_size'] 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if 'interpolation' in args and args['interpolation']: 38 | new_config['interpolation'] = args['interpolation'] 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if 'mean' in args and args['mean'] is not None: 45 | mean = tuple(args['mean']) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if 'std' in args and args['std'] is not None: 57 | std = tuple(args['std']) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | new_config['crop_pct'] = DEFAULT_CROP_PCT 68 | if 'crop_pct' in args and args['crop_pct'] is not None: 69 | new_config['crop_pct'] = args['crop_pct'] 70 | elif 'crop_pct' in default_cfg: 71 | new_config['crop_pct'] = default_cfg['crop_pct'] 72 | 73 | if verbose: 74 | _logger.info('Data processing configuration for current model + dataset:') 75 | for n, v in new_config.items(): 76 | _logger.info('\t%s: %s' % (n, str(v))) 77 | 78 | return new_config 79 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ Quick n Simple Image Folder, Tarfile based DataSet 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch.utils.data as data 6 | import os 7 | import torch 8 | import logging 9 | 10 | from PIL import Image 11 | 12 | from .parsers import create_parser 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | _ERROR_RETRY = 50 18 | 19 | 20 | class ImageDataset(data.Dataset): 21 | 22 | def __init__( 23 | self, 24 | root, 25 | parser=None, 26 | class_map='', 27 | load_bytes=False, 28 | transform=None, 29 | ): 30 | if parser is None or isinstance(parser, str): 31 | parser = create_parser(parser or '', root=root, class_map=class_map) 32 | self.parser = parser 33 | self.load_bytes = load_bytes 34 | self.transform = transform 35 | self._consecutive_errors = 0 36 | 37 | def __getitem__(self, index): 38 | img, target = self.parser[index] 39 | try: 40 | img = img.read() if self.load_bytes else Image.open(img).convert('RGB') 41 | except Exception as e: 42 | _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') 43 | self._consecutive_errors += 1 44 | if self._consecutive_errors < _ERROR_RETRY: 45 | return self.__getitem__((index + 1) % len(self.parser)) 46 | else: 47 | raise e 48 | self._consecutive_errors = 0 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | if target is None: 52 | target = torch.tensor(-1, dtype=torch.long) 53 | return img, target 54 | 55 | def __len__(self): 56 | return len(self.parser) 57 | 58 | def filename(self, index, basename=False, absolute=False): 59 | return self.parser.filename(index, basename, absolute) 60 | 61 | def filenames(self, basename=False, absolute=False): 62 | return self.parser.filenames(basename, absolute) 63 | 64 | 65 | class IterableImageDataset(data.IterableDataset): 66 | 67 | def __init__( 68 | self, 69 | root, 70 | parser=None, 71 | split='train', 72 | is_training=False, 73 | batch_size=None, 74 | class_map='', 75 | load_bytes=False, 76 | repeats=0, 77 | transform=None, 78 | ): 79 | assert parser is not None 80 | if isinstance(parser, str): 81 | self.parser = create_parser( 82 | parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) 83 | else: 84 | self.parser = parser 85 | self.transform = transform 86 | self._consecutive_errors = 0 87 | 88 | def __iter__(self): 89 | for img, target in self.parser: 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | if target is None: 93 | target = torch.tensor(-1, dtype=torch.long) 94 | yield img, target 95 | 96 | def __len__(self): 97 | if hasattr(self.parser, '__len__'): 98 | return len(self.parser) 99 | else: 100 | return 0 101 | 102 | def filename(self, index, basename=False, absolute=False): 103 | assert False, 'Filename lookup by index not supported, use filenames().' 104 | 105 | def filenames(self, basename=False, absolute=False): 106 | return self.parser.filenames(basename, absolute) 107 | 108 | 109 | class AugMixDataset(torch.utils.data.Dataset): 110 | """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" 111 | 112 | def __init__(self, dataset, num_splits=2): 113 | self.augmentation = None 114 | self.normalize = None 115 | self.dataset = dataset 116 | if self.dataset.transform is not None: 117 | self._set_transforms(self.dataset.transform) 118 | self.num_splits = num_splits 119 | 120 | def _set_transforms(self, x): 121 | assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' 122 | self.dataset.transform = x[0] 123 | self.augmentation = x[1] 124 | self.normalize = x[2] 125 | 126 | @property 127 | def transform(self): 128 | return self.dataset.transform 129 | 130 | @transform.setter 131 | def transform(self, x): 132 | self._set_transforms(x) 133 | 134 | def _normalize(self, x): 135 | return x if self.normalize is None else self.normalize(x) 136 | 137 | def __getitem__(self, i): 138 | x, y = self.dataset[i] # all splits share the same dataset base transform 139 | x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) 140 | # run the full augmentation on the remaining splits 141 | for _ in range(self.num_splits - 1): 142 | x_list.append(self._normalize(self.augmentation(x))) 143 | return tuple(x_list), y 144 | 145 | def __len__(self): 146 | return len(self.dataset) 147 | -------------------------------------------------------------------------------- /timm/data/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import IterableImageDataset, ImageDataset 4 | 5 | 6 | def _search_split(root, split): 7 | # look for sub-folder with name of split in root and use that if it exists 8 | split_name = split.split('[')[0] 9 | try_root = os.path.join(root, split_name) 10 | if os.path.exists(try_root): 11 | return try_root 12 | if split_name == 'validation': 13 | try_root = os.path.join(root, 'val') 14 | if os.path.exists(try_root): 15 | return try_root 16 | return root 17 | 18 | 19 | def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): 20 | name = name.lower() 21 | if name.startswith('tfds'): 22 | ds = IterableImageDataset( 23 | root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) 24 | else: 25 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future 26 | kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier 27 | if search_split and os.path.isdir(root): 28 | root = _search_split(root, split) 29 | ds = ImageDataset(root, parser=name, **kwargs) 30 | return ds 31 | -------------------------------------------------------------------------------- /timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | -------------------------------------------------------------------------------- /timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_class_map(filename, root=''): 5 | class_map_path = filename 6 | if not os.path.exists(class_map_path): 7 | class_map_path = os.path.join(root, filename) 8 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename 9 | class_map_ext = os.path.splitext(filename)[-1].lower() 10 | if class_map_ext == '.txt': 11 | with open(class_map_path) as f: 12 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 13 | else: 14 | assert False, 'Unsupported class map extension' 15 | return class_to_idx 16 | 17 | -------------------------------------------------------------------------------- /timm/data/parsers/constants.py: -------------------------------------------------------------------------------- 1 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') 2 | -------------------------------------------------------------------------------- /timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_tar import ParserImageTar 5 | from .parser_image_in_tar import ParserImageInTar 6 | 7 | 8 | def create_parser(name, root, split='train', **kwargs): 9 | name = name.lower() 10 | name = name.split('/', 2) 11 | prefix = '' 12 | if len(name) > 1: 13 | prefix = name[0] 14 | name = name[-1] 15 | 16 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 17 | # explicitly select other options shortly 18 | if prefix == 'tfds': 19 | from .parser_tfds import ParserTfds # defer tensorflow import 20 | parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) 21 | else: 22 | assert os.path.exists(root) 23 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 24 | # FIXME support split here, in parser? 25 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 26 | parser = ParserImageInTar(root, **kwargs) 27 | else: 28 | parser = ParserImageFolder(root, **kwargs) 29 | return parser 30 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads images from folders 2 | 3 | Folders are scannerd recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | 10 | from timm.utils.misc import natural_key 11 | 12 | from .parser import Parser 13 | from .class_map import load_class_map 14 | from .constants import IMG_EXTENSIONS 15 | 16 | 17 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 18 | labels = [] 19 | filenames = [] 20 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 21 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 22 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 23 | for f in files: 24 | base, ext = os.path.splitext(f) 25 | if ext.lower() in types: 26 | filenames.append(os.path.join(root, f)) 27 | labels.append(label) 28 | if class_to_idx is None: 29 | # building class index 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 34 | if sort: 35 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 36 | return images_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageFolder(Parser): 40 | 41 | def __init__( 42 | self, 43 | root, 44 | class_map=''): 45 | super().__init__() 46 | 47 | self.root = root 48 | class_to_idx = None 49 | if class_map: 50 | class_to_idx = load_class_map(class_map, root) 51 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 52 | if len(self.samples) == 0: 53 | raise RuntimeError( 54 | f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') 55 | 56 | def __getitem__(self, index): 57 | path, target = self.samples[index] 58 | return open(path, 'rb'), target 59 | 60 | def __len__(self): 61 | return len(self.samples) 62 | 63 | def _filename(self, index, basename=False, absolute=False): 64 | filename = self.samples[index][0] 65 | if basename: 66 | filename = os.path.basename(filename) 67 | elif not absolute: 68 | filename = os.path.relpath(filename, self.root) 69 | return filename 70 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from .parser import Parser 12 | from .class_map import load_class_map 13 | from .constants import IMG_EXTENSIONS 14 | from timm.utils.misc import natural_key 15 | 16 | 17 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 18 | files = [] 19 | labels = [] 20 | for ti in tarfile.getmembers(): 21 | if not ti.isfile(): 22 | continue 23 | dirname, basename = os.path.split(ti.path) 24 | label = os.path.basename(dirname) 25 | ext = os.path.splitext(basename)[1] 26 | if ext.lower() in IMG_EXTENSIONS: 27 | files.append(ti) 28 | labels.append(label) 29 | if class_to_idx is None: 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 34 | if sort: 35 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 36 | return tarinfo_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageTar(Parser): 40 | """ Single tarfile dataset where classes are mapped to folders within tar 41 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 42 | operate on folders of tars or tars in tars. 43 | """ 44 | def __init__(self, root, class_map=''): 45 | super().__init__() 46 | 47 | class_to_idx = None 48 | if class_map: 49 | class_to_idx = load_class_map(class_map, root) 50 | assert os.path.isfile(root) 51 | self.root = root 52 | 53 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 54 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 55 | self.imgs = self.samples 56 | self.tarfile = None # lazy init in __getitem__ 57 | 58 | def __getitem__(self, index): 59 | if self.tarfile is None: 60 | self.tarfile = tarfile.open(self.root) 61 | tarinfo, target = self.samples[index] 62 | fileobj = self.tarfile.extractfile(tarinfo) 63 | return fileobj, target 64 | 65 | def __len__(self): 66 | return len(self.samples) 67 | 68 | def _filename(self, index, basename=False, absolute=False): 69 | filename = self.samples[index][0].name 70 | if basename: 71 | filename = os.path.basename(filename) 72 | return filename 73 | -------------------------------------------------------------------------------- /timm/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ Random Erasing (Cutout) 2 | 3 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 4 | Copyright Zhun Zhong & Liang Zheng 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import random 9 | import math 10 | import torch 11 | 12 | 13 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 21 | else: 22 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 23 | 24 | 25 | class RandomErasing: 26 | """ Randomly selects a rectangle region in an image and erases its pixels. 27 | 'Random Erasing Data Augmentation' by Zhong et al. 28 | See https://arxiv.org/pdf/1708.04896.pdf 29 | 30 | This variant of RandomErasing is intended to be applied to either a batch 31 | or single image tensor after it has been normalized by dataset mean and std. 32 | Args: 33 | probability: Probability that the Random Erasing operation will be performed. 34 | min_area: Minimum percentage of erased area wrt input image area. 35 | max_area: Maximum percentage of erased area wrt input image area. 36 | min_aspect: Minimum aspect ratio of erased area. 37 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 38 | 'const' - erase block is constant color of 0 for all channels 39 | 'rand' - erase block is same per-channel random (normal) color 40 | 'pixel' - erase block is per-pixel random (normal) color 41 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 42 | per-image count is randomly chosen between 1 and this value. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, 48 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 49 | self.probability = probability 50 | self.min_area = min_area 51 | self.max_area = max_area 52 | max_aspect = max_aspect or 1 / min_aspect 53 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 54 | self.min_count = min_count 55 | self.max_count = max_count or min_count 56 | self.num_splits = num_splits 57 | mode = mode.lower() 58 | self.rand_color = False 59 | self.per_pixel = False 60 | if mode == 'rand': 61 | self.rand_color = True # per block random normal 62 | elif mode == 'pixel': 63 | self.per_pixel = True # per pixel random normal 64 | else: 65 | assert not mode or mode == 'const' 66 | self.device = device 67 | 68 | def _erase(self, img, chan, img_h, img_w, dtype): 69 | if random.random() > self.probability: 70 | return 71 | area = img_h * img_w 72 | count = self.min_count if self.min_count == self.max_count else \ 73 | random.randint(self.min_count, self.max_count) 74 | for _ in range(count): 75 | for attempt in range(10): 76 | target_area = random.uniform(self.min_area, self.max_area) * area / count 77 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 78 | h = int(round(math.sqrt(target_area * aspect_ratio))) 79 | w = int(round(math.sqrt(target_area / aspect_ratio))) 80 | if w < img_w and h < img_h: 81 | top = random.randint(0, img_h - h) 82 | left = random.randint(0, img_w - w) 83 | img[:, top:top + h, left:left + w] = _get_pixels( 84 | self.per_pixel, self.rand_color, (chan, h, w), 85 | dtype=dtype, device=self.device) 86 | break 87 | 88 | def __call__(self, input): 89 | if len(input.size()) == 3: 90 | self._erase(input, *input.size(), input.dtype) 91 | else: 92 | batch_size, chan, img_h, img_w = input.size() 93 | # skip first slice of batch if num_splits is set (for clean portion of samples) 94 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 95 | for i in range(batch_start, batch_size): 96 | self._erase(input[i], chan, img_h, img_w, input.dtype) 97 | return input 98 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /timm/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | from PIL import Image 4 | import warnings 5 | import math 6 | import random 7 | import numpy as np 8 | 9 | 10 | class ToNumpy: 11 | 12 | def __call__(self, pil_img): 13 | np_img = np.array(pil_img, dtype=np.uint8) 14 | if np_img.ndim < 3: 15 | np_img = np.expand_dims(np_img, axis=-1) 16 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 17 | return np_img 18 | 19 | 20 | class ToTensor: 21 | 22 | def __init__(self, dtype=torch.float32): 23 | self.dtype = dtype 24 | 25 | def __call__(self, pil_img): 26 | np_img = np.array(pil_img, dtype=np.uint8) 27 | if np_img.ndim < 3: 28 | np_img = np.expand_dims(np_img, axis=-1) 29 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 30 | return torch.from_numpy(np_img).to(dtype=self.dtype) 31 | 32 | 33 | _pil_interpolation_to_str = { 34 | Image.NEAREST: 'PIL.Image.NEAREST', 35 | Image.BILINEAR: 'PIL.Image.BILINEAR', 36 | Image.BICUBIC: 'PIL.Image.BICUBIC', 37 | Image.LANCZOS: 'PIL.Image.LANCZOS', 38 | Image.HAMMING: 'PIL.Image.HAMMING', 39 | Image.BOX: 'PIL.Image.BOX', 40 | } 41 | 42 | 43 | def _pil_interp(method): 44 | if method == 'bicubic': 45 | return Image.BICUBIC 46 | elif method == 'lanczos': 47 | return Image.LANCZOS 48 | elif method == 'hamming': 49 | return Image.HAMMING 50 | else: 51 | # default bilinear, do we want to allow nearest? 52 | return Image.BILINEAR 53 | 54 | 55 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 56 | 57 | 58 | class RandomResizedCropAndInterpolation: 59 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 60 | 61 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 62 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 63 | is finally resized to given size. 64 | This is popularly used to train the Inception networks. 65 | 66 | Args: 67 | size: expected output size of each edge 68 | scale: range of size of the origin size cropped 69 | ratio: range of aspect ratio of the origin aspect ratio cropped 70 | interpolation: Default: PIL.Image.BILINEAR 71 | """ 72 | 73 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), 74 | interpolation='bilinear'): 75 | if isinstance(size, (list, tuple)): 76 | self.size = tuple(size) 77 | else: 78 | self.size = (size, size) 79 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 80 | warnings.warn("range should be of kind (min, max)") 81 | 82 | if interpolation == 'random': 83 | self.interpolation = _RANDOM_INTERPOLATION 84 | else: 85 | self.interpolation = _pil_interp(interpolation) 86 | self.scale = scale 87 | self.ratio = ratio 88 | 89 | @staticmethod 90 | def get_params(img, scale, ratio): 91 | """Get parameters for ``crop`` for a random sized crop. 92 | 93 | Args: 94 | img (PIL Image): Image to be cropped. 95 | scale (tuple): range of size of the origin size cropped 96 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 97 | 98 | Returns: 99 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 100 | sized crop. 101 | """ 102 | area = img.size[0] * img.size[1] 103 | 104 | for attempt in range(10): 105 | target_area = random.uniform(*scale) * area 106 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 107 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 108 | 109 | w = int(round(math.sqrt(target_area * aspect_ratio))) 110 | h = int(round(math.sqrt(target_area / aspect_ratio))) 111 | 112 | if w <= img.size[0] and h <= img.size[1]: 113 | i = random.randint(0, img.size[1] - h) 114 | j = random.randint(0, img.size[0] - w) 115 | return i, j, h, w 116 | 117 | # Fallback to central crop 118 | in_ratio = img.size[0] / img.size[1] 119 | if in_ratio < min(ratio): 120 | w = img.size[0] 121 | h = int(round(w / min(ratio))) 122 | elif in_ratio > max(ratio): 123 | h = img.size[1] 124 | w = int(round(h * max(ratio))) 125 | else: # whole image 126 | w = img.size[0] 127 | h = img.size[1] 128 | i = (img.size[1] - h) // 2 129 | j = (img.size[0] - w) // 2 130 | return i, j, h, w 131 | 132 | def __call__(self, img): 133 | """ 134 | Args: 135 | img (PIL Image): Image to be cropped and resized. 136 | 137 | Returns: 138 | PIL Image: Randomly cropped and resized image. 139 | """ 140 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 141 | if isinstance(self.interpolation, (tuple, list)): 142 | interpolation = random.choice(self.interpolation) 143 | else: 144 | interpolation = self.interpolation 145 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 146 | 147 | def __repr__(self): 148 | if isinstance(self.interpolation, (tuple, list)): 149 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) 150 | else: 151 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 152 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 153 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 154 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 155 | format_string += ', interpolation={0})'.format(interpolate_str) 156 | return format_string 157 | 158 | 159 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 2 | from .jsd import JsdCrossEntropy 3 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch._C.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch._C.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x, target): 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .byoanet import * 2 | from .byobnet import * 3 | from .cait import * 4 | from .coat import * 5 | from .convit import * 6 | from .cspnet import * 7 | from .densenet import * 8 | from .dla import * 9 | from .dpn import * 10 | from .efficientnet import * 11 | from .ghostnet import * 12 | from .gluon_resnet import * 13 | from .gluon_xception import * 14 | from .hardcorenas import * 15 | from .hrnet import * 16 | from .inception_resnet_v2 import * 17 | from .inception_v3 import * 18 | from .inception_v4 import * 19 | from .levit import * 20 | from .mlp_mixer import * 21 | from .mobilenetv3 import * 22 | from .nasnet import * 23 | from .nfnet import * 24 | from .pit import * 25 | from .pnasnet import * 26 | from .regnet import * 27 | from .res2net import * 28 | from .resnest import * 29 | from .resnet import * 30 | from .resnetv2 import * 31 | from .rexnet import * 32 | from .selecsls import * 33 | from .senet import * 34 | from .sknet import * 35 | from .swin_transformer import * 36 | from .tnt import * 37 | from .tresnet import * 38 | from .vgg import * 39 | from .visformer import * 40 | from .vision_transformer import * 41 | from .vision_transformer_hybrid import * 42 | from .vovnet import * 43 | from .xception import * 44 | from .xception_aligned import * 45 | from .twins import * 46 | 47 | from .factory import create_model, split_model_name, safe_model_name 48 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 49 | from .layers import TestTimePoolHead, apply_test_time_pool 50 | from .layers import convert_splitbn_model 51 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 52 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 53 | has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained 54 | -------------------------------------------------------------------------------- /timm/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byoanet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byoanet.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byoanet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byoanet.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byobnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byobnet.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/byobnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byobnet.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/cait.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/cait.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/__pycache__/cait.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/cait.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | from .layers import set_layer_config 4 | from .hub import load_model_config_from_hf 5 | 6 | 7 | def split_model_name(model_name): 8 | model_split = model_name.split(':', 1) 9 | if len(model_split) == 1: 10 | return '', model_split[0] 11 | else: 12 | source_name, model_name = model_split 13 | assert source_name in ('timm', 'hf_hub') 14 | return source_name, model_name 15 | 16 | 17 | def safe_model_name(model_name, remove_source=True): 18 | def make_safe(name): 19 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 20 | if remove_source: 21 | model_name = split_model_name(model_name)[-1] 22 | return make_safe(model_name) 23 | 24 | 25 | def create_model( 26 | model_name, 27 | pretrained=False, 28 | checkpoint_path='', 29 | scriptable=None, 30 | exportable=None, 31 | no_jit=None, 32 | **kwargs): 33 | """Create a model 34 | 35 | Args: 36 | model_name (str): name of model to instantiate 37 | pretrained (bool): load pretrained ImageNet-1k weights if true 38 | checkpoint_path (str): path of checkpoint to load after model is initialized 39 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 40 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 41 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 42 | 43 | Keyword Args: 44 | drop_rate (float): dropout rate for training (default: 0.0) 45 | global_pool (str): global pool type (default: 'avg') 46 | **: other kwargs are model specific 47 | """ 48 | source_name, model_name = split_model_name(model_name) 49 | 50 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 51 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 52 | if not is_efficientnet: 53 | kwargs.pop('bn_tf', None) 54 | kwargs.pop('bn_momentum', None) 55 | kwargs.pop('bn_eps', None) 56 | 57 | # handle backwards compat with drop_connect -> drop_path change 58 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 59 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 60 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 61 | " Setting drop_path to %f." % drop_connect_rate) 62 | kwargs['drop_path_rate'] = drop_connect_rate 63 | 64 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 65 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 66 | # non-supporting models don't break and default args remain in effect. 67 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 68 | 69 | if source_name == 'hf_hub': 70 | # For model names specified in the form `hf_hub:path/architecture_name#revision`, 71 | # load model weights + default_cfg from Hugging Face hub. 72 | hf_default_cfg, model_name = load_model_config_from_hf(model_name) 73 | kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday 74 | 75 | if is_model(model_name): 76 | create_fn = model_entrypoint(model_name) 77 | else: 78 | raise RuntimeError('Unknown model (%s)' % model_name) 79 | 80 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 81 | model = create_fn(pretrained=pretrained, **kwargs) 82 | 83 | if checkpoint_path: 84 | load_checkpoint(model, checkpoint_path) 85 | 86 | return model 87 | -------------------------------------------------------------------------------- /timm/models/hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from typing import Union, Optional 6 | 7 | import torch 8 | from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX 9 | try: 10 | from torch.hub import get_dir 11 | except ImportError: 12 | from torch.hub import _get_torch_home as get_dir 13 | 14 | from timm import __version__ 15 | try: 16 | from huggingface_hub import hf_hub_url 17 | from huggingface_hub import cached_download 18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__) 19 | except ImportError: 20 | hf_hub_url = None 21 | cached_download = None 22 | 23 | _logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_cache_dir(child_dir=''): 27 | """ 28 | Returns the location of the directory where models are cached (and creates it if necessary). 29 | """ 30 | # Issue warning to move data if old env is set 31 | if os.getenv('TORCH_MODEL_ZOO'): 32 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 33 | 34 | hub_dir = get_dir() 35 | child_dir = () if not child_dir else (child_dir,) 36 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 37 | os.makedirs(model_dir, exist_ok=True) 38 | return model_dir 39 | 40 | 41 | def download_cached_file(url, check_hash=True, progress=False): 42 | parts = urlparse(url) 43 | filename = os.path.basename(parts.path) 44 | cached_file = os.path.join(get_cache_dir(), filename) 45 | if not os.path.exists(cached_file): 46 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) 47 | hash_prefix = None 48 | if check_hash: 49 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 50 | hash_prefix = r.group(1) if r else None 51 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 52 | return cached_file 53 | 54 | 55 | def has_hf_hub(necessary=False): 56 | if hf_hub_url is None and necessary: 57 | # if no HF Hub module installed and it is necessary to continue, raise error 58 | raise RuntimeError( 59 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 60 | return hf_hub_url is not None 61 | 62 | 63 | def hf_split(hf_id): 64 | rev_split = hf_id.split('@') 65 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' 66 | hf_model_id = rev_split[0] 67 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None 68 | return hf_model_id, hf_revision 69 | 70 | 71 | def load_cfg_from_json(json_file: Union[str, os.PathLike]): 72 | with open(json_file, "r", encoding="utf-8") as reader: 73 | text = reader.read() 74 | return json.loads(text) 75 | 76 | 77 | def _download_from_hf(model_id: str, filename: str): 78 | hf_model_id, hf_revision = hf_split(model_id) 79 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision) 80 | return cached_download(url, cache_dir=get_cache_dir('hf')) 81 | 82 | 83 | def load_model_config_from_hf(model_id: str): 84 | assert has_hf_hub(True) 85 | cached_file = _download_from_hf(model_id, 'config.json') 86 | default_cfg = load_cfg_from_json(cached_file) 87 | default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation 88 | model_name = default_cfg.get('architecture') 89 | return default_cfg, model_name 90 | 91 | 92 | def load_state_dict_from_hf(model_id: str): 93 | assert has_hf_hub(True) 94 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin') 95 | state_dict = torch.load(cached_file, map_location='cpu') 96 | return state_dict 97 | -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_jit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_jit.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_jit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_me.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_me.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/activations_me.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_me.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/blur_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/blur_pool.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/blur_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/bottleneck_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/bottleneck_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cbam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cbam.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cbam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cbam.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/classifier.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/classifier.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/classifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/classifier.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cond_conv2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cond_conv2d.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv2d_same.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv2d_same.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv_bn_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv_bn_act.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_act.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_act.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_conv2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_conv2d.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_norm_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_norm_act.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/drop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/drop.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/drop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/drop.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/eca.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/eca.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/eca.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/eca.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/evo_norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/evo_norm.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/evo_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/gather_excite.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/gather_excite.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/gather_excite.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/global_context.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/global_context.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/global_context.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/global_context.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/halo_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/halo_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/halo_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/inplace_abn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/inplace_abn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/involution.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/involution.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/involution.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/involution.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/lambda_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/lambda_layer.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/linear.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/linear.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/linear.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mixed_conv2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mixed_conv2d.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/non_local_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/non_local_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm_act.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/norm_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm_act.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/padding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/padding.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/padding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/padding.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/patch_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/patch_embed.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/patch_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/pool2d_same.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/pool2d_same.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/selective_kernel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/selective_kernel.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/separable_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/separable_conv.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/separable_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/space_to_depth.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/space_to_depth.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/squeeze_excite.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/squeeze_excite.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/std_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/std_conv.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/std_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/std_conv.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/swin_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/swin_attn.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/swin_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/test_time_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/test_time_pool.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/weight_init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/weight_init.cpython-37.pyc -------------------------------------------------------------------------------- /timm/models/layers/__pycache__/weight_init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/weight_init.cpython-38.pyc -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_default_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 43 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) 44 | if has_pretrained: 45 | _model_has_pretrained.add(model_name) 46 | return fn 47 | 48 | 49 | def _natural_key(string_): 50 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 51 | 52 | 53 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 54 | """ Return list of available model names, sorted alphabetically 55 | 56 | Args: 57 | filter (str) - Wildcard filter string that works with fnmatch 58 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 59 | pretrained (bool) - Include only models with pretrained weights if True 60 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 61 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 62 | 63 | Example: 64 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 65 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 66 | """ 67 | if module: 68 | all_models = list(_module_to_models[module]) 69 | else: 70 | all_models = _model_entrypoints.keys() 71 | if filter: 72 | models = [] 73 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 74 | for f in include_filters: 75 | include_models = fnmatch.filter(all_models, f) # include these models 76 | if len(include_models): 77 | models = set(models).union(include_models) 78 | else: 79 | models = all_models 80 | if exclude_filters: 81 | if not isinstance(exclude_filters, (tuple, list)): 82 | exclude_filters = [exclude_filters] 83 | for xf in exclude_filters: 84 | exclude_models = fnmatch.filter(models, xf) # exclude these models 85 | if len(exclude_models): 86 | models = set(models).difference(exclude_models) 87 | if pretrained: 88 | models = _model_has_pretrained.intersection(models) 89 | if name_matches_cfg: 90 | models = set(_model_default_cfgs).intersection(models) 91 | return list(sorted(models, key=_natural_key)) 92 | 93 | 94 | def is_model(model_name): 95 | """ Check if a model name exists 96 | """ 97 | return model_name in _model_entrypoints 98 | 99 | 100 | def model_entrypoint(model_name): 101 | """Fetch a model entrypoint for specified model name 102 | """ 103 | return _model_entrypoints[model_name] 104 | 105 | 106 | def list_modules(): 107 | """ Return list of module names that contain models / model entrypoints 108 | """ 109 | modules = _module_to_models.keys() 110 | return list(sorted(modules)) 111 | 112 | 113 | def is_model_in_modules(model_name, module_names): 114 | """Check if a model exists within a subset of modules 115 | Args: 116 | model_name (str) - name of model to check 117 | module_names (tuple, list, set) - names of modules to search in 118 | """ 119 | assert isinstance(module_names, (tuple, list, set)) 120 | return any(model_name in _module_to_models[n] for n in module_names) 121 | 122 | 123 | def has_model_default_key(model_name, cfg_key): 124 | """ Query model default_cfgs for existence of a specific key. 125 | """ 126 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: 127 | return True 128 | return False 129 | 130 | 131 | def is_model_default_key(model_name, cfg_key): 132 | """ Return truthy value for specified model default_cfg key, False if does not exist. 133 | """ 134 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): 135 | return True 136 | return False 137 | 138 | 139 | def get_model_default_value(model_name, cfg_key): 140 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 141 | """ 142 | if model_name in _model_default_cfgs: 143 | return _model_default_cfgs[model_name].get(cfg_key, None) 144 | else: 145 | return None 146 | 147 | 148 | def is_model_pretrained(model_name): 149 | return model_name in _model_has_pretrained 150 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | from .adabelief import AdaBelief 13 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 14 | -------------------------------------------------------------------------------- /timm/optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /timm/optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /timm/optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /timm/optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /timm/optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(RAdam, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(RAdam, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('RAdam does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | 53 | state['step'] += 1 54 | buffered = self.buffer[int(state['step'] % 10)] 55 | if state['step'] == buffered[0]: 56 | N_sma, step_size = buffered[1], buffered[2] 57 | else: 58 | buffered[0] = state['step'] 59 | beta2_t = beta2 ** state['step'] 60 | N_sma_max = 2 / (1 - beta2) - 1 61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 62 | buffered[1] = N_sma 63 | 64 | # more conservative since it's an approximated value 65 | if N_sma >= 5: 66 | step_size = group['lr'] * math.sqrt( 67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 68 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 69 | else: 70 | step_size = group['lr'] / (1 - beta1 ** state['step']) 71 | buffered[2] = step_size 72 | 73 | if group['weight_decay'] != 0: 74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | denom = exp_avg_sq.sqrt().add_(group['eps']) 79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 80 | else: 81 | p_data_fp32.add_(-step_size, exp_avg) 82 | 83 | p.data.copy_(p_data_fp32) 84 | 85 | return loss 86 | 87 | 88 | class PlainRAdam(Optimizer): 89 | 90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 92 | 93 | super(PlainRAdam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(PlainRAdam, self).__setstate__(state) 97 | 98 | def step(self, closure=None): 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data.float() 110 | if grad.is_sparse: 111 | raise RuntimeError('RAdam does not support sparse gradients') 112 | 113 | p_data_fp32 = p.data.float() 114 | 115 | state = self.state[p] 116 | 117 | if len(state) == 0: 118 | state['step'] = 0 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | state['step'] += 1 132 | beta2_t = beta2 ** state['step'] 133 | N_sma_max = 2 / (1 - beta2) - 1 134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 135 | 136 | if group['weight_decay'] != 0: 137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 138 | 139 | # more conservative since it's an approximated value 140 | if N_sma >= 5: 141 | step_size = group['lr'] * math.sqrt( 142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 143 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 144 | denom = exp_avg_sq.sqrt().add_(group['eps']) 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | else: 147 | step_size = group['lr'] / (1 - beta1 ** state['step']) 148 | p_data_fp32.add_(-step_size, exp_avg) 149 | 150 | p.data.copy_(p_data_fp32) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /timm/optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | """ RMSProp modified to behave like Tensorflow impl 2 | 3 | Originally cut & paste from PyTorch RMSProp 4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py 5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE 6 | 7 | Modifications Copyright 2020 Ross Wightman 8 | """ 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | class RMSpropTF(Optimizer): 15 | """Implements RMSprop algorithm (TensorFlow style epsilon) 16 | 17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 18 | and a few other modifications to closer match Tensorflow for matching hyper-params. 19 | 20 | Noteworthy changes include: 21 | 1. Epsilon applied inside square-root 22 | 2. square_avg initialized to ones 23 | 3. LR scaling of update accumulated in momentum buffer 24 | 25 | Proposed by G. Hinton in his 26 | `course `_. 27 | 28 | The centered version first appears in `Generating Sequences 29 | With Recurrent Neural Networks `_. 30 | 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-2) 35 | momentum (float, optional): momentum factor (default: 0) 36 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-10) 39 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 40 | the gradient is normalized by an estimation of its variance 41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 44 | update as per defaults in Tensorflow 45 | 46 | """ 47 | 48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, 49 | decoupled_decay=False, lr_in_momentum=True): 50 | if not 0.0 <= lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if not 0.0 <= eps: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= momentum: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if not 0.0 <= weight_decay: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | if not 0.0 <= alpha: 59 | raise ValueError("Invalid alpha value: {}".format(alpha)) 60 | 61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, 62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) 63 | super(RMSpropTF, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(RMSpropTF, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('momentum', 0) 69 | group.setdefault('centered', False) 70 | 71 | def step(self, closure=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | loss = closure() 81 | 82 | for group in self.param_groups: 83 | for p in group['params']: 84 | if p.grad is None: 85 | continue 86 | grad = p.grad.data 87 | if grad.is_sparse: 88 | raise RuntimeError('RMSprop does not support sparse gradients') 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero 95 | if group['momentum'] > 0: 96 | state['momentum_buffer'] = torch.zeros_like(p.data) 97 | if group['centered']: 98 | state['grad_avg'] = torch.zeros_like(p.data) 99 | 100 | square_avg = state['square_avg'] 101 | one_minus_alpha = 1. - group['alpha'] 102 | 103 | state['step'] += 1 104 | 105 | if group['weight_decay'] != 0: 106 | if 'decoupled_decay' in group and group['decoupled_decay']: 107 | p.data.add_(-group['weight_decay'], p.data) 108 | else: 109 | grad = grad.add(group['weight_decay'], p.data) 110 | 111 | # Tensorflow order of ops for updating squared avg 112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 114 | 115 | if group['centered']: 116 | grad_avg = state['grad_avg'] 117 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt 120 | else: 121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt 122 | 123 | if group['momentum'] > 0: 124 | buf = state['momentum_buffer'] 125 | # Tensorflow accumulates the LR scaling in the momentum buffer 126 | if 'lr_in_momentum' in group and group['lr_in_momentum']: 127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) 128 | p.data.add_(-buf) 129 | else: 130 | # PyTorch scales the param update by LR 131 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 132 | p.data.add_(-group['lr'], buf) 133 | else: 134 | p.data.addcdiv_(-group['lr'], grad, avg) 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /timm/scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class CosineLRScheduler(Scheduler): 19 | """ 20 | Cosine decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1608.03983. 22 | 23 | Inspiration from 24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | t_initial: int, 30 | t_mul: float = 1., 31 | lr_min: float = 0., 32 | decay_rate: float = 1., 33 | warmup_t=0, 34 | warmup_lr_init=0, 35 | warmup_prefix=False, 36 | cycle_limit=0, 37 | t_in_epochs=True, 38 | noise_range_t=None, 39 | noise_pct=0.67, 40 | noise_std=1.0, 41 | noise_seed=42, 42 | initialize=True) -> None: 43 | super().__init__( 44 | optimizer, param_group_field="lr", 45 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 46 | initialize=initialize) 47 | 48 | assert t_initial > 0 49 | assert lr_min >= 0 50 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 51 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 52 | "rate since t_initial = t_mul = eta_mul = 1.") 53 | self.t_initial = t_initial 54 | self.t_mul = t_mul 55 | self.lr_min = lr_min 56 | self.decay_rate = decay_rate 57 | self.cycle_limit = cycle_limit 58 | self.warmup_t = warmup_t 59 | self.warmup_lr_init = warmup_lr_init 60 | self.warmup_prefix = warmup_prefix 61 | self.t_in_epochs = t_in_epochs 62 | if self.warmup_t: 63 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 64 | super().update_groups(self.warmup_lr_init) 65 | else: 66 | self.warmup_steps = [1 for _ in self.base_values] 67 | 68 | def _get_lr(self, t): 69 | if t < self.warmup_t: 70 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 71 | else: 72 | if self.warmup_prefix: 73 | t = t - self.warmup_t 74 | 75 | if self.t_mul != 1: 76 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 77 | t_i = self.t_mul ** i * self.t_initial 78 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 79 | else: 80 | i = t // self.t_initial 81 | t_i = self.t_initial 82 | t_curr = t - (self.t_initial * i) 83 | 84 | gamma = self.decay_rate ** i 85 | lr_min = self.lr_min * gamma 86 | lr_max_values = [v * gamma for v in self.base_values] 87 | 88 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 89 | lrs = [ 90 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 91 | ] 92 | else: 93 | lrs = [self.lr_min for _ in self.base_values] 94 | 95 | return lrs 96 | 97 | def get_epoch_values(self, epoch: int): 98 | if self.t_in_epochs: 99 | return self._get_lr(epoch) 100 | else: 101 | return None 102 | 103 | def get_update_values(self, num_updates: int): 104 | if not self.t_in_epochs: 105 | return self._get_lr(num_updates) 106 | else: 107 | return None 108 | 109 | def get_cycle_length(self, cycles=0): 110 | if not cycles: 111 | cycles = self.cycle_limit 112 | cycles = max(1, cycles) 113 | if self.t_mul == 1.0: 114 | return self.t_initial * cycles 115 | else: 116 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 117 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__(optimizer, 'lr', initialize=initialize) 34 | 35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 36 | self.optimizer, 37 | patience=patience_t, 38 | factor=decay_rate, 39 | verbose=verbose, 40 | threshold=threshold, 41 | cooldown=cooldown_t, 42 | mode=mode, 43 | min_lr=lr_min 44 | ) 45 | 46 | self.noise_range = noise_range_t 47 | self.noise_pct = noise_pct 48 | self.noise_type = noise_type 49 | self.noise_std = noise_std 50 | self.noise_seed = noise_seed if noise_seed is not None else 42 51 | self.warmup_t = warmup_t 52 | self.warmup_lr_init = warmup_lr_init 53 | if self.warmup_t: 54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 55 | super().update_groups(self.warmup_lr_init) 56 | else: 57 | self.warmup_steps = [1 for _ in self.base_values] 58 | self.restore_lr = None 59 | 60 | def state_dict(self): 61 | return { 62 | 'best': self.lr_scheduler.best, 63 | 'last_epoch': self.lr_scheduler.last_epoch, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | self.lr_scheduler.best = state_dict['best'] 68 | if 'last_epoch' in state_dict: 69 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 70 | 71 | # override the base class step fn completely 72 | def step(self, epoch, metric=None): 73 | if epoch <= self.warmup_t: 74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 75 | super().update_groups(lrs) 76 | else: 77 | if self.restore_lr is not None: 78 | # restore actual LR from before our last noise perturbation before stepping base 79 | for i, param_group in enumerate(self.optimizer.param_groups): 80 | param_group['lr'] = self.restore_lr[i] 81 | self.restore_lr = None 82 | 83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 84 | 85 | if self.noise_range is not None: 86 | if isinstance(self.noise_range, (list, tuple)): 87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 88 | else: 89 | apply_noise = epoch >= self.noise_range 90 | if apply_noise: 91 | self._apply_noise(epoch) 92 | 93 | def _apply_noise(self, epoch): 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + epoch) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | 105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 106 | # stepping of base scheduler 107 | restore_lr = [] 108 | for i, param_group in enumerate(self.optimizer.param_groups): 109 | old_lr = float(param_group['lr']) 110 | restore_lr.append(old_lr) 111 | new_lr = old_lr + old_lr * noise 112 | param_group['lr'] = new_lr 113 | self.restore_lr = restore_lr 114 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | 9 | 10 | def create_scheduler(args, optimizer): 11 | num_epochs = args.epochs 12 | 13 | if getattr(args, 'lr_noise', None) is not None: 14 | lr_noise = getattr(args, 'lr_noise') 15 | if isinstance(lr_noise, (list, tuple)): 16 | noise_range = [n * num_epochs for n in lr_noise] 17 | if len(noise_range) == 1: 18 | noise_range = noise_range[0] 19 | else: 20 | noise_range = lr_noise * num_epochs 21 | else: 22 | noise_range = None 23 | 24 | lr_scheduler = None 25 | if args.sched == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_epochs, 29 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 30 | lr_min=args.min_lr, 31 | decay_rate=args.decay_rate, 32 | warmup_lr_init=args.warmup_lr, 33 | warmup_t=args.warmup_epochs, 34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 35 | t_in_epochs=True, 36 | noise_range_t=noise_range, 37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 38 | noise_std=getattr(args, 'lr_noise_std', 1.), 39 | noise_seed=getattr(args, 'seed', 42), 40 | ) 41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 42 | elif args.sched == 'tanh': 43 | lr_scheduler = TanhLRScheduler( 44 | optimizer, 45 | t_initial=num_epochs, 46 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 47 | lr_min=args.min_lr, 48 | warmup_lr_init=args.warmup_lr, 49 | warmup_t=args.warmup_epochs, 50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 51 | t_in_epochs=True, 52 | noise_range_t=noise_range, 53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 54 | noise_std=getattr(args, 'lr_noise_std', 1.), 55 | noise_seed=getattr(args, 'seed', 42), 56 | ) 57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 58 | elif args.sched == 'step': 59 | lr_scheduler = StepLRScheduler( 60 | optimizer, 61 | decay_t=args.decay_epochs, 62 | decay_rate=args.decay_rate, 63 | warmup_lr_init=args.warmup_lr, 64 | warmup_t=args.warmup_epochs, 65 | noise_range_t=noise_range, 66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 67 | noise_std=getattr(args, 'lr_noise_std', 1.), 68 | noise_seed=getattr(args, 'seed', 42), 69 | ) 70 | elif args.sched == 'plateau': 71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 72 | lr_scheduler = PlateauLRScheduler( 73 | optimizer, 74 | decay_rate=args.decay_rate, 75 | patience_t=args.patience_epochs, 76 | lr_min=args.min_lr, 77 | mode=mode, 78 | warmup_lr_init=args.warmup_lr, 79 | warmup_t=args.warmup_epochs, 80 | cooldown_t=0, 81 | noise_range_t=noise_range, 82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 83 | noise_std=getattr(args, 'lr_noise_std', 1.), 84 | noise_seed=getattr(args, 'seed', 42), 85 | ) 86 | 87 | return lr_scheduler, num_epochs 88 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class TanhLRScheduler(Scheduler): 19 | """ 20 | Hyberbolic-Tangent decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1806.01593 22 | """ 23 | 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_initial: int, 27 | lb: float = -6., 28 | ub: float = 4., 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | assert lb < ub 50 | assert cycle_limit >= 0 51 | assert warmup_t >= 0 52 | assert warmup_lr_init >= 0 53 | self.lb = lb 54 | self.ub = ub 55 | self.t_initial = t_initial 56 | self.t_mul = t_mul 57 | self.lr_min = lr_min 58 | self.decay_rate = decay_rate 59 | self.cycle_limit = cycle_limit 60 | self.warmup_t = warmup_t 61 | self.warmup_lr_init = warmup_lr_init 62 | self.warmup_prefix = warmup_prefix 63 | self.t_in_epochs = t_in_epochs 64 | if self.warmup_t: 65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 67 | super().update_groups(self.warmup_lr_init) 68 | else: 69 | self.warmup_steps = [1 for _ in self.base_values] 70 | 71 | def _get_lr(self, t): 72 | if t < self.warmup_t: 73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 74 | else: 75 | if self.warmup_prefix: 76 | t = t - self.warmup_t 77 | 78 | if self.t_mul != 1: 79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 80 | t_i = self.t_mul ** i * self.t_initial 81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 82 | else: 83 | i = t // self.t_initial 84 | t_i = self.t_initial 85 | t_curr = t - (self.t_initial * i) 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | gamma = self.decay_rate ** i 89 | lr_min = self.lr_min * gamma 90 | lr_max_values = [v * gamma for v in self.base_values] 91 | 92 | tr = t_curr / t_i 93 | lrs = [ 94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 95 | for lr_max in lr_max_values 96 | ] 97 | else: 98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 99 | return lrs 100 | 101 | def get_epoch_values(self, epoch: int): 102 | if self.t_in_epochs: 103 | return self._get_lr(epoch) 104 | else: 105 | return None 106 | 107 | def get_update_values(self, num_updates: int): 108 | if not self.t_in_epochs: 109 | return self._get_lr(num_updates) 110 | else: 111 | return None 112 | 113 | def get_cycle_length(self, cycles=0): 114 | if not cycles: 115 | cycles = self.cycle_limit 116 | cycles = max(1, cycles) 117 | if self.t_mul == 1.0: 118 | return self.t_initial * cycles 119 | else: 120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 121 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .distributed import distribute_bn, reduce_tensor 6 | from .jit import set_jit_legacy 7 | from .log import setup_default_logging, FormatterNoInfo 8 | from .metrics import AverageMeter, accuracy 9 | from .misc import natural_key, add_bool_arg 10 | from .model import unwrap_model, get_state_dict 11 | from .model_ema import ModelEma, ModelEmaV2 12 | from .random import random_seed 13 | from .summary import update_summary, get_outdir 14 | -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/utils/checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint Saver 2 | 3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import glob 9 | import operator 10 | import os 11 | import logging 12 | 13 | import torch 14 | 15 | from .model import unwrap_model, get_state_dict 16 | 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | class CheckpointSaver: 22 | def __init__( 23 | self, 24 | model, 25 | optimizer, 26 | args=None, 27 | model_ema=None, 28 | amp_scaler=None, 29 | checkpoint_prefix='checkpoint', 30 | recovery_prefix='recovery', 31 | checkpoint_dir='', 32 | recovery_dir='', 33 | decreasing=False, 34 | max_history=10, 35 | unwrap_fn=unwrap_model): 36 | 37 | # objects to save state_dicts of 38 | self.model = model 39 | self.optimizer = optimizer 40 | self.args = args 41 | self.model_ema = model_ema 42 | self.amp_scaler = amp_scaler 43 | 44 | # state 45 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness 46 | self.best_epoch = None 47 | self.best_metric = None 48 | self.curr_recovery_file = '' 49 | self.last_recovery_file = '' 50 | 51 | # config 52 | self.checkpoint_dir = checkpoint_dir 53 | self.recovery_dir = recovery_dir 54 | self.save_prefix = checkpoint_prefix 55 | self.recovery_prefix = recovery_prefix 56 | self.extension = '.pth.tar' 57 | self.decreasing = decreasing # a lower metric is better if True 58 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs 59 | self.max_history = max_history 60 | self.unwrap_fn = unwrap_fn 61 | assert self.max_history >= 1 62 | 63 | def save_checkpoint(self, epoch, metric=None): 64 | assert epoch >= 0 65 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) 66 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 67 | self._save(tmp_save_path, epoch, metric) 68 | if os.path.exists(last_save_path): 69 | os.unlink(last_save_path) # required for Windows support. 70 | os.rename(tmp_save_path, last_save_path) 71 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 72 | if (len(self.checkpoint_files) < self.max_history 73 | or metric is None or self.cmp(metric, worst_file[1])): 74 | if len(self.checkpoint_files) >= self.max_history: 75 | self._cleanup_checkpoints(1) 76 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 77 | save_path = os.path.join(self.checkpoint_dir, filename) 78 | os.link(last_save_path, save_path) 79 | self.checkpoint_files.append((save_path, metric)) 80 | self.checkpoint_files = sorted( 81 | self.checkpoint_files, key=lambda x: x[1], 82 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 83 | 84 | checkpoints_str = "Current checkpoints:\n" 85 | for c in self.checkpoint_files: 86 | checkpoints_str += ' {}\n'.format(c) 87 | _logger.info(checkpoints_str) 88 | 89 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 90 | self.best_epoch = epoch 91 | self.best_metric = metric 92 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 93 | if os.path.exists(best_save_path): 94 | os.unlink(best_save_path) 95 | os.link(last_save_path, best_save_path) 96 | 97 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 98 | 99 | def _save(self, save_path, epoch, metric=None): 100 | save_state = { 101 | 'epoch': epoch, 102 | 'arch': type(self.model).__name__.lower(), 103 | 'state_dict': get_state_dict(self.model, self.unwrap_fn), 104 | 'optimizer': self.optimizer.state_dict(), 105 | 'version': 2, # version < 2 increments epoch before save 106 | } 107 | if self.args is not None: 108 | save_state['arch'] = self.args.model 109 | save_state['args'] = self.args 110 | if self.amp_scaler is not None: 111 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() 112 | if self.model_ema is not None: 113 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) 114 | if metric is not None: 115 | save_state['metric'] = metric 116 | torch.save(save_state, save_path) 117 | 118 | def _cleanup_checkpoints(self, trim=0): 119 | trim = min(len(self.checkpoint_files), trim) 120 | delete_index = self.max_history - trim 121 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index: 122 | return 123 | to_delete = self.checkpoint_files[delete_index:] 124 | for d in to_delete: 125 | try: 126 | _logger.debug("Cleaning checkpoint: {}".format(d)) 127 | os.remove(d[0]) 128 | except Exception as e: 129 | _logger.error("Exception '{}' while deleting checkpoint".format(e)) 130 | self.checkpoint_files = self.checkpoint_files[:delete_index] 131 | 132 | def save_recovery(self, epoch, batch_idx=0): 133 | assert epoch >= 0 134 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 135 | save_path = os.path.join(self.recovery_dir, filename) 136 | self._save(save_path, epoch) 137 | if os.path.exists(self.last_recovery_file): 138 | try: 139 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 140 | os.remove(self.last_recovery_file) 141 | except Exception as e: 142 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 143 | self.last_recovery_file = self.curr_recovery_file 144 | self.curr_recovery_file = save_path 145 | 146 | def find_recovery(self): 147 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) 148 | files = glob.glob(recovery_path + '*' + self.extension) 149 | files = sorted(files) 150 | return files[0] if len(files) else '' 151 | -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 21 | with amp.scale_loss(loss, optimizer) as scaled_loss: 22 | scaled_loss.backward(create_graph=create_graph) 23 | if clip_grad is not None: 24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 25 | optimizer.step() 26 | 27 | def state_dict(self): 28 | if 'state_dict' in amp.__dict__: 29 | return amp.state_dict() 30 | 31 | def load_state_dict(self, state_dict): 32 | if 'load_state_dict' in amp.__dict__: 33 | amp.load_state_dict(state_dict) 34 | 35 | 36 | class NativeScaler: 37 | state_dict_key = "amp_scaler" 38 | 39 | def __init__(self): 40 | self._scaler = torch.cuda.amp.GradScaler() 41 | 42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 43 | self._scaler.scale(loss).backward(create_graph=create_graph) 44 | if clip_grad is not None: 45 | assert parameters is not None 46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 48 | self._scaler.step(optimizer) 49 | self._scaler.update() 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) 56 | -------------------------------------------------------------------------------- /timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | 8 | def set_jit_legacy(): 9 | """ Set JIT executor to legacy w/ support for op fusion 10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 11 | in the JIT exectutor. These API are not supported so could change. 12 | """ 13 | # 14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 15 | torch._C._jit_set_profiling_executor(False) 16 | torch._C._jit_set_profiling_mode(False) 17 | torch._C._jit_override_can_fuse_on_gpu(True) 18 | #torch._C._jit_set_texpr_fuser_enabled(True) 19 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /timm/utils/model.py: -------------------------------------------------------------------------------- 1 | """ Model / state_dict utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from .model_ema import ModelEma 6 | import torch 7 | import fnmatch 8 | 9 | def unwrap_model(model): 10 | if isinstance(model, ModelEma): 11 | return unwrap_model(model.ema) 12 | else: 13 | return model.module if hasattr(model, 'module') else model 14 | 15 | 16 | def get_state_dict(model, unwrap_fn=unwrap_model): 17 | return unwrap_fn(model).state_dict() 18 | 19 | 20 | def avg_sq_ch_mean(model, input, output): 21 | "calculate average channel square mean of output activations" 22 | return torch.mean(output.mean(axis=[0,2,3])**2).item() 23 | 24 | 25 | def avg_ch_var(model, input, output): 26 | "calculate average channel variance of output activations" 27 | return torch.mean(output.var(axis=[0,2,3])).item()\ 28 | 29 | 30 | def avg_ch_var_residual(model, input, output): 31 | "calculate average channel variance of output activations" 32 | return torch.mean(output.var(axis=[0,2,3])).item() 33 | 34 | 35 | class ActivationStatsHook: 36 | """Iterates through each of `model`'s modules and matches modules using unix pattern 37 | matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is 38 | a match. 39 | 40 | Arguments: 41 | model (nn.Module): model from which we will extract the activation stats 42 | hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string 43 | matching with the name of model's modules. 44 | hook_fns (List[Callable]): List of hook functions to be registered at every 45 | module in `layer_names`. 46 | 47 | Inspiration from https://docs.fast.ai/callback.hook.html. 48 | 49 | Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example 50 | on how to plot Signal Propogation Plots using `ActivationStatsHook`. 51 | """ 52 | 53 | def __init__(self, model, hook_fn_locs, hook_fns): 54 | self.model = model 55 | self.hook_fn_locs = hook_fn_locs 56 | self.hook_fns = hook_fns 57 | if len(hook_fn_locs) != len(hook_fns): 58 | raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ 59 | their lengths are different.") 60 | self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) 61 | for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): 62 | self.register_hook(hook_fn_loc, hook_fn) 63 | 64 | def _create_hook(self, hook_fn): 65 | def append_activation_stats(module, input, output): 66 | out = hook_fn(module, input, output) 67 | self.stats[hook_fn.__name__].append(out) 68 | return append_activation_stats 69 | 70 | def register_hook(self, hook_fn_loc, hook_fn): 71 | for name, module in self.model.named_modules(): 72 | if not fnmatch.fnmatch(name, hook_fn_loc): 73 | continue 74 | module.register_forward_hook(self._create_hook(hook_fn)) 75 | 76 | 77 | def extract_spp_stats(model, 78 | hook_fn_locs, 79 | hook_fns, 80 | input_shape=[8, 3, 224, 224]): 81 | """Extract average square channel mean and variance of activations during 82 | forward pass to plot Signal Propogation Plots (SPP). 83 | 84 | Paper: https://arxiv.org/abs/2101.08692 85 | 86 | Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 87 | """ 88 | x = torch.normal(0., 1., input_shape) 89 | hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) 90 | _ = model(x) 91 | return hook.stats 92 | -------------------------------------------------------------------------------- /timm/utils/model_ema.py: -------------------------------------------------------------------------------- 1 | """ Exponential Moving Average (EMA) of model updates 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | from copy import deepcopy 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | class ModelEma: 16 | """ Model Exponential Moving Average (DEPRECATED) 17 | 18 | Keep a moving average of everything in the model state_dict (parameters and buffers). 19 | This version is deprecated, it does not work with scripted models. Will be removed eventually. 20 | 21 | This is intended to allow functionality like 22 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 23 | 24 | A smoothed version of the weights is necessary for some training schemes to perform well. 25 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 26 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 27 | smoothing of weights to match results. Pay attention to the decay constant you are using 28 | relative to your update count per epoch. 29 | 30 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 31 | disable validation of the EMA weights. Validation will have to be done manually in a separate 32 | process, or after the training stops converging. 33 | 34 | This class is sensitive where it is initialized in the sequence of model init, 35 | GPU assignment and distributed training wrappers. 36 | """ 37 | def __init__(self, model, decay=0.9999, device='', resume=''): 38 | # make a copy of the model for accumulating moving average of weights 39 | self.ema = deepcopy(model) 40 | self.ema.eval() 41 | self.decay = decay 42 | self.device = device # perform ema on different device from model if set 43 | if device: 44 | self.ema.to(device=device) 45 | self.ema_has_module = hasattr(self.ema, 'module') 46 | if resume: 47 | self._load_checkpoint(resume) 48 | for p in self.ema.parameters(): 49 | p.requires_grad_(False) 50 | 51 | def _load_checkpoint(self, checkpoint_path): 52 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 53 | assert isinstance(checkpoint, dict) 54 | if 'state_dict_ema' in checkpoint: 55 | new_state_dict = OrderedDict() 56 | for k, v in checkpoint['state_dict_ema'].items(): 57 | # ema model may have been wrapped by DataParallel, and need module prefix 58 | if self.ema_has_module: 59 | name = 'module.' + k if not k.startswith('module') else k 60 | else: 61 | name = k 62 | new_state_dict[name] = v 63 | self.ema.load_state_dict(new_state_dict) 64 | _logger.info("Loaded state_dict_ema") 65 | else: 66 | _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") 67 | 68 | def update(self, model): 69 | # correct a mismatch in state dict keys 70 | needs_module = hasattr(model, 'module') and not self.ema_has_module 71 | with torch.no_grad(): 72 | msd = model.state_dict() 73 | for k, ema_v in self.ema.state_dict().items(): 74 | if needs_module: 75 | k = 'module.' + k 76 | model_v = msd[k].detach() 77 | if self.device: 78 | model_v = model_v.to(device=self.device) 79 | ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) 80 | 81 | 82 | class ModelEmaV2(nn.Module): 83 | """ Model Exponential Moving Average V2 84 | 85 | Keep a moving average of everything in the model state_dict (parameters and buffers). 86 | V2 of this module is simpler, it does not match params/buffers based on name but simply 87 | iterates in order. It works with torchscript (JIT of full model). 88 | 89 | This is intended to allow functionality like 90 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 91 | 92 | A smoothed version of the weights is necessary for some training schemes to perform well. 93 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 94 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 95 | smoothing of weights to match results. Pay attention to the decay constant you are using 96 | relative to your update count per epoch. 97 | 98 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 99 | disable validation of the EMA weights. Validation will have to be done manually in a separate 100 | process, or after the training stops converging. 101 | 102 | This class is sensitive where it is initialized in the sequence of model init, 103 | GPU assignment and distributed training wrappers. 104 | """ 105 | def __init__(self, model, decay=0.9999, device=None): 106 | super(ModelEmaV2, self).__init__() 107 | # make a copy of the model for accumulating moving average of weights 108 | self.module = deepcopy(model) 109 | self.module.eval() 110 | self.decay = decay 111 | self.device = device # perform ema on different device from model if set 112 | if self.device is not None: 113 | self.module.to(device=device) 114 | 115 | def _update(self, model, update_fn): 116 | with torch.no_grad(): 117 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 118 | if self.device is not None: 119 | model_v = model_v.to(device=self.device) 120 | ema_v.copy_(update_fn(ema_v, model_v)) 121 | 122 | def update(self, model): 123 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 124 | 125 | def set(self, model): 126 | self._update(model, update_fn=lambda e, m: m) 127 | -------------------------------------------------------------------------------- /timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | def get_outdir(path, *paths, inc=False): 14 | outdir = os.path.join(path, *paths) 15 | if not os.path.exists(outdir): 16 | os.makedirs(outdir) 17 | elif inc: 18 | count = 1 19 | outdir_inc = outdir + '-' + str(count) 20 | while os.path.exists(outdir_inc): 21 | count = count + 1 22 | outdir_inc = outdir + '-' + str(count) 23 | assert count < 100 24 | outdir = outdir_inc 25 | os.makedirs(outdir) 26 | return outdir 27 | 28 | 29 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): 30 | rowd = OrderedDict(epoch=epoch) 31 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 32 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 33 | if log_wandb: 34 | wandb.log(rowd) 35 | with open(filename, mode='a') as cf: 36 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 37 | if write_header: # first iteration (epoch == 1 can't be used) 38 | dw.writeheader() 39 | dw.writerow(rowd) 40 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.12' 2 | -------------------------------------------------------------------------------- /utils_evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/lupantech/ScienceQA 3 | ''' 4 | 5 | import os 6 | import json 7 | import argparse 8 | import warnings 9 | import pandas as pd 10 | from sentence_transformers import SentenceTransformer 11 | from evaluations import caculate_bleu, caculate_rouge, caculate_similariry 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | def get_acc_with_contion(res_pd, key, values): 16 | if isinstance(values, list): 17 | total_pd = res_pd[res_pd[key].isin(values)] 18 | else: 19 | total_pd = res_pd[res_pd[key] == values] 20 | correct_pd = total_pd[total_pd['true_false'] == True] 21 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) 22 | return acc 23 | 24 | 25 | def get_scores(result_data, rationale_data, results_reference, data_file): 26 | # read result file 27 | results = result_data 28 | num = len(results) 29 | assert num == 4241 30 | #print("number of questions:", num) 31 | 32 | # read data file 33 | sqa_data = json.load(open(data_file)) 34 | 35 | # construct pandas data 36 | sqa_pd = pd.DataFrame(sqa_data).T 37 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set 38 | 39 | # update data 40 | for index, row in res_pd.iterrows(): 41 | 42 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False 43 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False 44 | res_pd.loc[index, 'has_image'] = True if row['image'] else False 45 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False 46 | 47 | label = row['answer'] 48 | pred = int(results[index]) 49 | res_pd.loc[index, 'pred'] = pred 50 | res_pd.loc[index, 'true_false'] = (label == pred) 51 | 52 | # accuracy scores 53 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100 54 | #assert result_file.split('_')[-1] == "{:.3f}.json".format(acc_average) 55 | 56 | 57 | # rationale quality 58 | 59 | ## BLEU 60 | bleu1 = caculate_bleu(rationale_data, results_reference, gram=1) 61 | bleu4 = caculate_bleu(rationale_data, results_reference, gram=4) 62 | 63 | ## Rouge-L 64 | rouge = caculate_rouge(rationale_data, results_reference) 65 | 66 | ## Similarity 67 | model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda() 68 | similariry = caculate_similariry(rationale_data, results_reference, model) 69 | 70 | scores = { 71 | "answer":{ 72 | 'acc_natural': 73 | get_acc_with_contion(res_pd, 'subject', 'natural science'), 74 | 'acc_social': 75 | get_acc_with_contion(res_pd, 'subject', 'social science'), 76 | 'acc_language': 77 | get_acc_with_contion(res_pd, 'subject', 'language science'), 78 | 'acc_has_text': 79 | get_acc_with_contion(res_pd, 'has_text', True), 80 | 'acc_has_image': 81 | get_acc_with_contion(res_pd, 'has_image', True), 82 | 'acc_no_context': 83 | get_acc_with_contion(res_pd, 'no_context', True), 84 | 'acc_grade_1_6': 85 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']), 86 | 'acc_grade_7_12': 87 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']), 88 | 'acc_average': 89 | "{:.2f}".format(acc_average), 90 | }, 91 | "rationale":{ 92 | 'bleu1': bleu1 * 100, 93 | 'bleu4': bleu4 * 100, 94 | 'rouge': rouge * 100, 95 | 'similariry': similariry * 100, 96 | } 97 | } 98 | 99 | return scores 100 | 101 | 102 | def print_scores(scores): 103 | latex_output = "" 104 | for key, score in scores.items(): 105 | print(f"{key[4:]}: \t{score}") 106 | latex_output += f"& {score} " 107 | latex_output += "\\\\" 108 | print(latex_output) 109 | -------------------------------------------------------------------------------- /vision_features/mm-cot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/vision_features/mm-cot.png --------------------------------------------------------------------------------