├── .DS_Store
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── __pycache__
├── evaluations.cpython-39.pyc
├── model.cpython-39.pyc
├── utils_data.cpython-39.pyc
├── utils_evaluate.cpython-39.pyc
└── utils_prompt.cpython-39.pyc
├── data
├── instruct_captions.json
└── name_map.json
├── evaluations.py
├── extract_caption.py
├── extract_features.py
├── main.py
├── model.py
├── requirements.txt
├── run_inference.sh
├── run_training.sh
├── timm
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── version.cpython-37.pyc
│ └── version.cpython-38.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── auto_augment.cpython-37.pyc
│ │ ├── auto_augment.cpython-38.pyc
│ │ ├── config.cpython-37.pyc
│ │ ├── config.cpython-38.pyc
│ │ ├── constants.cpython-37.pyc
│ │ ├── constants.cpython-38.pyc
│ │ ├── dataset.cpython-37.pyc
│ │ ├── dataset.cpython-38.pyc
│ │ ├── dataset_factory.cpython-37.pyc
│ │ ├── dataset_factory.cpython-38.pyc
│ │ ├── distributed_sampler.cpython-37.pyc
│ │ ├── distributed_sampler.cpython-38.pyc
│ │ ├── loader.cpython-37.pyc
│ │ ├── loader.cpython-38.pyc
│ │ ├── mixup.cpython-37.pyc
│ │ ├── mixup.cpython-38.pyc
│ │ ├── random_erasing.cpython-37.pyc
│ │ ├── random_erasing.cpython-38.pyc
│ │ ├── real_labels.cpython-37.pyc
│ │ ├── real_labels.cpython-38.pyc
│ │ ├── transforms.cpython-37.pyc
│ │ ├── transforms.cpython-38.pyc
│ │ ├── transforms_factory.cpython-37.pyc
│ │ └── transforms_factory.cpython-38.pyc
│ ├── auto_augment.py
│ ├── config.py
│ ├── constants.py
│ ├── dataset.py
│ ├── dataset_factory.py
│ ├── distributed_sampler.py
│ ├── loader.py
│ ├── mixup.py
│ ├── parsers
│ │ ├── __init__.py
│ │ ├── class_map.py
│ │ ├── constants.py
│ │ ├── parser.py
│ │ ├── parser_factory.py
│ │ ├── parser_image_folder.py
│ │ ├── parser_image_in_tar.py
│ │ ├── parser_image_tar.py
│ │ └── parser_tfds.py
│ ├── random_erasing.py
│ ├── real_labels.py
│ ├── tf_preprocessing.py
│ ├── transforms.py
│ └── transforms_factory.py
├── loss
│ ├── __init__.py
│ ├── asymmetric_loss.py
│ ├── cross_entropy.py
│ └── jsd.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── byoanet.cpython-37.pyc
│ │ ├── byoanet.cpython-38.pyc
│ │ ├── byobnet.cpython-37.pyc
│ │ ├── byobnet.cpython-38.pyc
│ │ ├── cait.cpython-37.pyc
│ │ └── cait.cpython-38.pyc
│ ├── byoanet.py
│ ├── byobnet.py
│ ├── cait.py
│ ├── coat.py
│ ├── convit.py
│ ├── cspnet.py
│ ├── densenet.py
│ ├── dla.py
│ ├── dpn.py
│ ├── efficientnet.py
│ ├── efficientnet_blocks.py
│ ├── efficientnet_builder.py
│ ├── factory.py
│ ├── features.py
│ ├── ghostnet.py
│ ├── gluon_resnet.py
│ ├── gluon_xception.py
│ ├── hardcorenas.py
│ ├── helpers.py
│ ├── hrnet.py
│ ├── hub.py
│ ├── inception_resnet_v2.py
│ ├── inception_v3.py
│ ├── inception_v4.py
│ ├── layers
│ │ └── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── activations.cpython-37.pyc
│ │ │ ├── activations.cpython-38.pyc
│ │ │ ├── activations_jit.cpython-37.pyc
│ │ │ ├── activations_jit.cpython-38.pyc
│ │ │ ├── activations_me.cpython-37.pyc
│ │ │ ├── activations_me.cpython-38.pyc
│ │ │ ├── adaptive_avgmax_pool.cpython-37.pyc
│ │ │ ├── adaptive_avgmax_pool.cpython-38.pyc
│ │ │ ├── blur_pool.cpython-37.pyc
│ │ │ ├── blur_pool.cpython-38.pyc
│ │ │ ├── bottleneck_attn.cpython-37.pyc
│ │ │ ├── bottleneck_attn.cpython-38.pyc
│ │ │ ├── cbam.cpython-37.pyc
│ │ │ ├── cbam.cpython-38.pyc
│ │ │ ├── classifier.cpython-37.pyc
│ │ │ ├── classifier.cpython-38.pyc
│ │ │ ├── cond_conv2d.cpython-37.pyc
│ │ │ ├── cond_conv2d.cpython-38.pyc
│ │ │ ├── config.cpython-37.pyc
│ │ │ ├── config.cpython-38.pyc
│ │ │ ├── conv2d_same.cpython-37.pyc
│ │ │ ├── conv2d_same.cpython-38.pyc
│ │ │ ├── conv_bn_act.cpython-37.pyc
│ │ │ ├── conv_bn_act.cpython-38.pyc
│ │ │ ├── create_act.cpython-37.pyc
│ │ │ ├── create_act.cpython-38.pyc
│ │ │ ├── create_attn.cpython-37.pyc
│ │ │ ├── create_attn.cpython-38.pyc
│ │ │ ├── create_conv2d.cpython-37.pyc
│ │ │ ├── create_conv2d.cpython-38.pyc
│ │ │ ├── create_norm_act.cpython-37.pyc
│ │ │ ├── create_norm_act.cpython-38.pyc
│ │ │ ├── drop.cpython-37.pyc
│ │ │ ├── drop.cpython-38.pyc
│ │ │ ├── eca.cpython-37.pyc
│ │ │ ├── eca.cpython-38.pyc
│ │ │ ├── evo_norm.cpython-37.pyc
│ │ │ ├── evo_norm.cpython-38.pyc
│ │ │ ├── gather_excite.cpython-37.pyc
│ │ │ ├── gather_excite.cpython-38.pyc
│ │ │ ├── global_context.cpython-37.pyc
│ │ │ ├── global_context.cpython-38.pyc
│ │ │ ├── halo_attn.cpython-37.pyc
│ │ │ ├── halo_attn.cpython-38.pyc
│ │ │ ├── helpers.cpython-37.pyc
│ │ │ ├── helpers.cpython-38.pyc
│ │ │ ├── inplace_abn.cpython-37.pyc
│ │ │ ├── inplace_abn.cpython-38.pyc
│ │ │ ├── involution.cpython-37.pyc
│ │ │ ├── involution.cpython-38.pyc
│ │ │ ├── lambda_layer.cpython-37.pyc
│ │ │ ├── lambda_layer.cpython-38.pyc
│ │ │ ├── linear.cpython-37.pyc
│ │ │ ├── linear.cpython-38.pyc
│ │ │ ├── mixed_conv2d.cpython-37.pyc
│ │ │ ├── mixed_conv2d.cpython-38.pyc
│ │ │ ├── mlp.cpython-37.pyc
│ │ │ ├── mlp.cpython-38.pyc
│ │ │ ├── non_local_attn.cpython-37.pyc
│ │ │ ├── non_local_attn.cpython-38.pyc
│ │ │ ├── norm.cpython-37.pyc
│ │ │ ├── norm.cpython-38.pyc
│ │ │ ├── norm_act.cpython-37.pyc
│ │ │ ├── norm_act.cpython-38.pyc
│ │ │ ├── padding.cpython-37.pyc
│ │ │ ├── padding.cpython-38.pyc
│ │ │ ├── patch_embed.cpython-37.pyc
│ │ │ ├── patch_embed.cpython-38.pyc
│ │ │ ├── pool2d_same.cpython-37.pyc
│ │ │ ├── pool2d_same.cpython-38.pyc
│ │ │ ├── selective_kernel.cpython-37.pyc
│ │ │ ├── selective_kernel.cpython-38.pyc
│ │ │ ├── separable_conv.cpython-37.pyc
│ │ │ ├── separable_conv.cpython-38.pyc
│ │ │ ├── space_to_depth.cpython-37.pyc
│ │ │ ├── space_to_depth.cpython-38.pyc
│ │ │ ├── split_attn.cpython-37.pyc
│ │ │ ├── split_attn.cpython-38.pyc
│ │ │ ├── split_batchnorm.cpython-37.pyc
│ │ │ ├── split_batchnorm.cpython-38.pyc
│ │ │ ├── squeeze_excite.cpython-37.pyc
│ │ │ ├── squeeze_excite.cpython-38.pyc
│ │ │ ├── std_conv.cpython-37.pyc
│ │ │ ├── std_conv.cpython-38.pyc
│ │ │ ├── swin_attn.cpython-37.pyc
│ │ │ ├── swin_attn.cpython-38.pyc
│ │ │ ├── test_time_pool.cpython-37.pyc
│ │ │ ├── test_time_pool.cpython-38.pyc
│ │ │ ├── weight_init.cpython-37.pyc
│ │ │ └── weight_init.cpython-38.pyc
│ ├── levit.py
│ ├── mlp_mixer.py
│ ├── mobilenetv3.py
│ ├── nasnet.py
│ ├── nfnet.py
│ ├── pit.py
│ ├── pnasnet.py
│ ├── registry.py
│ ├── regnet.py
│ ├── res2net.py
│ ├── resnest.py
│ ├── resnet.py
│ ├── resnetv2.py
│ ├── rexnet.py
│ ├── selecsls.py
│ ├── senet.py
│ ├── sknet.py
│ ├── swin_transformer.py
│ ├── tnt.py
│ ├── tresnet.py
│ ├── twins.py
│ ├── vgg.py
│ ├── visformer.py
│ ├── vision_transformer.py
│ ├── vision_transformer_hybrid.py
│ ├── vovnet.py
│ ├── xception.py
│ └── xception_aligned.py
├── optim
│ ├── __init__.py
│ ├── adabelief.py
│ ├── adafactor.py
│ ├── adahessian.py
│ ├── adamp.py
│ ├── adamw.py
│ ├── lookahead.py
│ ├── nadam.py
│ ├── novograd.py
│ ├── nvnovograd.py
│ ├── optim_factory.py
│ ├── radam.py
│ ├── rmsprop_tf.py
│ └── sgdp.py
├── scheduler
│ ├── __init__.py
│ ├── cosine_lr.py
│ ├── plateau_lr.py
│ ├── scheduler.py
│ ├── scheduler_factory.py
│ ├── step_lr.py
│ └── tanh_lr.py
├── utils
│ ├── __init__.py
│ ├── agc.py
│ ├── checkpoint_saver.py
│ ├── clip_grad.py
│ ├── cuda.py
│ ├── distributed.py
│ ├── jit.py
│ ├── log.py
│ ├── metrics.py
│ ├── misc.py
│ ├── model.py
│ ├── model_ema.py
│ ├── random.py
│ └── summary.py
└── version.py
├── utils_data.py
├── utils_evaluate.py
├── utils_prompt.py
└── vision_features
└── mm-cot.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal Chain-of-Thought Reasoning in Language Models
2 |
3 |
"Imagine learning a textbook without figures or tables."
4 |
5 | Multimodal-CoT incorporates vision features in a decoupled training framework. The framework consists of two training stages: (i) rationale generation and (ii) answer inference. Both stages share the same model architecture but differ in the input and output.
6 |
7 | 
8 |
9 |
10 | ## Requirements
11 |
12 | Install all required python dependencies:
13 |
14 | ```
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ## Datasets
19 |
20 | Download the dataset from the following repository:
21 |
22 | ```
23 | https://github.com/lupantech/ScienceQA/tree/main/data
24 | ```
25 | The vision features (detr, resnet, clip, vit) are available at https://huggingface.co/cooelf/vision_features/tree/main
26 |
27 | Alternatively, you may download the extracted vision features (detr, resnet, clip) from [vision_features](https://drive.google.com/file/d/13B0hc_F_45-UlqPLKSgRz-ALtFQ8kIJr/view?usp=share_link) and unzip the files under `vision_features`
28 |
29 | ## Extract Features (optional)
30 |
31 | The processed vision features for ScienceQA are available at https://huggingface.co/cooelf/vision_features/tree/main.
32 |
33 | The following instructions show how we obtain those features.
34 |
35 | Download the image files from [Google Drive](https://drive.google.com/drive/folders/1w8imCXWYn2LxajmGeGH_g5DaL2rabHev?usp=sharing) and unzip all the images (train, dev, test) in the same folder (). The structure should be:
36 |
37 | ```
38 | images
39 | ├── 1
40 | │ └── image.png
41 | ├── 2
42 | │ └── image.png
43 | ├── 3
44 | │ └── image.png
45 | ├── 5
46 | │ └── image.png
47 | ├── 7
48 | │ └── image.png
49 | ```
50 |
51 | Run ```extract_features.py --data_root images --output_dir vision_features --img_type vit```
52 |
53 | If you hope to use your own images, please structure those images in the way above, or modify the script ```extract_features.py```.
54 |
55 | ## Extract Captions (optional)
56 |
57 | The processed captions for ScienceQA are available at ```data/instruct_captions.json```.
58 |
59 | The following instructions show how we obtain those features.
60 |
61 | Intall lavis and prepare Vicuna weights to use InstructBLIP for caption extraction.
62 |
63 | https://github.com/salesforce/LAVIS/tree/f982acc73288408bceda2d35471a8fcf55aa04ca/projects/instructblip
64 |
65 | Assume that the images are stored in the ```images``` folder.
66 |
67 | ```
68 | python extract_caption.py
69 | ```
70 |
71 | ## Instructions
72 |
73 | ### Training
74 |
75 | ```
76 | # rationale generation
77 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
78 | --data_root data/ScienceQA/data \
79 | --caption_file data/instruct_captions.json \
80 | --model declare-lab/flan-alpaca-large \
81 | --user_msg rationale --img_type vit \
82 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \
83 | --use_caption --use_generate --prompt_format QCM-E \
84 | --output_dir experiments
85 |
86 | # answer inference
87 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \
88 | --data_root data/ScienceQA/data \
89 | --caption_file data/instruct_captions.json \
90 | --model declare-lab/flan-alpaca-large \
91 | --user_msg answer --img_type vit \
92 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \
93 | --use_caption --use_generate --prompt_format QCMG-A \
94 | --output_dir experiments \
95 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \
96 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json
97 |
98 | ```
99 |
100 | ### Inference
101 |
102 | Our trained models are available at https://huggingface.co/cooelf/mm-cot/tree/main. To use our trained models, please put the them under the ```models``` folder.
103 |
104 | ```
105 | # rationale generation
106 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
107 | --data_root data/ScienceQA/data \
108 | --caption_file data/instruct_captions.json \
109 | --model declare-lab/flan-alpaca-large \
110 | --user_msg rationale --img_type vit \
111 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \
112 | --use_caption --use_generate --prompt_format QCM-E \
113 | --output_dir experiments
114 | --evaluate_dir models/mm-cot-large-rationale
115 |
116 | # answer inference
117 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \
118 | --data_root data/ScienceQA/data \
119 | --caption_file data/instruct_captions.json \
120 | --model declare-lab/flan-alpaca-large \
121 | --user_msg answer --img_type vit \
122 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \
123 | --use_caption --use_generate --prompt_format QCMG-A \
124 | --output_dir experiments \
125 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \
126 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json \
127 | --evaluate_dir models/mm-cot-large-answer
128 | ```
129 |
130 | ## Citing MM-CoT
131 |
132 | ```
133 | @article{zhang2023multicot,
134 | title={Multimodal Chain-of-Thought Reasoning in Language Models},
135 | author={Zhang, Zhuosheng and Zhang, Aston and Li, Mu and Zhao, Hai and Karypis, George and Smola, Alex},
136 | journal={arXiv preprint arXiv:2302.00923},
137 | year={2023}
138 | }
139 | ```
140 |
141 | ## License
142 |
143 | This project is licensed under the Apache-2.0 License.
144 |
145 | ## Acknowledgement
146 |
147 | Part of our codes are adapted from [ScienceQA](https://github.com/lupantech/ScienceQA), [Transformers](https://github.com/huggingface/transformers), [pytorch-image-models](https://github.com/huggingface/pytorch-image-models).
148 |
149 | We thank [Pan Lu](https://lupantech.github.io/) for providing parameter size for ScienceQA baselines.
150 |
--------------------------------------------------------------------------------
/__pycache__/evaluations.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/evaluations.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utils_data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_data.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utils_evaluate.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_evaluate.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/utils_prompt.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/__pycache__/utils_prompt.cpython-39.pyc
--------------------------------------------------------------------------------
/evaluations.py:
--------------------------------------------------------------------------------
1 | '''
2 | Adapted from https://github.com/lupantech/ScienceQA
3 | '''
4 |
5 | import re
6 | from rouge import Rouge
7 | from nltk.translate.bleu_score import sentence_bleu
8 | from sentence_transformers import util
9 |
10 | ########################
11 | ## BLEU
12 | ########################
13 | def tokenize(text):
14 | tokens = re.split(r'\s|\.', text)
15 | tokens = [t for t in tokens if len(t) > 0]
16 | return tokens
17 |
18 |
19 | def bleu_score(reference, hypothesis, gram):
20 | reference_tokens = tokenize(reference)
21 | hypothesis_tokens = tokenize(hypothesis)
22 |
23 | if gram == 1:
24 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1., )) # BELU-1
25 | elif gram == 2:
26 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 2., 1. / 2.)) # BELU-2
27 | elif gram == 3:
28 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 3., 1. / 3., 1. / 3.)) # BELU-3
29 | elif gram == 4:
30 | bleu = sentence_bleu([reference_tokens], hypothesis_tokens, (1. / 4., 1. / 4., 1. / 4., 1. / 4.)) # BELU-4
31 |
32 | return bleu
33 |
34 |
35 | def caculate_bleu(results, data, gram):
36 | bleus = []
37 | for qid, output in results.items():
38 | prediction = output
39 | target = data[qid]
40 | target = target.strip()
41 | if target == "":
42 | continue
43 | bleu = bleu_score(target, prediction, gram)
44 | bleus.append(bleu)
45 |
46 | avg_bleu = sum(bleus) / len(bleus)
47 |
48 | return avg_bleu
49 |
50 |
51 | ########################
52 | ## Rouge-L
53 | ########################
54 | def score_rouge(str1, str2):
55 | rouge = Rouge(metrics=["rouge-l"])
56 | scores = rouge.get_scores(str1, str2, avg=True)
57 | rouge_l = scores['rouge-l']['f']
58 | return rouge_l
59 |
60 |
61 | def caculate_rouge(results, data):
62 | rouges = []
63 | for qid, output in results.items():
64 | prediction = output
65 | target = data[qid]
66 | target = target.strip()
67 | if prediction == "":
68 | continue
69 | if target == "":
70 | continue
71 | rouge = score_rouge(target, prediction)
72 | rouges.append(rouge)
73 |
74 | avg_rouge = sum(rouges) / len(rouges)
75 | return avg_rouge
76 |
77 |
78 | ########################
79 | ## Sentence Similarity
80 | ########################
81 | def similariry_score(str1, str2, model):
82 | # compute embedding for both lists
83 | embedding_1 = model.encode(str1, convert_to_tensor=True)
84 | embedding_2 = model.encode(str2, convert_to_tensor=True)
85 | score = util.pytorch_cos_sim(embedding_1, embedding_2).item()
86 | return score
87 |
88 |
89 | def caculate_similariry(results, data, model):
90 | scores = []
91 | for qid, output in results.items():
92 | prediction = output
93 | target = data[qid]
94 | target = target.strip()
95 |
96 | score = similariry_score(target, prediction, model)
97 | scores.append(score)
98 |
99 | avg_score = sum(scores) / len(scores)
100 | return avg_score
101 |
--------------------------------------------------------------------------------
/extract_caption.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import os
4 | from tqdm import tqdm
5 | from lavis.models import load_model_and_preprocess
6 | import json
7 |
8 | # loads InstructBLIP model
9 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
10 | model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device)
11 |
12 | data_root = "data/images"
13 | output_dir = "data/instruct_captions.json"
14 |
15 | all_images = os.listdir(data_root)
16 | all_images.sort(key=lambda x:int(x))
17 |
18 | name_map = {}
19 |
20 | for image in tqdm(all_images):
21 | if os.path.exists(os.path.join(data_root, image, "image.png")):
22 | curr_dir = os.path.join(data_root, image, "image.png")
23 | else:
24 | curr_dir = os.path.join(data_root, image, "choice_0.png")
25 | raw_image = Image.open(curr_dir).convert("RGB")
26 | # prepare the image
27 | image_features = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
28 | output = model.generate({"image": image_features, "prompt": "Write a detailed description."})
29 | name_map[str(image)] = output
30 |
31 | with open(output_dir, 'w') as outfile:
32 | json.dump(name_map, outfile, indent=2)
--------------------------------------------------------------------------------
/extract_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import torchvision.transforms as T
4 | import timm
5 | from timm.data import resolve_data_config
6 | from timm.data.transforms_factory import create_transform
7 | import os
8 | import argparse
9 | import json
10 | from tqdm import tqdm
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--data_root', type=str, default='images')
15 | parser.add_argument('--output_dir', type=str, default='vision_features')
16 | parser.add_argument('--img_type', type=str, default="vit", choices=['detr', 'vit'], help='type of image features')
17 | args = parser.parse_args()
18 | return args
19 |
20 | def extract_features(img_type, input_image):
21 | if img_type == "vit":
22 | config = resolve_data_config({}, model=vit_model)
23 | transform = create_transform(**config)
24 | with torch.no_grad():
25 | img = Image.open(input_image).convert("RGB")
26 | input = transform(img).unsqueeze(0)
27 | feature = vit_model.forward_features(input)
28 | return feature
29 |
30 | elif img_type == "detr":
31 | transform = T.Compose([
32 | T.Resize(224),
33 | T.ToTensor(),
34 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35 | ])
36 | with torch.no_grad():
37 | img = Image.open(input_image).convert("RGB")
38 | input = transform(img).unsqueeze(0)
39 | feature = detr_model(input)[-1]
40 | return feature
41 |
42 | if __name__ == '__main__':
43 | args = parse_args()
44 | print("args",args)
45 | all_images = os.listdir(args.data_root)
46 | tmp = []
47 | name_map = {}
48 | all_images.sort(key=lambda x:int(x))
49 | print(len(all_images))
50 | if args.img_type == "vit":
51 | vit_model = timm.create_model("vit_large_patch32_384", pretrained=True, num_classes=0)
52 | vit_model.eval()
53 | elif args.img_type == "detr":
54 | detr_model = torch.hub.load('cooelf/detr', 'detr_resnet101_dc5', pretrained=True)
55 | detr_model.eval()
56 | for idx, image in enumerate(tqdm(all_images)):
57 | if idx % 100 == 0: print(idx)
58 | if os.path.exists(os.path.join(args.data_root, image, "image.png")):
59 | curr_dir = os.path.join(args.data_root, image, "image.png")
60 | else:
61 | curr_dir = os.path.join(args.data_root, image, "choice_0.png")
62 | feature = extract_features(args.img_type, curr_dir)
63 | tmp.append(feature.detach().cpu())
64 | name_map[str(image)] = idx
65 |
66 | res = torch.cat(tmp).cpu()
67 | print(res.shape)
68 | torch.save(res, os.path.join(args.output_dir, args.img_type +'.pth'))
69 | with open(os.path.join(args.output_dir, 'name_map.json'), 'w') as outfile:
70 | json.dump(name_map, outfile)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | huggingface-hub>=0.4.0
2 | numpy==1.23.2
3 | openai==0.23.0
4 | pandas==1.4.3
5 | rouge==1.0.1
6 | sentence-transformers==2.2.2
7 | transformers==4.30.0
8 | nltk==3.6.6
9 | evaluate==0.4.0
10 | rouge==1.0.1
11 | rouge_score==0.1.2
12 | rich>=13.3.2
13 |
--------------------------------------------------------------------------------
/run_inference.sh:
--------------------------------------------------------------------------------
1 | # base
2 | CUDA_VISIBLE_DEVICES=0 python main_central.py \
3 | --data_root data/ScienceQA/data \
4 | --caption_file data/instruct_captions.json \
5 | --model declare-lab/flan-alpaca-base \
6 | --user_msg rationale --img_type vit \
7 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 512 \
8 | --use_caption --use_generate --final_eval --prompt_format QCM-E \
9 | --output_dir experiments \
10 | --evaluate_dir models/mm-cot-base-rationale
11 |
12 | CUDA_VISIBLE_DEVICES=0 python main_central.py \
13 | --data_root data/ScienceQA/data \
14 | --caption_file data/instruct_captions.json \
15 | --model declare-lab/flan-alpaca-base \
16 | --user_msg rationale --img_type vit \
17 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 64 \
18 | --use_caption --use_generate --prompt_format QCMG-A \
19 | --output_dir experiments \
20 | --eval_le models/mm-cot-base-rationale/predictions_ans_eval.json \
21 | --test_le models/mm-cot-base-rationale/predictions_ans_test.json \
22 | --evaluate_dir models/mm-cot-base-answer
23 |
24 | # large
25 | # rationale generation
26 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
27 | --data_root data/ScienceQA/data \
28 | --caption_file data/instruct_captions.json \
29 | --model declare-lab/flan-alpaca-large \
30 | --user_msg rationale --img_type vit \
31 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \
32 | --use_caption --use_generate --prompt_format QCM-E \
33 | --output_dir experiments \
34 | --evaluate_dir models/mm-cot-large-rationale
35 |
36 | # answer inference
37 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \
38 | --data_root data/ScienceQA/data \
39 | --caption_file data/instruct_captions.json \
40 | --model declare-lab/flan-alpaca-large \
41 | --user_msg answer --img_type vit \
42 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \
43 | --use_caption --use_generate --prompt_format QCMG-A \
44 | --output_dir experiments \
45 | --eval_le models/mm-cot-large-rationale/predictions_ans_eval.json \
46 | --test_le models/mm-cot-large-rationale/predictions_ans_test.json \
47 | --evaluate_dir models/mm-cot-large-answer
--------------------------------------------------------------------------------
/run_training.sh:
--------------------------------------------------------------------------------
1 | # base
2 | CUDA_VISIBLE_DEVICES=0 python main_central.py \
3 | --data_root data/ScienceQA/data \
4 | --caption_file data/instruct_captions.json \
5 | --model declare-lab/flan-alpaca-base \
6 | --user_msg rationale --img_type vit \
7 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 512 \
8 | --use_caption --use_generate --final_eval --prompt_format QCM-E \
9 | --output_dir experiments0620
10 |
11 | CUDA_VISIBLE_DEVICES=0 python main_central.py \
12 | --data_root data/ScienceQA/data \
13 | --caption_file data/instruct_captions.json \
14 | --model declare-lab/flan-alpaca-base \
15 | --user_msg rationale --img_type vit \
16 | --bs 8 --eval_bs 8 --epoch 20 --lr 8e-5 --output_len 64 \
17 | --use_caption --use_generate --prompt_format QCMG-A \
18 | --output_dir experiments0620 \
19 | --eval_le experiments/rationale_declare-lab-flan-alpaca-base_vit_QCM-E_lr8e-05_bs8_op512_ep20/predictions_ans_eval.json \
20 | --test_le experiments/rationale_declare-lab-flan-alpaca-base_vit_QCM-E_lr8e-05_bs8_op512_ep20/predictions_ans_test.json
21 |
22 | # large
23 | # rationale generation
24 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py \
25 | --data_root data/ScienceQA/data \
26 | --caption_file data/instruct_captions.json \
27 | --model declare-lab/flan-alpaca-large \
28 | --user_msg rationale --img_type vit \
29 | --bs 2 --eval_bs 4 --epoch 50 --lr 5e-5 --output_len 512 \
30 | --use_caption --use_generate --prompt_format QCM-E \
31 | --output_dir experiments
32 |
33 | # answer inference
34 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main_central.py \
35 | --data_root data/ScienceQA/data \
36 | --caption_file data/instruct_captions.json \
37 | --model declare-lab/flan-alpaca-large \
38 | --user_msg answer --img_type vit \
39 | --bs 4 --eval_bs 8 --epoch 50 --lr 5e-5 --output_len 64 \
40 | --use_caption --use_generate --prompt_format QCMG-A \
41 | --output_dir experiments \
42 | --eval_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_eval.json \
43 | --test_le experiments/rationale_declare-lab-flan-alpaca-large_vit_QCM-E_lr5e-05_bs8_op512_ep50/predictions_ans_test.json
--------------------------------------------------------------------------------
/timm/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3 | is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
4 | get_model_default_value, is_model_pretrained
5 |
--------------------------------------------------------------------------------
/timm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/__pycache__/version.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/version.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/__pycache__/version.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/__pycache__/version.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2 | rand_augment_transform, auto_augment_transform
3 | from .config import resolve_data_config
4 | from .constants import *
5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6 | from .dataset_factory import create_dataset
7 | from .loader import create_loader
8 | from .mixup import Mixup, FastCollateMixup
9 | from .parsers import create_parser
10 | from .real_labels import RealLabelsImagenet
11 | from .transforms import *
12 | from .transforms_factory import create_transform
--------------------------------------------------------------------------------
/timm/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/auto_augment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/auto_augment.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/auto_augment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/auto_augment.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/constants.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/constants.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/constants.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/constants.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/dataset_factory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset_factory.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/dataset_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/dataset_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/distributed_sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/distributed_sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/distributed_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/distributed_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/loader.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/loader.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/mixup.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/mixup.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/mixup.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/mixup.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/random_erasing.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/random_erasing.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/random_erasing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/random_erasing.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/real_labels.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/real_labels.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/real_labels.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/real_labels.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/transforms_factory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms_factory.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/data/__pycache__/transforms_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/data/__pycache__/transforms_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/data/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .constants import *
3 |
4 |
5 | _logger = logging.getLogger(__name__)
6 |
7 |
8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
9 | new_config = {}
10 | default_cfg = default_cfg
11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
12 | default_cfg = model.default_cfg
13 |
14 | # Resolve input/image size
15 | in_chans = 3
16 | if 'chans' in args and args['chans'] is not None:
17 | in_chans = args['chans']
18 |
19 | input_size = (in_chans, 224, 224)
20 | if 'input_size' in args and args['input_size'] is not None:
21 | assert isinstance(args['input_size'], (tuple, list))
22 | assert len(args['input_size']) == 3
23 | input_size = tuple(args['input_size'])
24 | in_chans = input_size[0] # input_size overrides in_chans
25 | elif 'img_size' in args and args['img_size'] is not None:
26 | assert isinstance(args['img_size'], int)
27 | input_size = (in_chans, args['img_size'], args['img_size'])
28 | else:
29 | if use_test_size and 'test_input_size' in default_cfg:
30 | input_size = default_cfg['test_input_size']
31 | elif 'input_size' in default_cfg:
32 | input_size = default_cfg['input_size']
33 | new_config['input_size'] = input_size
34 |
35 | # resolve interpolation method
36 | new_config['interpolation'] = 'bicubic'
37 | if 'interpolation' in args and args['interpolation']:
38 | new_config['interpolation'] = args['interpolation']
39 | elif 'interpolation' in default_cfg:
40 | new_config['interpolation'] = default_cfg['interpolation']
41 |
42 | # resolve dataset + model mean for normalization
43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN
44 | if 'mean' in args and args['mean'] is not None:
45 | mean = tuple(args['mean'])
46 | if len(mean) == 1:
47 | mean = tuple(list(mean) * in_chans)
48 | else:
49 | assert len(mean) == in_chans
50 | new_config['mean'] = mean
51 | elif 'mean' in default_cfg:
52 | new_config['mean'] = default_cfg['mean']
53 |
54 | # resolve dataset + model std deviation for normalization
55 | new_config['std'] = IMAGENET_DEFAULT_STD
56 | if 'std' in args and args['std'] is not None:
57 | std = tuple(args['std'])
58 | if len(std) == 1:
59 | std = tuple(list(std) * in_chans)
60 | else:
61 | assert len(std) == in_chans
62 | new_config['std'] = std
63 | elif 'std' in default_cfg:
64 | new_config['std'] = default_cfg['std']
65 |
66 | # resolve default crop percentage
67 | new_config['crop_pct'] = DEFAULT_CROP_PCT
68 | if 'crop_pct' in args and args['crop_pct'] is not None:
69 | new_config['crop_pct'] = args['crop_pct']
70 | elif 'crop_pct' in default_cfg:
71 | new_config['crop_pct'] = default_cfg['crop_pct']
72 |
73 | if verbose:
74 | _logger.info('Data processing configuration for current model + dataset:')
75 | for n, v in new_config.items():
76 | _logger.info('\t%s: %s' % (n, str(v)))
77 |
78 | return new_config
79 |
--------------------------------------------------------------------------------
/timm/data/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_CROP_PCT = 0.875
2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
8 |
--------------------------------------------------------------------------------
/timm/data/dataset.py:
--------------------------------------------------------------------------------
1 | """ Quick n Simple Image Folder, Tarfile based DataSet
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch.utils.data as data
6 | import os
7 | import torch
8 | import logging
9 |
10 | from PIL import Image
11 |
12 | from .parsers import create_parser
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | _ERROR_RETRY = 50
18 |
19 |
20 | class ImageDataset(data.Dataset):
21 |
22 | def __init__(
23 | self,
24 | root,
25 | parser=None,
26 | class_map='',
27 | load_bytes=False,
28 | transform=None,
29 | ):
30 | if parser is None or isinstance(parser, str):
31 | parser = create_parser(parser or '', root=root, class_map=class_map)
32 | self.parser = parser
33 | self.load_bytes = load_bytes
34 | self.transform = transform
35 | self._consecutive_errors = 0
36 |
37 | def __getitem__(self, index):
38 | img, target = self.parser[index]
39 | try:
40 | img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
41 | except Exception as e:
42 | _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
43 | self._consecutive_errors += 1
44 | if self._consecutive_errors < _ERROR_RETRY:
45 | return self.__getitem__((index + 1) % len(self.parser))
46 | else:
47 | raise e
48 | self._consecutive_errors = 0
49 | if self.transform is not None:
50 | img = self.transform(img)
51 | if target is None:
52 | target = torch.tensor(-1, dtype=torch.long)
53 | return img, target
54 |
55 | def __len__(self):
56 | return len(self.parser)
57 |
58 | def filename(self, index, basename=False, absolute=False):
59 | return self.parser.filename(index, basename, absolute)
60 |
61 | def filenames(self, basename=False, absolute=False):
62 | return self.parser.filenames(basename, absolute)
63 |
64 |
65 | class IterableImageDataset(data.IterableDataset):
66 |
67 | def __init__(
68 | self,
69 | root,
70 | parser=None,
71 | split='train',
72 | is_training=False,
73 | batch_size=None,
74 | class_map='',
75 | load_bytes=False,
76 | repeats=0,
77 | transform=None,
78 | ):
79 | assert parser is not None
80 | if isinstance(parser, str):
81 | self.parser = create_parser(
82 | parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
83 | else:
84 | self.parser = parser
85 | self.transform = transform
86 | self._consecutive_errors = 0
87 |
88 | def __iter__(self):
89 | for img, target in self.parser:
90 | if self.transform is not None:
91 | img = self.transform(img)
92 | if target is None:
93 | target = torch.tensor(-1, dtype=torch.long)
94 | yield img, target
95 |
96 | def __len__(self):
97 | if hasattr(self.parser, '__len__'):
98 | return len(self.parser)
99 | else:
100 | return 0
101 |
102 | def filename(self, index, basename=False, absolute=False):
103 | assert False, 'Filename lookup by index not supported, use filenames().'
104 |
105 | def filenames(self, basename=False, absolute=False):
106 | return self.parser.filenames(basename, absolute)
107 |
108 |
109 | class AugMixDataset(torch.utils.data.Dataset):
110 | """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
111 |
112 | def __init__(self, dataset, num_splits=2):
113 | self.augmentation = None
114 | self.normalize = None
115 | self.dataset = dataset
116 | if self.dataset.transform is not None:
117 | self._set_transforms(self.dataset.transform)
118 | self.num_splits = num_splits
119 |
120 | def _set_transforms(self, x):
121 | assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
122 | self.dataset.transform = x[0]
123 | self.augmentation = x[1]
124 | self.normalize = x[2]
125 |
126 | @property
127 | def transform(self):
128 | return self.dataset.transform
129 |
130 | @transform.setter
131 | def transform(self, x):
132 | self._set_transforms(x)
133 |
134 | def _normalize(self, x):
135 | return x if self.normalize is None else self.normalize(x)
136 |
137 | def __getitem__(self, i):
138 | x, y = self.dataset[i] # all splits share the same dataset base transform
139 | x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
140 | # run the full augmentation on the remaining splits
141 | for _ in range(self.num_splits - 1):
142 | x_list.append(self._normalize(self.augmentation(x)))
143 | return tuple(x_list), y
144 |
145 | def __len__(self):
146 | return len(self.dataset)
147 |
--------------------------------------------------------------------------------
/timm/data/dataset_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .dataset import IterableImageDataset, ImageDataset
4 |
5 |
6 | def _search_split(root, split):
7 | # look for sub-folder with name of split in root and use that if it exists
8 | split_name = split.split('[')[0]
9 | try_root = os.path.join(root, split_name)
10 | if os.path.exists(try_root):
11 | return try_root
12 | if split_name == 'validation':
13 | try_root = os.path.join(root, 'val')
14 | if os.path.exists(try_root):
15 | return try_root
16 | return root
17 |
18 |
19 | def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
20 | name = name.lower()
21 | if name.startswith('tfds'):
22 | ds = IterableImageDataset(
23 | root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
24 | else:
25 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
26 | kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
27 | if search_split and os.path.isdir(root):
28 | root = _search_split(root, split)
29 | ds = ImageDataset(root, parser=name, **kwargs)
30 | return ds
31 |
--------------------------------------------------------------------------------
/timm/data/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data import Sampler
4 | import torch.distributed as dist
5 |
6 |
7 | class OrderedDistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 | It is especially useful in conjunction with
10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
11 | process can pass a DistributedSampler instance as a DataLoader sampler,
12 | and load a subset of the original dataset that is exclusive to it.
13 | .. note::
14 | Dataset is assumed to be of constant size.
15 | Arguments:
16 | dataset: Dataset used for sampling.
17 | num_replicas (optional): Number of processes participating in
18 | distributed training.
19 | rank (optional): Rank of the current process within num_replicas.
20 | """
21 |
22 | def __init__(self, dataset, num_replicas=None, rank=None):
23 | if num_replicas is None:
24 | if not dist.is_available():
25 | raise RuntimeError("Requires distributed package to be available")
26 | num_replicas = dist.get_world_size()
27 | if rank is None:
28 | if not dist.is_available():
29 | raise RuntimeError("Requires distributed package to be available")
30 | rank = dist.get_rank()
31 | self.dataset = dataset
32 | self.num_replicas = num_replicas
33 | self.rank = rank
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | indices = list(range(len(self.dataset)))
39 |
40 | # add extra samples to make it evenly divisible
41 | indices += indices[:(self.total_size - len(indices))]
42 | assert len(indices) == self.total_size
43 |
44 | # subsample
45 | indices = indices[self.rank:self.total_size:self.num_replicas]
46 | assert len(indices) == self.num_samples
47 |
48 | return iter(indices)
49 |
50 | def __len__(self):
51 | return self.num_samples
52 |
--------------------------------------------------------------------------------
/timm/data/parsers/__init__.py:
--------------------------------------------------------------------------------
1 | from .parser_factory import create_parser
2 |
--------------------------------------------------------------------------------
/timm/data/parsers/class_map.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def load_class_map(filename, root=''):
5 | class_map_path = filename
6 | if not os.path.exists(class_map_path):
7 | class_map_path = os.path.join(root, filename)
8 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
9 | class_map_ext = os.path.splitext(filename)[-1].lower()
10 | if class_map_ext == '.txt':
11 | with open(class_map_path) as f:
12 | class_to_idx = {v.strip(): k for k, v in enumerate(f)}
13 | else:
14 | assert False, 'Unsupported class map extension'
15 | return class_to_idx
16 |
17 |
--------------------------------------------------------------------------------
/timm/data/parsers/constants.py:
--------------------------------------------------------------------------------
1 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
2 |
--------------------------------------------------------------------------------
/timm/data/parsers/parser.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class Parser:
5 | def __init__(self):
6 | pass
7 |
8 | @abstractmethod
9 | def _filename(self, index, basename=False, absolute=False):
10 | pass
11 |
12 | def filename(self, index, basename=False, absolute=False):
13 | return self._filename(index, basename=basename, absolute=absolute)
14 |
15 | def filenames(self, basename=False, absolute=False):
16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17 |
18 |
--------------------------------------------------------------------------------
/timm/data/parsers/parser_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .parser_image_folder import ParserImageFolder
4 | from .parser_image_tar import ParserImageTar
5 | from .parser_image_in_tar import ParserImageInTar
6 |
7 |
8 | def create_parser(name, root, split='train', **kwargs):
9 | name = name.lower()
10 | name = name.split('/', 2)
11 | prefix = ''
12 | if len(name) > 1:
13 | prefix = name[0]
14 | name = name[-1]
15 |
16 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
17 | # explicitly select other options shortly
18 | if prefix == 'tfds':
19 | from .parser_tfds import ParserTfds # defer tensorflow import
20 | parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
21 | else:
22 | assert os.path.exists(root)
23 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
24 | # FIXME support split here, in parser?
25 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
26 | parser = ParserImageInTar(root, **kwargs)
27 | else:
28 | parser = ParserImageFolder(root, **kwargs)
29 | return parser
30 |
--------------------------------------------------------------------------------
/timm/data/parsers/parser_image_folder.py:
--------------------------------------------------------------------------------
1 | """ A dataset parser that reads images from folders
2 |
3 | Folders are scannerd recursively to find image files. Labels are based
4 | on the folder hierarchy, just leaf folders by default.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 |
10 | from timm.utils.misc import natural_key
11 |
12 | from .parser import Parser
13 | from .class_map import load_class_map
14 | from .constants import IMG_EXTENSIONS
15 |
16 |
17 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
18 | labels = []
19 | filenames = []
20 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
21 | rel_path = os.path.relpath(root, folder) if (root != folder) else ''
22 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
23 | for f in files:
24 | base, ext = os.path.splitext(f)
25 | if ext.lower() in types:
26 | filenames.append(os.path.join(root, f))
27 | labels.append(label)
28 | if class_to_idx is None:
29 | # building class index
30 | unique_labels = set(labels)
31 | sorted_labels = list(sorted(unique_labels, key=natural_key))
32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
33 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
34 | if sort:
35 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
36 | return images_and_targets, class_to_idx
37 |
38 |
39 | class ParserImageFolder(Parser):
40 |
41 | def __init__(
42 | self,
43 | root,
44 | class_map=''):
45 | super().__init__()
46 |
47 | self.root = root
48 | class_to_idx = None
49 | if class_map:
50 | class_to_idx = load_class_map(class_map, root)
51 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
52 | if len(self.samples) == 0:
53 | raise RuntimeError(
54 | f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
55 |
56 | def __getitem__(self, index):
57 | path, target = self.samples[index]
58 | return open(path, 'rb'), target
59 |
60 | def __len__(self):
61 | return len(self.samples)
62 |
63 | def _filename(self, index, basename=False, absolute=False):
64 | filename = self.samples[index][0]
65 | if basename:
66 | filename = os.path.basename(filename)
67 | elif not absolute:
68 | filename = os.path.relpath(filename, self.root)
69 | return filename
70 |
--------------------------------------------------------------------------------
/timm/data/parsers/parser_image_tar.py:
--------------------------------------------------------------------------------
1 | """ A dataset parser that reads single tarfile based datasets
2 |
3 | This parser can read datasets consisting if a single tarfile containing images.
4 | I am planning to deprecated it in favour of ParerImageInTar.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | import tarfile
10 |
11 | from .parser import Parser
12 | from .class_map import load_class_map
13 | from .constants import IMG_EXTENSIONS
14 | from timm.utils.misc import natural_key
15 |
16 |
17 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
18 | files = []
19 | labels = []
20 | for ti in tarfile.getmembers():
21 | if not ti.isfile():
22 | continue
23 | dirname, basename = os.path.split(ti.path)
24 | label = os.path.basename(dirname)
25 | ext = os.path.splitext(basename)[1]
26 | if ext.lower() in IMG_EXTENSIONS:
27 | files.append(ti)
28 | labels.append(label)
29 | if class_to_idx is None:
30 | unique_labels = set(labels)
31 | sorted_labels = list(sorted(unique_labels, key=natural_key))
32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
33 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
34 | if sort:
35 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
36 | return tarinfo_and_targets, class_to_idx
37 |
38 |
39 | class ParserImageTar(Parser):
40 | """ Single tarfile dataset where classes are mapped to folders within tar
41 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
42 | operate on folders of tars or tars in tars.
43 | """
44 | def __init__(self, root, class_map=''):
45 | super().__init__()
46 |
47 | class_to_idx = None
48 | if class_map:
49 | class_to_idx = load_class_map(class_map, root)
50 | assert os.path.isfile(root)
51 | self.root = root
52 |
53 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
54 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
55 | self.imgs = self.samples
56 | self.tarfile = None # lazy init in __getitem__
57 |
58 | def __getitem__(self, index):
59 | if self.tarfile is None:
60 | self.tarfile = tarfile.open(self.root)
61 | tarinfo, target = self.samples[index]
62 | fileobj = self.tarfile.extractfile(tarinfo)
63 | return fileobj, target
64 |
65 | def __len__(self):
66 | return len(self.samples)
67 |
68 | def _filename(self, index, basename=False, absolute=False):
69 | filename = self.samples[index][0].name
70 | if basename:
71 | filename = os.path.basename(filename)
72 | return filename
73 |
--------------------------------------------------------------------------------
/timm/data/random_erasing.py:
--------------------------------------------------------------------------------
1 | """ Random Erasing (Cutout)
2 |
3 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
4 | Copyright Zhun Zhong & Liang Zheng
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import random
9 | import math
10 | import torch
11 |
12 |
13 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
21 | else:
22 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
23 |
24 |
25 | class RandomErasing:
26 | """ Randomly selects a rectangle region in an image and erases its pixels.
27 | 'Random Erasing Data Augmentation' by Zhong et al.
28 | See https://arxiv.org/pdf/1708.04896.pdf
29 |
30 | This variant of RandomErasing is intended to be applied to either a batch
31 | or single image tensor after it has been normalized by dataset mean and std.
32 | Args:
33 | probability: Probability that the Random Erasing operation will be performed.
34 | min_area: Minimum percentage of erased area wrt input image area.
35 | max_area: Maximum percentage of erased area wrt input image area.
36 | min_aspect: Minimum aspect ratio of erased area.
37 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
38 | 'const' - erase block is constant color of 0 for all channels
39 | 'rand' - erase block is same per-channel random (normal) color
40 | 'pixel' - erase block is per-pixel random (normal) color
41 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
42 | per-image count is randomly chosen between 1 and this value.
43 | """
44 |
45 | def __init__(
46 | self,
47 | probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
48 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
49 | self.probability = probability
50 | self.min_area = min_area
51 | self.max_area = max_area
52 | max_aspect = max_aspect or 1 / min_aspect
53 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
54 | self.min_count = min_count
55 | self.max_count = max_count or min_count
56 | self.num_splits = num_splits
57 | mode = mode.lower()
58 | self.rand_color = False
59 | self.per_pixel = False
60 | if mode == 'rand':
61 | self.rand_color = True # per block random normal
62 | elif mode == 'pixel':
63 | self.per_pixel = True # per pixel random normal
64 | else:
65 | assert not mode or mode == 'const'
66 | self.device = device
67 |
68 | def _erase(self, img, chan, img_h, img_w, dtype):
69 | if random.random() > self.probability:
70 | return
71 | area = img_h * img_w
72 | count = self.min_count if self.min_count == self.max_count else \
73 | random.randint(self.min_count, self.max_count)
74 | for _ in range(count):
75 | for attempt in range(10):
76 | target_area = random.uniform(self.min_area, self.max_area) * area / count
77 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
78 | h = int(round(math.sqrt(target_area * aspect_ratio)))
79 | w = int(round(math.sqrt(target_area / aspect_ratio)))
80 | if w < img_w and h < img_h:
81 | top = random.randint(0, img_h - h)
82 | left = random.randint(0, img_w - w)
83 | img[:, top:top + h, left:left + w] = _get_pixels(
84 | self.per_pixel, self.rand_color, (chan, h, w),
85 | dtype=dtype, device=self.device)
86 | break
87 |
88 | def __call__(self, input):
89 | if len(input.size()) == 3:
90 | self._erase(input, *input.size(), input.dtype)
91 | else:
92 | batch_size, chan, img_h, img_w = input.size()
93 | # skip first slice of batch if num_splits is set (for clean portion of samples)
94 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
95 | for i in range(batch_start, batch_size):
96 | self._erase(input[i], chan, img_h, img_w, input.dtype)
97 | return input
98 |
--------------------------------------------------------------------------------
/timm/data/real_labels.py:
--------------------------------------------------------------------------------
1 | """ Real labels evaluator for ImageNet
2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import os
8 | import json
9 | import numpy as np
10 |
11 |
12 | class RealLabelsImagenet:
13 |
14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
15 | with open(real_json) as real_labels:
16 | real_labels = json.load(real_labels)
17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
18 | self.real_labels = real_labels
19 | self.filenames = filenames
20 | assert len(self.filenames) == len(self.real_labels)
21 | self.topk = topk
22 | self.is_correct = {k: [] for k in topk}
23 | self.sample_idx = 0
24 |
25 | def add_result(self, output):
26 | maxk = max(self.topk)
27 | _, pred_batch = output.topk(maxk, 1, True, True)
28 | pred_batch = pred_batch.cpu().numpy()
29 | for pred in pred_batch:
30 | filename = self.filenames[self.sample_idx]
31 | filename = os.path.basename(filename)
32 | if self.real_labels[filename]:
33 | for k in self.topk:
34 | self.is_correct[k].append(
35 | any([p in self.real_labels[filename] for p in pred[:k]]))
36 | self.sample_idx += 1
37 |
38 | def get_accuracy(self, k=None):
39 | if k is None:
40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
41 | else:
42 | return float(np.mean(self.is_correct[k])) * 100
43 |
--------------------------------------------------------------------------------
/timm/data/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | from PIL import Image
4 | import warnings
5 | import math
6 | import random
7 | import numpy as np
8 |
9 |
10 | class ToNumpy:
11 |
12 | def __call__(self, pil_img):
13 | np_img = np.array(pil_img, dtype=np.uint8)
14 | if np_img.ndim < 3:
15 | np_img = np.expand_dims(np_img, axis=-1)
16 | np_img = np.rollaxis(np_img, 2) # HWC to CHW
17 | return np_img
18 |
19 |
20 | class ToTensor:
21 |
22 | def __init__(self, dtype=torch.float32):
23 | self.dtype = dtype
24 |
25 | def __call__(self, pil_img):
26 | np_img = np.array(pil_img, dtype=np.uint8)
27 | if np_img.ndim < 3:
28 | np_img = np.expand_dims(np_img, axis=-1)
29 | np_img = np.rollaxis(np_img, 2) # HWC to CHW
30 | return torch.from_numpy(np_img).to(dtype=self.dtype)
31 |
32 |
33 | _pil_interpolation_to_str = {
34 | Image.NEAREST: 'PIL.Image.NEAREST',
35 | Image.BILINEAR: 'PIL.Image.BILINEAR',
36 | Image.BICUBIC: 'PIL.Image.BICUBIC',
37 | Image.LANCZOS: 'PIL.Image.LANCZOS',
38 | Image.HAMMING: 'PIL.Image.HAMMING',
39 | Image.BOX: 'PIL.Image.BOX',
40 | }
41 |
42 |
43 | def _pil_interp(method):
44 | if method == 'bicubic':
45 | return Image.BICUBIC
46 | elif method == 'lanczos':
47 | return Image.LANCZOS
48 | elif method == 'hamming':
49 | return Image.HAMMING
50 | else:
51 | # default bilinear, do we want to allow nearest?
52 | return Image.BILINEAR
53 |
54 |
55 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
56 |
57 |
58 | class RandomResizedCropAndInterpolation:
59 | """Crop the given PIL Image to random size and aspect ratio with random interpolation.
60 |
61 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
62 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
63 | is finally resized to given size.
64 | This is popularly used to train the Inception networks.
65 |
66 | Args:
67 | size: expected output size of each edge
68 | scale: range of size of the origin size cropped
69 | ratio: range of aspect ratio of the origin aspect ratio cropped
70 | interpolation: Default: PIL.Image.BILINEAR
71 | """
72 |
73 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
74 | interpolation='bilinear'):
75 | if isinstance(size, (list, tuple)):
76 | self.size = tuple(size)
77 | else:
78 | self.size = (size, size)
79 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
80 | warnings.warn("range should be of kind (min, max)")
81 |
82 | if interpolation == 'random':
83 | self.interpolation = _RANDOM_INTERPOLATION
84 | else:
85 | self.interpolation = _pil_interp(interpolation)
86 | self.scale = scale
87 | self.ratio = ratio
88 |
89 | @staticmethod
90 | def get_params(img, scale, ratio):
91 | """Get parameters for ``crop`` for a random sized crop.
92 |
93 | Args:
94 | img (PIL Image): Image to be cropped.
95 | scale (tuple): range of size of the origin size cropped
96 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
97 |
98 | Returns:
99 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
100 | sized crop.
101 | """
102 | area = img.size[0] * img.size[1]
103 |
104 | for attempt in range(10):
105 | target_area = random.uniform(*scale) * area
106 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
107 | aspect_ratio = math.exp(random.uniform(*log_ratio))
108 |
109 | w = int(round(math.sqrt(target_area * aspect_ratio)))
110 | h = int(round(math.sqrt(target_area / aspect_ratio)))
111 |
112 | if w <= img.size[0] and h <= img.size[1]:
113 | i = random.randint(0, img.size[1] - h)
114 | j = random.randint(0, img.size[0] - w)
115 | return i, j, h, w
116 |
117 | # Fallback to central crop
118 | in_ratio = img.size[0] / img.size[1]
119 | if in_ratio < min(ratio):
120 | w = img.size[0]
121 | h = int(round(w / min(ratio)))
122 | elif in_ratio > max(ratio):
123 | h = img.size[1]
124 | w = int(round(h * max(ratio)))
125 | else: # whole image
126 | w = img.size[0]
127 | h = img.size[1]
128 | i = (img.size[1] - h) // 2
129 | j = (img.size[0] - w) // 2
130 | return i, j, h, w
131 |
132 | def __call__(self, img):
133 | """
134 | Args:
135 | img (PIL Image): Image to be cropped and resized.
136 |
137 | Returns:
138 | PIL Image: Randomly cropped and resized image.
139 | """
140 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
141 | if isinstance(self.interpolation, (tuple, list)):
142 | interpolation = random.choice(self.interpolation)
143 | else:
144 | interpolation = self.interpolation
145 | return F.resized_crop(img, i, j, h, w, self.size, interpolation)
146 |
147 | def __repr__(self):
148 | if isinstance(self.interpolation, (tuple, list)):
149 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
150 | else:
151 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
152 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
153 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
154 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
155 | format_string += ', interpolation={0})'.format(interpolate_str)
156 | return format_string
157 |
158 |
159 |
--------------------------------------------------------------------------------
/timm/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
2 | from .jsd import JsdCrossEntropy
3 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
--------------------------------------------------------------------------------
/timm/loss/asymmetric_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AsymmetricLossMultiLabel(nn.Module):
6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
7 | super(AsymmetricLossMultiLabel, self).__init__()
8 |
9 | self.gamma_neg = gamma_neg
10 | self.gamma_pos = gamma_pos
11 | self.clip = clip
12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
13 | self.eps = eps
14 |
15 | def forward(self, x, y):
16 | """"
17 | Parameters
18 | ----------
19 | x: input logits
20 | y: targets (multi-label binarized vector)
21 | """
22 |
23 | # Calculating Probabilities
24 | x_sigmoid = torch.sigmoid(x)
25 | xs_pos = x_sigmoid
26 | xs_neg = 1 - x_sigmoid
27 |
28 | # Asymmetric Clipping
29 | if self.clip is not None and self.clip > 0:
30 | xs_neg = (xs_neg + self.clip).clamp(max=1)
31 |
32 | # Basic CE calculation
33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
35 | loss = los_pos + los_neg
36 |
37 | # Asymmetric Focusing
38 | if self.gamma_neg > 0 or self.gamma_pos > 0:
39 | if self.disable_torch_grad_focal_loss:
40 | torch._C.set_grad_enabled(False)
41 | pt0 = xs_pos * y
42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
43 | pt = pt0 + pt1
44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma)
46 | if self.disable_torch_grad_focal_loss:
47 | torch._C.set_grad_enabled(True)
48 | loss *= one_sided_w
49 |
50 | return -loss.sum()
51 |
52 |
53 | class AsymmetricLossSingleLabel(nn.Module):
54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
55 | super(AsymmetricLossSingleLabel, self).__init__()
56 |
57 | self.eps = eps
58 | self.logsoftmax = nn.LogSoftmax(dim=-1)
59 | self.targets_classes = [] # prevent gpu repeated memory allocation
60 | self.gamma_pos = gamma_pos
61 | self.gamma_neg = gamma_neg
62 | self.reduction = reduction
63 |
64 | def forward(self, inputs, target, reduction=None):
65 | """"
66 | Parameters
67 | ----------
68 | x: input logits
69 | y: targets (1-hot vector)
70 | """
71 |
72 | num_classes = inputs.size()[-1]
73 | log_preds = self.logsoftmax(inputs)
74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
75 |
76 | # ASL weights
77 | targets = self.targets_classes
78 | anti_targets = 1 - targets
79 | xs_pos = torch.exp(log_preds)
80 | xs_neg = 1 - xs_pos
81 | xs_pos = xs_pos * targets
82 | xs_neg = xs_neg * anti_targets
83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
84 | self.gamma_pos * targets + self.gamma_neg * anti_targets)
85 | log_preds = log_preds * asymmetric_w
86 |
87 | if self.eps > 0: # label smoothing
88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
89 |
90 | # loss calculation
91 | loss = - self.targets_classes.mul(log_preds)
92 |
93 | loss = loss.sum(dim=-1)
94 | if self.reduction == 'mean':
95 | loss = loss.mean()
96 |
97 | return loss
98 |
--------------------------------------------------------------------------------
/timm/loss/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LabelSmoothingCrossEntropy(nn.Module):
7 | """
8 | NLL loss with label smoothing.
9 | """
10 | def __init__(self, smoothing=0.1):
11 | """
12 | Constructor for the LabelSmoothing module.
13 | :param smoothing: label smoothing factor
14 | """
15 | super(LabelSmoothingCrossEntropy, self).__init__()
16 | assert smoothing < 1.0
17 | self.smoothing = smoothing
18 | self.confidence = 1. - smoothing
19 |
20 | def forward(self, x, target):
21 | logprobs = F.log_softmax(x, dim=-1)
22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
23 | nll_loss = nll_loss.squeeze(1)
24 | smooth_loss = -logprobs.mean(dim=-1)
25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
26 | return loss.mean()
27 |
28 |
29 | class SoftTargetCrossEntropy(nn.Module):
30 |
31 | def __init__(self):
32 | super(SoftTargetCrossEntropy, self).__init__()
33 |
34 | def forward(self, x, target):
35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
36 | return loss.mean()
37 |
--------------------------------------------------------------------------------
/timm/loss/jsd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .cross_entropy import LabelSmoothingCrossEntropy
6 |
7 |
8 | class JsdCrossEntropy(nn.Module):
9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss
10 |
11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
13 | https://arxiv.org/abs/1912.02781
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
18 | super().__init__()
19 | self.num_splits = num_splits
20 | self.alpha = alpha
21 | if smoothing is not None and smoothing > 0:
22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
23 | else:
24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
25 |
26 | def __call__(self, output, target):
27 | split_size = output.shape[0] // self.num_splits
28 | assert split_size * self.num_splits == output.shape[0]
29 | logits_split = torch.split(output, split_size)
30 |
31 | # Cross-entropy is only computed on clean images
32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
33 | probs = [F.softmax(logits, dim=1) for logits in logits_split]
34 |
35 | # Clamp mixture distribution to avoid exploding KL divergence
36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
37 | loss += self.alpha * sum([F.kl_div(
38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
39 | return loss
40 |
--------------------------------------------------------------------------------
/timm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .byoanet import *
2 | from .byobnet import *
3 | from .cait import *
4 | from .coat import *
5 | from .convit import *
6 | from .cspnet import *
7 | from .densenet import *
8 | from .dla import *
9 | from .dpn import *
10 | from .efficientnet import *
11 | from .ghostnet import *
12 | from .gluon_resnet import *
13 | from .gluon_xception import *
14 | from .hardcorenas import *
15 | from .hrnet import *
16 | from .inception_resnet_v2 import *
17 | from .inception_v3 import *
18 | from .inception_v4 import *
19 | from .levit import *
20 | from .mlp_mixer import *
21 | from .mobilenetv3 import *
22 | from .nasnet import *
23 | from .nfnet import *
24 | from .pit import *
25 | from .pnasnet import *
26 | from .regnet import *
27 | from .res2net import *
28 | from .resnest import *
29 | from .resnet import *
30 | from .resnetv2 import *
31 | from .rexnet import *
32 | from .selecsls import *
33 | from .senet import *
34 | from .sknet import *
35 | from .swin_transformer import *
36 | from .tnt import *
37 | from .tresnet import *
38 | from .vgg import *
39 | from .visformer import *
40 | from .vision_transformer import *
41 | from .vision_transformer_hybrid import *
42 | from .vovnet import *
43 | from .xception import *
44 | from .xception_aligned import *
45 | from .twins import *
46 |
47 | from .factory import create_model, split_model_name, safe_model_name
48 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters
49 | from .layers import TestTimePoolHead, apply_test_time_pool
50 | from .layers import convert_splitbn_model
51 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
52 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
53 | has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained
54 |
--------------------------------------------------------------------------------
/timm/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/byoanet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byoanet.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/byoanet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byoanet.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/byobnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byobnet.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/byobnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/byobnet.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/cait.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/cait.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/__pycache__/cait.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/__pycache__/cait.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/factory.py:
--------------------------------------------------------------------------------
1 | from .registry import is_model, is_model_in_modules, model_entrypoint
2 | from .helpers import load_checkpoint
3 | from .layers import set_layer_config
4 | from .hub import load_model_config_from_hf
5 |
6 |
7 | def split_model_name(model_name):
8 | model_split = model_name.split(':', 1)
9 | if len(model_split) == 1:
10 | return '', model_split[0]
11 | else:
12 | source_name, model_name = model_split
13 | assert source_name in ('timm', 'hf_hub')
14 | return source_name, model_name
15 |
16 |
17 | def safe_model_name(model_name, remove_source=True):
18 | def make_safe(name):
19 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
20 | if remove_source:
21 | model_name = split_model_name(model_name)[-1]
22 | return make_safe(model_name)
23 |
24 |
25 | def create_model(
26 | model_name,
27 | pretrained=False,
28 | checkpoint_path='',
29 | scriptable=None,
30 | exportable=None,
31 | no_jit=None,
32 | **kwargs):
33 | """Create a model
34 |
35 | Args:
36 | model_name (str): name of model to instantiate
37 | pretrained (bool): load pretrained ImageNet-1k weights if true
38 | checkpoint_path (str): path of checkpoint to load after model is initialized
39 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
40 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
41 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
42 |
43 | Keyword Args:
44 | drop_rate (float): dropout rate for training (default: 0.0)
45 | global_pool (str): global pool type (default: 'avg')
46 | **: other kwargs are model specific
47 | """
48 | source_name, model_name = split_model_name(model_name)
49 |
50 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
51 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
52 | if not is_efficientnet:
53 | kwargs.pop('bn_tf', None)
54 | kwargs.pop('bn_momentum', None)
55 | kwargs.pop('bn_eps', None)
56 |
57 | # handle backwards compat with drop_connect -> drop_path change
58 | drop_connect_rate = kwargs.pop('drop_connect_rate', None)
59 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
60 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
61 | " Setting drop_path to %f." % drop_connect_rate)
62 | kwargs['drop_path_rate'] = drop_connect_rate
63 |
64 | # Parameters that aren't supported by all models or are intended to only override model defaults if set
65 | # should default to None in command line args/cfg. Remove them if they are present and not set so that
66 | # non-supporting models don't break and default args remain in effect.
67 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
68 |
69 | if source_name == 'hf_hub':
70 | # For model names specified in the form `hf_hub:path/architecture_name#revision`,
71 | # load model weights + default_cfg from Hugging Face hub.
72 | hf_default_cfg, model_name = load_model_config_from_hf(model_name)
73 | kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
74 |
75 | if is_model(model_name):
76 | create_fn = model_entrypoint(model_name)
77 | else:
78 | raise RuntimeError('Unknown model (%s)' % model_name)
79 |
80 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
81 | model = create_fn(pretrained=pretrained, **kwargs)
82 |
83 | if checkpoint_path:
84 | load_checkpoint(model, checkpoint_path)
85 |
86 | return model
87 |
--------------------------------------------------------------------------------
/timm/models/hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from functools import partial
5 | from typing import Union, Optional
6 |
7 | import torch
8 | from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
9 | try:
10 | from torch.hub import get_dir
11 | except ImportError:
12 | from torch.hub import _get_torch_home as get_dir
13 |
14 | from timm import __version__
15 | try:
16 | from huggingface_hub import hf_hub_url
17 | from huggingface_hub import cached_download
18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__)
19 | except ImportError:
20 | hf_hub_url = None
21 | cached_download = None
22 |
23 | _logger = logging.getLogger(__name__)
24 |
25 |
26 | def get_cache_dir(child_dir=''):
27 | """
28 | Returns the location of the directory where models are cached (and creates it if necessary).
29 | """
30 | # Issue warning to move data if old env is set
31 | if os.getenv('TORCH_MODEL_ZOO'):
32 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
33 |
34 | hub_dir = get_dir()
35 | child_dir = () if not child_dir else (child_dir,)
36 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
37 | os.makedirs(model_dir, exist_ok=True)
38 | return model_dir
39 |
40 |
41 | def download_cached_file(url, check_hash=True, progress=False):
42 | parts = urlparse(url)
43 | filename = os.path.basename(parts.path)
44 | cached_file = os.path.join(get_cache_dir(), filename)
45 | if not os.path.exists(cached_file):
46 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
47 | hash_prefix = None
48 | if check_hash:
49 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
50 | hash_prefix = r.group(1) if r else None
51 | download_url_to_file(url, cached_file, hash_prefix, progress=progress)
52 | return cached_file
53 |
54 |
55 | def has_hf_hub(necessary=False):
56 | if hf_hub_url is None and necessary:
57 | # if no HF Hub module installed and it is necessary to continue, raise error
58 | raise RuntimeError(
59 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
60 | return hf_hub_url is not None
61 |
62 |
63 | def hf_split(hf_id):
64 | rev_split = hf_id.split('@')
65 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
66 | hf_model_id = rev_split[0]
67 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None
68 | return hf_model_id, hf_revision
69 |
70 |
71 | def load_cfg_from_json(json_file: Union[str, os.PathLike]):
72 | with open(json_file, "r", encoding="utf-8") as reader:
73 | text = reader.read()
74 | return json.loads(text)
75 |
76 |
77 | def _download_from_hf(model_id: str, filename: str):
78 | hf_model_id, hf_revision = hf_split(model_id)
79 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
80 | return cached_download(url, cache_dir=get_cache_dir('hf'))
81 |
82 |
83 | def load_model_config_from_hf(model_id: str):
84 | assert has_hf_hub(True)
85 | cached_file = _download_from_hf(model_id, 'config.json')
86 | default_cfg = load_cfg_from_json(cached_file)
87 | default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
88 | model_name = default_cfg.get('architecture')
89 | return default_cfg, model_name
90 |
91 |
92 | def load_state_dict_from_hf(model_id: str):
93 | assert has_hf_hub(True)
94 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
95 | state_dict = torch.load(cached_file, map_location='cpu')
96 | return state_dict
97 |
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations_jit.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_jit.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations_me.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_me.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/activations_me.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/activations_me.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/blur_pool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/blur_pool.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/bottleneck_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/bottleneck_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/cbam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cbam.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/cbam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cbam.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/classifier.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/classifier.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/classifier.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/classifier.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/cond_conv2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cond_conv2d.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/conv2d_same.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv2d_same.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/conv_bn_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv_bn_act.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_act.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_act.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_conv2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_conv2d.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_norm_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_norm_act.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/drop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/drop.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/drop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/drop.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/eca.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/eca.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/eca.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/eca.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/evo_norm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/evo_norm.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/gather_excite.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/gather_excite.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/global_context.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/global_context.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/global_context.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/global_context.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/halo_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/halo_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/helpers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/helpers.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/inplace_abn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/inplace_abn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/involution.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/involution.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/involution.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/involution.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/lambda_layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/lambda_layer.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/linear.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/linear.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/linear.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/linear.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/mixed_conv2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mixed_conv2d.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/mlp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mlp.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/mlp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/mlp.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/non_local_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/non_local_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/norm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/norm_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm_act.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/norm_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/norm_act.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/padding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/padding.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/padding.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/padding.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/patch_embed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/patch_embed.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/pool2d_same.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/pool2d_same.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/selective_kernel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/selective_kernel.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/separable_conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/separable_conv.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/space_to_depth.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/space_to_depth.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/split_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/split_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/split_batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/squeeze_excite.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/squeeze_excite.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/std_conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/std_conv.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/std_conv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/std_conv.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/swin_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/swin_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/test_time_pool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/test_time_pool.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/weight_init.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/weight_init.cpython-37.pyc
--------------------------------------------------------------------------------
/timm/models/layers/__pycache__/weight_init.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/timm/models/layers/__pycache__/weight_init.cpython-38.pyc
--------------------------------------------------------------------------------
/timm/models/registry.py:
--------------------------------------------------------------------------------
1 | """ Model Registry
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 |
5 | import sys
6 | import re
7 | import fnmatch
8 | from collections import defaultdict
9 | from copy import deepcopy
10 |
11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
12 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
13 |
14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
15 | _model_to_module = {} # mapping of model names to module names
16 | _model_entrypoints = {} # mapping of model names to entrypoint fns
17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present
18 | _model_default_cfgs = dict() # central repo for model default_cfgs
19 |
20 |
21 | def register_model(fn):
22 | # lookup containing module
23 | mod = sys.modules[fn.__module__]
24 | module_name_split = fn.__module__.split('.')
25 | module_name = module_name_split[-1] if len(module_name_split) else ''
26 |
27 | # add model to __all__ in module
28 | model_name = fn.__name__
29 | if hasattr(mod, '__all__'):
30 | mod.__all__.append(model_name)
31 | else:
32 | mod.__all__ = [model_name]
33 |
34 | # add entries to registry dict/sets
35 | _model_entrypoints[model_name] = fn
36 | _model_to_module[model_name] = module_name
37 | _module_to_models[module_name].add(model_name)
38 | has_pretrained = False # check if model has a pretrained url to allow filtering on this
39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
41 | # entrypoints or non-matching combos
42 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
43 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
44 | if has_pretrained:
45 | _model_has_pretrained.add(model_name)
46 | return fn
47 |
48 |
49 | def _natural_key(string_):
50 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
51 |
52 |
53 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
54 | """ Return list of available model names, sorted alphabetically
55 |
56 | Args:
57 | filter (str) - Wildcard filter string that works with fnmatch
58 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
59 | pretrained (bool) - Include only models with pretrained weights if True
60 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
61 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
62 |
63 | Example:
64 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
65 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
66 | """
67 | if module:
68 | all_models = list(_module_to_models[module])
69 | else:
70 | all_models = _model_entrypoints.keys()
71 | if filter:
72 | models = []
73 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
74 | for f in include_filters:
75 | include_models = fnmatch.filter(all_models, f) # include these models
76 | if len(include_models):
77 | models = set(models).union(include_models)
78 | else:
79 | models = all_models
80 | if exclude_filters:
81 | if not isinstance(exclude_filters, (tuple, list)):
82 | exclude_filters = [exclude_filters]
83 | for xf in exclude_filters:
84 | exclude_models = fnmatch.filter(models, xf) # exclude these models
85 | if len(exclude_models):
86 | models = set(models).difference(exclude_models)
87 | if pretrained:
88 | models = _model_has_pretrained.intersection(models)
89 | if name_matches_cfg:
90 | models = set(_model_default_cfgs).intersection(models)
91 | return list(sorted(models, key=_natural_key))
92 |
93 |
94 | def is_model(model_name):
95 | """ Check if a model name exists
96 | """
97 | return model_name in _model_entrypoints
98 |
99 |
100 | def model_entrypoint(model_name):
101 | """Fetch a model entrypoint for specified model name
102 | """
103 | return _model_entrypoints[model_name]
104 |
105 |
106 | def list_modules():
107 | """ Return list of module names that contain models / model entrypoints
108 | """
109 | modules = _module_to_models.keys()
110 | return list(sorted(modules))
111 |
112 |
113 | def is_model_in_modules(model_name, module_names):
114 | """Check if a model exists within a subset of modules
115 | Args:
116 | model_name (str) - name of model to check
117 | module_names (tuple, list, set) - names of modules to search in
118 | """
119 | assert isinstance(module_names, (tuple, list, set))
120 | return any(model_name in _module_to_models[n] for n in module_names)
121 |
122 |
123 | def has_model_default_key(model_name, cfg_key):
124 | """ Query model default_cfgs for existence of a specific key.
125 | """
126 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
127 | return True
128 | return False
129 |
130 |
131 | def is_model_default_key(model_name, cfg_key):
132 | """ Return truthy value for specified model default_cfg key, False if does not exist.
133 | """
134 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
135 | return True
136 | return False
137 |
138 |
139 | def get_model_default_value(model_name, cfg_key):
140 | """ Get a specific model default_cfg value by key. None if it doesn't exist.
141 | """
142 | if model_name in _model_default_cfgs:
143 | return _model_default_cfgs[model_name].get(cfg_key, None)
144 | else:
145 | return None
146 |
147 |
148 | def is_model_pretrained(model_name):
149 | return model_name in _model_has_pretrained
150 |
--------------------------------------------------------------------------------
/timm/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adamp import AdamP
2 | from .adamw import AdamW
3 | from .adafactor import Adafactor
4 | from .adahessian import Adahessian
5 | from .lookahead import Lookahead
6 | from .nadam import Nadam
7 | from .novograd import NovoGrad
8 | from .nvnovograd import NvNovoGrad
9 | from .radam import RAdam
10 | from .rmsprop_tf import RMSpropTF
11 | from .sgdp import SGDP
12 | from .adabelief import AdaBelief
13 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
14 |
--------------------------------------------------------------------------------
/timm/optim/adamp.py:
--------------------------------------------------------------------------------
1 | """
2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class AdamP(Optimizer):
17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21 | super(AdamP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 |
65 | grad = p.grad.data
66 | beta1, beta2 = group['betas']
67 | nesterov = group['nesterov']
68 |
69 | state = self.state[p]
70 |
71 | # State initialization
72 | if len(state) == 0:
73 | state['step'] = 0
74 | state['exp_avg'] = torch.zeros_like(p.data)
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 |
77 | # Adam
78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79 |
80 | state['step'] += 1
81 | bias_correction1 = 1 - beta1 ** state['step']
82 | bias_correction2 = 1 - beta2 ** state['step']
83 |
84 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86 |
87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
88 | step_size = group['lr'] / bias_correction1
89 |
90 | if nesterov:
91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
92 | else:
93 | perturb = exp_avg / denom
94 |
95 | # Projection
96 | wd_ratio = 1
97 | if len(p.shape) > 1:
98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
99 |
100 | # Weight decay
101 | if group['weight_decay'] > 0:
102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
103 |
104 | # Step
105 | p.data.add_(-step_size, perturb)
106 |
107 | return loss
108 |
--------------------------------------------------------------------------------
/timm/optim/adamw.py:
--------------------------------------------------------------------------------
1 | """ AdamW Optimizer
2 | Impl copied from PyTorch master
3 | """
4 | import math
5 | import torch
6 | from torch.optim.optimizer import Optimizer
7 |
8 |
9 | class AdamW(Optimizer):
10 | r"""Implements AdamW algorithm.
11 |
12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14 |
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 | (default: False)
27 |
28 | .. _Adam\: A Method for Stochastic Optimization:
29 | https://arxiv.org/abs/1412.6980
30 | .. _Decoupled Weight Decay Regularization:
31 | https://arxiv.org/abs/1711.05101
32 | .. _On the Convergence of Adam and Beyond:
33 | https://openreview.net/forum?id=ryQu7f-RZ
34 | """
35 |
36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37 | weight_decay=1e-2, amsgrad=False):
38 | if not 0.0 <= lr:
39 | raise ValueError("Invalid learning rate: {}".format(lr))
40 | if not 0.0 <= eps:
41 | raise ValueError("Invalid epsilon value: {}".format(eps))
42 | if not 0.0 <= betas[0] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44 | if not 0.0 <= betas[1] < 1.0:
45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46 | defaults = dict(lr=lr, betas=betas, eps=eps,
47 | weight_decay=weight_decay, amsgrad=amsgrad)
48 | super(AdamW, self).__init__(params, defaults)
49 |
50 | def __setstate__(self, state):
51 | super(AdamW, self).__setstate__(state)
52 | for group in self.param_groups:
53 | group.setdefault('amsgrad', False)
54 |
55 | def step(self, closure=None):
56 | """Performs a single optimization step.
57 |
58 | Arguments:
59 | closure (callable, optional): A closure that reevaluates the model
60 | and returns the loss.
61 | """
62 | loss = None
63 | if closure is not None:
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in group['params']:
68 | if p.grad is None:
69 | continue
70 |
71 | # Perform stepweight decay
72 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
73 |
74 | # Perform optimization step
75 | grad = p.grad.data
76 | if grad.is_sparse:
77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78 | amsgrad = group['amsgrad']
79 |
80 | state = self.state[p]
81 |
82 | # State initialization
83 | if len(state) == 0:
84 | state['step'] = 0
85 | # Exponential moving average of gradient values
86 | state['exp_avg'] = torch.zeros_like(p.data)
87 | # Exponential moving average of squared gradient values
88 | state['exp_avg_sq'] = torch.zeros_like(p.data)
89 | if amsgrad:
90 | # Maintains max of all exp. moving avg. of sq. grad. values
91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92 |
93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94 | if amsgrad:
95 | max_exp_avg_sq = state['max_exp_avg_sq']
96 | beta1, beta2 = group['betas']
97 |
98 | state['step'] += 1
99 | bias_correction1 = 1 - beta1 ** state['step']
100 | bias_correction2 = 1 - beta2 ** state['step']
101 |
102 | # Decay the first and second moment running average coefficient
103 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105 | if amsgrad:
106 | # Maintains the maximum of all 2nd moment running avg. till now
107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108 | # Use the max. for normalizing running avg. of gradient
109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110 | else:
111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112 |
113 | step_size = group['lr'] / bias_correction1
114 |
115 | p.data.addcdiv_(-step_size, exp_avg, denom)
116 |
117 | return loss
118 |
--------------------------------------------------------------------------------
/timm/optim/lookahead.py:
--------------------------------------------------------------------------------
1 | """ Lookahead Optimizer Wrapper.
2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | from collections import defaultdict
10 |
11 |
12 | class Lookahead(Optimizer):
13 | def __init__(self, base_optimizer, alpha=0.5, k=6):
14 | if not 0.0 <= alpha <= 1.0:
15 | raise ValueError(f'Invalid slow update rate: {alpha}')
16 | if not 1 <= k:
17 | raise ValueError(f'Invalid lookahead steps: {k}')
18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
19 | self.base_optimizer = base_optimizer
20 | self.param_groups = self.base_optimizer.param_groups
21 | self.defaults = base_optimizer.defaults
22 | self.defaults.update(defaults)
23 | self.state = defaultdict(dict)
24 | # manually add our defaults to the param groups
25 | for name, default in defaults.items():
26 | for group in self.param_groups:
27 | group.setdefault(name, default)
28 |
29 | def update_slow(self, group):
30 | for fast_p in group["params"]:
31 | if fast_p.grad is None:
32 | continue
33 | param_state = self.state[fast_p]
34 | if 'slow_buffer' not in param_state:
35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data)
36 | param_state['slow_buffer'].copy_(fast_p.data)
37 | slow = param_state['slow_buffer']
38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow)
39 | fast_p.data.copy_(slow)
40 |
41 | def sync_lookahead(self):
42 | for group in self.param_groups:
43 | self.update_slow(group)
44 |
45 | def step(self, closure=None):
46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
47 | loss = self.base_optimizer.step(closure)
48 | for group in self.param_groups:
49 | group['lookahead_step'] += 1
50 | if group['lookahead_step'] % group['lookahead_k'] == 0:
51 | self.update_slow(group)
52 | return loss
53 |
54 | def state_dict(self):
55 | fast_state_dict = self.base_optimizer.state_dict()
56 | slow_state = {
57 | (id(k) if isinstance(k, torch.Tensor) else k): v
58 | for k, v in self.state.items()
59 | }
60 | fast_state = fast_state_dict['state']
61 | param_groups = fast_state_dict['param_groups']
62 | return {
63 | 'state': fast_state,
64 | 'slow_state': slow_state,
65 | 'param_groups': param_groups,
66 | }
67 |
68 | def load_state_dict(self, state_dict):
69 | fast_state_dict = {
70 | 'state': state_dict['state'],
71 | 'param_groups': state_dict['param_groups'],
72 | }
73 | self.base_optimizer.load_state_dict(fast_state_dict)
74 |
75 | # We want to restore the slow state, but share param_groups reference
76 | # with base_optimizer. This is a bit redundant but least code
77 | slow_state_new = False
78 | if 'slow_state' not in state_dict:
79 | print('Loading state_dict from optimizer without Lookahead applied.')
80 | state_dict['slow_state'] = defaultdict(dict)
81 | slow_state_new = True
82 | slow_state_dict = {
83 | 'state': state_dict['slow_state'],
84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code
85 | }
86 | super(Lookahead, self).load_state_dict(slow_state_dict)
87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container
88 | if slow_state_new:
89 | # reapply defaults to catch missing lookahead specific ones
90 | for name, default in self.defaults.items():
91 | for group in self.param_groups:
92 | group.setdefault(name, default)
93 |
--------------------------------------------------------------------------------
/timm/optim/nadam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Optimizer
3 |
4 |
5 | class Nadam(Optimizer):
6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7 |
8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9 |
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 2e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20 |
21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf
22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23 |
24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25 | NOTE: Has potential issues but does work well on some problems.
26 | """
27 |
28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
29 | weight_decay=0, schedule_decay=4e-3):
30 | defaults = dict(lr=lr, betas=betas, eps=eps,
31 | weight_decay=weight_decay, schedule_decay=schedule_decay)
32 | super(Nadam, self).__init__(params, defaults)
33 |
34 | def step(self, closure=None):
35 | """Performs a single optimization step.
36 |
37 | Arguments:
38 | closure (callable, optional): A closure that reevaluates the model
39 | and returns the loss.
40 | """
41 | loss = None
42 | if closure is not None:
43 | loss = closure()
44 |
45 | for group in self.param_groups:
46 | for p in group['params']:
47 | if p.grad is None:
48 | continue
49 | grad = p.grad.data
50 | state = self.state[p]
51 |
52 | # State initialization
53 | if len(state) == 0:
54 | state['step'] = 0
55 | state['m_schedule'] = 1.
56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
58 |
59 | # Warming momentum schedule
60 | m_schedule = state['m_schedule']
61 | schedule_decay = group['schedule_decay']
62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
63 | beta1, beta2 = group['betas']
64 | eps = group['eps']
65 | state['step'] += 1
66 | t = state['step']
67 |
68 | if group['weight_decay'] != 0:
69 | grad = grad.add(group['weight_decay'], p.data)
70 |
71 | momentum_cache_t = beta1 * \
72 | (1. - 0.5 * (0.96 ** (t * schedule_decay)))
73 | momentum_cache_t_1 = beta1 * \
74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
75 | m_schedule_new = m_schedule * momentum_cache_t
76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
77 | state['m_schedule'] = m_schedule_new
78 |
79 | # Decay the first and second moment running average coefficient
80 | exp_avg.mul_(beta1).add_(1. - beta1, grad)
81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
83 | denom = exp_avg_sq_prime.sqrt_().add_(eps)
84 |
85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
87 |
88 | return loss
89 |
--------------------------------------------------------------------------------
/timm/optim/novograd.py:
--------------------------------------------------------------------------------
1 | """NovoGrad Optimizer.
2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4 | - https://arxiv.org/abs/1905.11286
5 | """
6 |
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | import math
10 |
11 |
12 | class NovoGrad(Optimizer):
13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
15 | super(NovoGrad, self).__init__(params, defaults)
16 | self._lr = lr
17 | self._beta1 = betas[0]
18 | self._beta2 = betas[1]
19 | self._eps = eps
20 | self._wd = weight_decay
21 | self._grad_averaging = grad_averaging
22 |
23 | self._momentum_initialized = False
24 |
25 | def step(self, closure=None):
26 | loss = None
27 | if closure is not None:
28 | loss = closure()
29 |
30 | if not self._momentum_initialized:
31 | for group in self.param_groups:
32 | for p in group['params']:
33 | if p.grad is None:
34 | continue
35 | state = self.state[p]
36 | grad = p.grad.data
37 | if grad.is_sparse:
38 | raise RuntimeError('NovoGrad does not support sparse gradients')
39 |
40 | v = torch.norm(grad)**2
41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
42 | state['step'] = 0
43 | state['v'] = v
44 | state['m'] = m
45 | state['grad_ema'] = None
46 | self._momentum_initialized = True
47 |
48 | for group in self.param_groups:
49 | for p in group['params']:
50 | if p.grad is None:
51 | continue
52 | state = self.state[p]
53 | state['step'] += 1
54 |
55 | step, v, m = state['step'], state['v'], state['m']
56 | grad_ema = state['grad_ema']
57 |
58 | grad = p.grad.data
59 | g2 = torch.norm(grad)**2
60 | grad_ema = g2 if grad_ema is None else grad_ema * \
61 | self._beta2 + g2 * (1. - self._beta2)
62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
63 |
64 | if self._grad_averaging:
65 | grad *= (1. - self._beta1)
66 |
67 | g2 = torch.norm(grad)**2
68 | v = self._beta2*v + (1. - self._beta2)*g2
69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
70 | bias_correction1 = 1 - self._beta1 ** step
71 | bias_correction2 = 1 - self._beta2 ** step
72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
73 |
74 | state['v'], state['m'] = v, m
75 | state['grad_ema'] = grad_ema
76 | p.data.add_(-step_size, m)
77 | return loss
78 |
--------------------------------------------------------------------------------
/timm/optim/nvnovograd.py:
--------------------------------------------------------------------------------
1 | """ Nvidia NovoGrad Optimizer.
2 | Original impl by Nvidia from Jasper example:
3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5 | - https://arxiv.org/abs/1905.11286
6 | """
7 |
8 | import torch
9 | from torch.optim.optimizer import Optimizer
10 | import math
11 |
12 |
13 | class NvNovoGrad(Optimizer):
14 | """
15 | Implements Novograd algorithm.
16 |
17 | Args:
18 | params (iterable): iterable of parameters to optimize or dicts defining
19 | parameter groups
20 | lr (float, optional): learning rate (default: 1e-3)
21 | betas (Tuple[float, float], optional): coefficients used for computing
22 | running averages of gradient and its square (default: (0.95, 0.98))
23 | eps (float, optional): term added to the denominator to improve
24 | numerical stability (default: 1e-8)
25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26 | grad_averaging: gradient averaging
27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28 | algorithm from the paper `On the Convergence of Adam and Beyond`_
29 | (default: False)
30 | """
31 |
32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33 | weight_decay=0, grad_averaging=False, amsgrad=False):
34 | if not 0.0 <= lr:
35 | raise ValueError("Invalid learning rate: {}".format(lr))
36 | if not 0.0 <= eps:
37 | raise ValueError("Invalid epsilon value: {}".format(eps))
38 | if not 0.0 <= betas[0] < 1.0:
39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40 | if not 0.0 <= betas[1] < 1.0:
41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42 | defaults = dict(lr=lr, betas=betas, eps=eps,
43 | weight_decay=weight_decay,
44 | grad_averaging=grad_averaging,
45 | amsgrad=amsgrad)
46 |
47 | super(NvNovoGrad, self).__init__(params, defaults)
48 |
49 | def __setstate__(self, state):
50 | super(NvNovoGrad, self).__setstate__(state)
51 | for group in self.param_groups:
52 | group.setdefault('amsgrad', False)
53 |
54 | def step(self, closure=None):
55 | """Performs a single optimization step.
56 |
57 | Arguments:
58 | closure (callable, optional): A closure that reevaluates the model
59 | and returns the loss.
60 | """
61 | loss = None
62 | if closure is not None:
63 | loss = closure()
64 |
65 | for group in self.param_groups:
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | if grad.is_sparse:
71 | raise RuntimeError('Sparse gradients are not supported.')
72 | amsgrad = group['amsgrad']
73 |
74 | state = self.state[p]
75 |
76 | # State initialization
77 | if len(state) == 0:
78 | state['step'] = 0
79 | # Exponential moving average of gradient values
80 | state['exp_avg'] = torch.zeros_like(p.data)
81 | # Exponential moving average of squared gradient values
82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
83 | if amsgrad:
84 | # Maintains max of all exp. moving avg. of sq. grad. values
85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
86 |
87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88 | if amsgrad:
89 | max_exp_avg_sq = state['max_exp_avg_sq']
90 | beta1, beta2 = group['betas']
91 |
92 | state['step'] += 1
93 |
94 | norm = torch.sum(torch.pow(grad, 2))
95 |
96 | if exp_avg_sq == 0:
97 | exp_avg_sq.copy_(norm)
98 | else:
99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
100 |
101 | if amsgrad:
102 | # Maintains the maximum of all 2nd moment running avg. till now
103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
104 | # Use the max. for normalizing running avg. of gradient
105 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
106 | else:
107 | denom = exp_avg_sq.sqrt().add_(group['eps'])
108 |
109 | grad.div_(denom)
110 | if group['weight_decay'] != 0:
111 | grad.add_(group['weight_decay'], p.data)
112 | if group['grad_averaging']:
113 | grad.mul_(1 - beta1)
114 | exp_avg.mul_(beta1).add_(grad)
115 |
116 | p.data.add_(-group['lr'], exp_avg)
117 |
118 | return loss
119 |
--------------------------------------------------------------------------------
/timm/optim/radam.py:
--------------------------------------------------------------------------------
1 | """RAdam Optimizer.
2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4 | """
5 | import math
6 | import torch
7 | from torch.optim.optimizer import Optimizer, required
8 |
9 |
10 | class RAdam(Optimizer):
11 |
12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
14 | self.buffer = [[None, None, None] for ind in range(10)]
15 | super(RAdam, self).__init__(params, defaults)
16 |
17 | def __setstate__(self, state):
18 | super(RAdam, self).__setstate__(state)
19 |
20 | def step(self, closure=None):
21 |
22 | loss = None
23 | if closure is not None:
24 | loss = closure()
25 |
26 | for group in self.param_groups:
27 |
28 | for p in group['params']:
29 | if p.grad is None:
30 | continue
31 | grad = p.grad.data.float()
32 | if grad.is_sparse:
33 | raise RuntimeError('RAdam does not support sparse gradients')
34 |
35 | p_data_fp32 = p.data.float()
36 |
37 | state = self.state[p]
38 |
39 | if len(state) == 0:
40 | state['step'] = 0
41 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
43 | else:
44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
46 |
47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
48 | beta1, beta2 = group['betas']
49 |
50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
51 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
52 |
53 | state['step'] += 1
54 | buffered = self.buffer[int(state['step'] % 10)]
55 | if state['step'] == buffered[0]:
56 | N_sma, step_size = buffered[1], buffered[2]
57 | else:
58 | buffered[0] = state['step']
59 | beta2_t = beta2 ** state['step']
60 | N_sma_max = 2 / (1 - beta2) - 1
61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
62 | buffered[1] = N_sma
63 |
64 | # more conservative since it's an approximated value
65 | if N_sma >= 5:
66 | step_size = group['lr'] * math.sqrt(
67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
68 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
69 | else:
70 | step_size = group['lr'] / (1 - beta1 ** state['step'])
71 | buffered[2] = step_size
72 |
73 | if group['weight_decay'] != 0:
74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
75 |
76 | # more conservative since it's an approximated value
77 | if N_sma >= 5:
78 | denom = exp_avg_sq.sqrt().add_(group['eps'])
79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
80 | else:
81 | p_data_fp32.add_(-step_size, exp_avg)
82 |
83 | p.data.copy_(p_data_fp32)
84 |
85 | return loss
86 |
87 |
88 | class PlainRAdam(Optimizer):
89 |
90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
92 |
93 | super(PlainRAdam, self).__init__(params, defaults)
94 |
95 | def __setstate__(self, state):
96 | super(PlainRAdam, self).__setstate__(state)
97 |
98 | def step(self, closure=None):
99 |
100 | loss = None
101 | if closure is not None:
102 | loss = closure()
103 |
104 | for group in self.param_groups:
105 |
106 | for p in group['params']:
107 | if p.grad is None:
108 | continue
109 | grad = p.grad.data.float()
110 | if grad.is_sparse:
111 | raise RuntimeError('RAdam does not support sparse gradients')
112 |
113 | p_data_fp32 = p.data.float()
114 |
115 | state = self.state[p]
116 |
117 | if len(state) == 0:
118 | state['step'] = 0
119 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
121 | else:
122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
124 |
125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
126 | beta1, beta2 = group['betas']
127 |
128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
129 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
130 |
131 | state['step'] += 1
132 | beta2_t = beta2 ** state['step']
133 | N_sma_max = 2 / (1 - beta2) - 1
134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
135 |
136 | if group['weight_decay'] != 0:
137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
138 |
139 | # more conservative since it's an approximated value
140 | if N_sma >= 5:
141 | step_size = group['lr'] * math.sqrt(
142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
143 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
144 | denom = exp_avg_sq.sqrt().add_(group['eps'])
145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
146 | else:
147 | step_size = group['lr'] / (1 - beta1 ** state['step'])
148 | p_data_fp32.add_(-step_size, exp_avg)
149 |
150 | p.data.copy_(p_data_fp32)
151 |
152 | return loss
153 |
--------------------------------------------------------------------------------
/timm/optim/rmsprop_tf.py:
--------------------------------------------------------------------------------
1 | """ RMSProp modified to behave like Tensorflow impl
2 |
3 | Originally cut & paste from PyTorch RMSProp
4 | https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5 | Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6 |
7 | Modifications Copyright 2020 Ross Wightman
8 | """
9 |
10 | import torch
11 | from torch.optim import Optimizer
12 |
13 |
14 | class RMSpropTF(Optimizer):
15 | """Implements RMSprop algorithm (TensorFlow style epsilon)
16 |
17 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18 | and a few other modifications to closer match Tensorflow for matching hyper-params.
19 |
20 | Noteworthy changes include:
21 | 1. Epsilon applied inside square-root
22 | 2. square_avg initialized to ones
23 | 3. LR scaling of update accumulated in momentum buffer
24 |
25 | Proposed by G. Hinton in his
26 | `course `_.
27 |
28 | The centered version first appears in `Generating Sequences
29 | With Recurrent Neural Networks `_.
30 |
31 | Arguments:
32 | params (iterable): iterable of parameters to optimize or dicts defining
33 | parameter groups
34 | lr (float, optional): learning rate (default: 1e-2)
35 | momentum (float, optional): momentum factor (default: 0)
36 | alpha (float, optional): smoothing (decay) constant (default: 0.9)
37 | eps (float, optional): term added to the denominator to improve
38 | numerical stability (default: 1e-10)
39 | centered (bool, optional) : if ``True``, compute the centered RMSProp,
40 | the gradient is normalized by an estimation of its variance
41 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44 | update as per defaults in Tensorflow
45 |
46 | """
47 |
48 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
49 | decoupled_decay=False, lr_in_momentum=True):
50 | if not 0.0 <= lr:
51 | raise ValueError("Invalid learning rate: {}".format(lr))
52 | if not 0.0 <= eps:
53 | raise ValueError("Invalid epsilon value: {}".format(eps))
54 | if not 0.0 <= momentum:
55 | raise ValueError("Invalid momentum value: {}".format(momentum))
56 | if not 0.0 <= weight_decay:
57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
58 | if not 0.0 <= alpha:
59 | raise ValueError("Invalid alpha value: {}".format(alpha))
60 |
61 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
62 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
63 | super(RMSpropTF, self).__init__(params, defaults)
64 |
65 | def __setstate__(self, state):
66 | super(RMSpropTF, self).__setstate__(state)
67 | for group in self.param_groups:
68 | group.setdefault('momentum', 0)
69 | group.setdefault('centered', False)
70 |
71 | def step(self, closure=None):
72 | """Performs a single optimization step.
73 |
74 | Arguments:
75 | closure (callable, optional): A closure that reevaluates the model
76 | and returns the loss.
77 | """
78 | loss = None
79 | if closure is not None:
80 | loss = closure()
81 |
82 | for group in self.param_groups:
83 | for p in group['params']:
84 | if p.grad is None:
85 | continue
86 | grad = p.grad.data
87 | if grad.is_sparse:
88 | raise RuntimeError('RMSprop does not support sparse gradients')
89 | state = self.state[p]
90 |
91 | # State initialization
92 | if len(state) == 0:
93 | state['step'] = 0
94 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
95 | if group['momentum'] > 0:
96 | state['momentum_buffer'] = torch.zeros_like(p.data)
97 | if group['centered']:
98 | state['grad_avg'] = torch.zeros_like(p.data)
99 |
100 | square_avg = state['square_avg']
101 | one_minus_alpha = 1. - group['alpha']
102 |
103 | state['step'] += 1
104 |
105 | if group['weight_decay'] != 0:
106 | if 'decoupled_decay' in group and group['decoupled_decay']:
107 | p.data.add_(-group['weight_decay'], p.data)
108 | else:
109 | grad = grad.add(group['weight_decay'], p.data)
110 |
111 | # Tensorflow order of ops for updating squared avg
112 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
113 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
114 |
115 | if group['centered']:
116 | grad_avg = state['grad_avg']
117 | grad_avg.add_(one_minus_alpha, grad - grad_avg)
118 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
119 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt
120 | else:
121 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
122 |
123 | if group['momentum'] > 0:
124 | buf = state['momentum_buffer']
125 | # Tensorflow accumulates the LR scaling in the momentum buffer
126 | if 'lr_in_momentum' in group and group['lr_in_momentum']:
127 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
128 | p.data.add_(-buf)
129 | else:
130 | # PyTorch scales the param update by LR
131 | buf.mul_(group['momentum']).addcdiv_(grad, avg)
132 | p.data.add_(-group['lr'], buf)
133 | else:
134 | p.data.addcdiv_(-group['lr'], grad, avg)
135 |
136 | return loss
137 |
--------------------------------------------------------------------------------
/timm/optim/sgdp.py:
--------------------------------------------------------------------------------
1 | """
2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | class SGDP(Optimizer):
17 | def __init__(self, params, lr=required, momentum=0, dampening=0,
18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
21 | super(SGDP, self).__init__(params, defaults)
22 |
23 | def _channel_view(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | def _layer_view(self, x):
27 | return x.view(1, -1)
28 |
29 | def _cosine_similarity(self, x, y, eps, view_func):
30 | x = view_func(x)
31 | y = view_func(y)
32 |
33 | x_norm = x.norm(dim=1).add_(eps)
34 | y_norm = y.norm(dim=1).add_(eps)
35 | dot = (x * y).sum(dim=1)
36 |
37 | return dot.abs() / x_norm / y_norm
38 |
39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40 | wd = 1
41 | expand_size = [-1] + [1] * (len(p.shape) - 1)
42 | for view_func in [self._channel_view, self._layer_view]:
43 |
44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45 |
46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49 | wd = wd_ratio
50 |
51 | return perturb, wd
52 |
53 | return perturb, wd
54 |
55 | def step(self, closure=None):
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group in self.param_groups:
61 | weight_decay = group['weight_decay']
62 | momentum = group['momentum']
63 | dampening = group['dampening']
64 | nesterov = group['nesterov']
65 |
66 | for p in group['params']:
67 | if p.grad is None:
68 | continue
69 | grad = p.grad.data
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state['momentum'] = torch.zeros_like(p.data)
75 |
76 | # SGD
77 | buf = state['momentum']
78 | buf.mul_(momentum).add_(1 - dampening, grad)
79 | if nesterov:
80 | d_p = grad + momentum * buf
81 | else:
82 | d_p = buf
83 |
84 | # Projection
85 | wd_ratio = 1
86 | if len(p.shape) > 1:
87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
88 |
89 | # Weight decay
90 | if weight_decay != 0:
91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
92 |
93 | # Step
94 | p.data.add_(-group['lr'], d_p)
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/timm/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_lr import CosineLRScheduler
2 | from .plateau_lr import PlateauLRScheduler
3 | from .step_lr import StepLRScheduler
4 | from .tanh_lr import TanhLRScheduler
5 | from .scheduler_factory import create_scheduler
6 |
--------------------------------------------------------------------------------
/timm/scheduler/cosine_lr.py:
--------------------------------------------------------------------------------
1 | """ Cosine Scheduler
2 |
3 | Cosine LR schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 |
15 | _logger = logging.getLogger(__name__)
16 |
17 |
18 | class CosineLRScheduler(Scheduler):
19 | """
20 | Cosine decay with restarts.
21 | This is described in the paper https://arxiv.org/abs/1608.03983.
22 |
23 | Inspiration from
24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
25 | """
26 |
27 | def __init__(self,
28 | optimizer: torch.optim.Optimizer,
29 | t_initial: int,
30 | t_mul: float = 1.,
31 | lr_min: float = 0.,
32 | decay_rate: float = 1.,
33 | warmup_t=0,
34 | warmup_lr_init=0,
35 | warmup_prefix=False,
36 | cycle_limit=0,
37 | t_in_epochs=True,
38 | noise_range_t=None,
39 | noise_pct=0.67,
40 | noise_std=1.0,
41 | noise_seed=42,
42 | initialize=True) -> None:
43 | super().__init__(
44 | optimizer, param_group_field="lr",
45 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
46 | initialize=initialize)
47 |
48 | assert t_initial > 0
49 | assert lr_min >= 0
50 | if t_initial == 1 and t_mul == 1 and decay_rate == 1:
51 | _logger.warning("Cosine annealing scheduler will have no effect on the learning "
52 | "rate since t_initial = t_mul = eta_mul = 1.")
53 | self.t_initial = t_initial
54 | self.t_mul = t_mul
55 | self.lr_min = lr_min
56 | self.decay_rate = decay_rate
57 | self.cycle_limit = cycle_limit
58 | self.warmup_t = warmup_t
59 | self.warmup_lr_init = warmup_lr_init
60 | self.warmup_prefix = warmup_prefix
61 | self.t_in_epochs = t_in_epochs
62 | if self.warmup_t:
63 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
64 | super().update_groups(self.warmup_lr_init)
65 | else:
66 | self.warmup_steps = [1 for _ in self.base_values]
67 |
68 | def _get_lr(self, t):
69 | if t < self.warmup_t:
70 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
71 | else:
72 | if self.warmup_prefix:
73 | t = t - self.warmup_t
74 |
75 | if self.t_mul != 1:
76 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
77 | t_i = self.t_mul ** i * self.t_initial
78 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
79 | else:
80 | i = t // self.t_initial
81 | t_i = self.t_initial
82 | t_curr = t - (self.t_initial * i)
83 |
84 | gamma = self.decay_rate ** i
85 | lr_min = self.lr_min * gamma
86 | lr_max_values = [v * gamma for v in self.base_values]
87 |
88 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
89 | lrs = [
90 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
91 | ]
92 | else:
93 | lrs = [self.lr_min for _ in self.base_values]
94 |
95 | return lrs
96 |
97 | def get_epoch_values(self, epoch: int):
98 | if self.t_in_epochs:
99 | return self._get_lr(epoch)
100 | else:
101 | return None
102 |
103 | def get_update_values(self, num_updates: int):
104 | if not self.t_in_epochs:
105 | return self._get_lr(num_updates)
106 | else:
107 | return None
108 |
109 | def get_cycle_length(self, cycles=0):
110 | if not cycles:
111 | cycles = self.cycle_limit
112 | cycles = max(1, cycles)
113 | if self.t_mul == 1.0:
114 | return self.t_initial * cycles
115 | else:
116 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
117 |
--------------------------------------------------------------------------------
/timm/scheduler/plateau_lr.py:
--------------------------------------------------------------------------------
1 | """ Plateau Scheduler
2 |
3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 |
9 | from .scheduler import Scheduler
10 |
11 |
12 | class PlateauLRScheduler(Scheduler):
13 | """Decay the LR by a factor every time the validation loss plateaus."""
14 |
15 | def __init__(self,
16 | optimizer,
17 | decay_rate=0.1,
18 | patience_t=10,
19 | verbose=True,
20 | threshold=1e-4,
21 | cooldown_t=0,
22 | warmup_t=0,
23 | warmup_lr_init=0,
24 | lr_min=0,
25 | mode='max',
26 | noise_range_t=None,
27 | noise_type='normal',
28 | noise_pct=0.67,
29 | noise_std=1.0,
30 | noise_seed=None,
31 | initialize=True,
32 | ):
33 | super().__init__(optimizer, 'lr', initialize=initialize)
34 |
35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
36 | self.optimizer,
37 | patience=patience_t,
38 | factor=decay_rate,
39 | verbose=verbose,
40 | threshold=threshold,
41 | cooldown=cooldown_t,
42 | mode=mode,
43 | min_lr=lr_min
44 | )
45 |
46 | self.noise_range = noise_range_t
47 | self.noise_pct = noise_pct
48 | self.noise_type = noise_type
49 | self.noise_std = noise_std
50 | self.noise_seed = noise_seed if noise_seed is not None else 42
51 | self.warmup_t = warmup_t
52 | self.warmup_lr_init = warmup_lr_init
53 | if self.warmup_t:
54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
55 | super().update_groups(self.warmup_lr_init)
56 | else:
57 | self.warmup_steps = [1 for _ in self.base_values]
58 | self.restore_lr = None
59 |
60 | def state_dict(self):
61 | return {
62 | 'best': self.lr_scheduler.best,
63 | 'last_epoch': self.lr_scheduler.last_epoch,
64 | }
65 |
66 | def load_state_dict(self, state_dict):
67 | self.lr_scheduler.best = state_dict['best']
68 | if 'last_epoch' in state_dict:
69 | self.lr_scheduler.last_epoch = state_dict['last_epoch']
70 |
71 | # override the base class step fn completely
72 | def step(self, epoch, metric=None):
73 | if epoch <= self.warmup_t:
74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
75 | super().update_groups(lrs)
76 | else:
77 | if self.restore_lr is not None:
78 | # restore actual LR from before our last noise perturbation before stepping base
79 | for i, param_group in enumerate(self.optimizer.param_groups):
80 | param_group['lr'] = self.restore_lr[i]
81 | self.restore_lr = None
82 |
83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler
84 |
85 | if self.noise_range is not None:
86 | if isinstance(self.noise_range, (list, tuple)):
87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
88 | else:
89 | apply_noise = epoch >= self.noise_range
90 | if apply_noise:
91 | self._apply_noise(epoch)
92 |
93 | def _apply_noise(self, epoch):
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + epoch)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 |
105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal
106 | # stepping of base scheduler
107 | restore_lr = []
108 | for i, param_group in enumerate(self.optimizer.param_groups):
109 | old_lr = float(param_group['lr'])
110 | restore_lr.append(old_lr)
111 | new_lr = old_lr + old_lr * noise
112 | param_group['lr'] = new_lr
113 | self.restore_lr = restore_lr
114 |
--------------------------------------------------------------------------------
/timm/scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import torch
4 |
5 |
6 | class Scheduler:
7 | """ Parameter Scheduler Base Class
8 | A scheduler base class that can be used to schedule any optimizer parameter groups.
9 |
10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13 |
14 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
15 |
16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19 |
20 | Based on ideas from:
21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23 | """
24 |
25 | def __init__(self,
26 | optimizer: torch.optim.Optimizer,
27 | param_group_field: str,
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize: bool = True) -> None:
34 | self.optimizer = optimizer
35 | self.param_group_field = param_group_field
36 | self._initial_param_group_field = f"initial_{param_group_field}"
37 | if initialize:
38 | for i, group in enumerate(self.optimizer.param_groups):
39 | if param_group_field not in group:
40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
41 | group.setdefault(self._initial_param_group_field, group[param_group_field])
42 | else:
43 | for i, group in enumerate(self.optimizer.param_groups):
44 | if self._initial_param_group_field not in group:
45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
47 | self.metric = None # any point to having this for all?
48 | self.noise_range_t = noise_range_t
49 | self.noise_pct = noise_pct
50 | self.noise_type = noise_type
51 | self.noise_std = noise_std
52 | self.noise_seed = noise_seed if noise_seed is not None else 42
53 | self.update_groups(self.base_values)
54 |
55 | def state_dict(self) -> Dict[str, Any]:
56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
57 |
58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
59 | self.__dict__.update(state_dict)
60 |
61 | def get_epoch_values(self, epoch: int):
62 | return None
63 |
64 | def get_update_values(self, num_updates: int):
65 | return None
66 |
67 | def step(self, epoch: int, metric: float = None) -> None:
68 | self.metric = metric
69 | values = self.get_epoch_values(epoch)
70 | if values is not None:
71 | values = self._add_noise(values, epoch)
72 | self.update_groups(values)
73 |
74 | def step_update(self, num_updates: int, metric: float = None):
75 | self.metric = metric
76 | values = self.get_update_values(num_updates)
77 | if values is not None:
78 | values = self._add_noise(values, num_updates)
79 | self.update_groups(values)
80 |
81 | def update_groups(self, values):
82 | if not isinstance(values, (list, tuple)):
83 | values = [values] * len(self.optimizer.param_groups)
84 | for param_group, value in zip(self.optimizer.param_groups, values):
85 | param_group[self.param_group_field] = value
86 |
87 | def _add_noise(self, lrs, t):
88 | if self.noise_range_t is not None:
89 | if isinstance(self.noise_range_t, (list, tuple)):
90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91 | else:
92 | apply_noise = t >= self.noise_range_t
93 | if apply_noise:
94 | g = torch.Generator()
95 | g.manual_seed(self.noise_seed + t)
96 | if self.noise_type == 'normal':
97 | while True:
98 | # resample if noise out of percent limit, brute force but shouldn't spin much
99 | noise = torch.randn(1, generator=g).item()
100 | if abs(noise) < self.noise_pct:
101 | break
102 | else:
103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104 | lrs = [v + v * noise for v in lrs]
105 | return lrs
106 |
--------------------------------------------------------------------------------
/timm/scheduler/scheduler_factory.py:
--------------------------------------------------------------------------------
1 | """ Scheduler Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from .cosine_lr import CosineLRScheduler
5 | from .tanh_lr import TanhLRScheduler
6 | from .step_lr import StepLRScheduler
7 | from .plateau_lr import PlateauLRScheduler
8 |
9 |
10 | def create_scheduler(args, optimizer):
11 | num_epochs = args.epochs
12 |
13 | if getattr(args, 'lr_noise', None) is not None:
14 | lr_noise = getattr(args, 'lr_noise')
15 | if isinstance(lr_noise, (list, tuple)):
16 | noise_range = [n * num_epochs for n in lr_noise]
17 | if len(noise_range) == 1:
18 | noise_range = noise_range[0]
19 | else:
20 | noise_range = lr_noise * num_epochs
21 | else:
22 | noise_range = None
23 |
24 | lr_scheduler = None
25 | if args.sched == 'cosine':
26 | lr_scheduler = CosineLRScheduler(
27 | optimizer,
28 | t_initial=num_epochs,
29 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
30 | lr_min=args.min_lr,
31 | decay_rate=args.decay_rate,
32 | warmup_lr_init=args.warmup_lr,
33 | warmup_t=args.warmup_epochs,
34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
35 | t_in_epochs=True,
36 | noise_range_t=noise_range,
37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
38 | noise_std=getattr(args, 'lr_noise_std', 1.),
39 | noise_seed=getattr(args, 'seed', 42),
40 | )
41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
42 | elif args.sched == 'tanh':
43 | lr_scheduler = TanhLRScheduler(
44 | optimizer,
45 | t_initial=num_epochs,
46 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
47 | lr_min=args.min_lr,
48 | warmup_lr_init=args.warmup_lr,
49 | warmup_t=args.warmup_epochs,
50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
51 | t_in_epochs=True,
52 | noise_range_t=noise_range,
53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
54 | noise_std=getattr(args, 'lr_noise_std', 1.),
55 | noise_seed=getattr(args, 'seed', 42),
56 | )
57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
58 | elif args.sched == 'step':
59 | lr_scheduler = StepLRScheduler(
60 | optimizer,
61 | decay_t=args.decay_epochs,
62 | decay_rate=args.decay_rate,
63 | warmup_lr_init=args.warmup_lr,
64 | warmup_t=args.warmup_epochs,
65 | noise_range_t=noise_range,
66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
67 | noise_std=getattr(args, 'lr_noise_std', 1.),
68 | noise_seed=getattr(args, 'seed', 42),
69 | )
70 | elif args.sched == 'plateau':
71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
72 | lr_scheduler = PlateauLRScheduler(
73 | optimizer,
74 | decay_rate=args.decay_rate,
75 | patience_t=args.patience_epochs,
76 | lr_min=args.min_lr,
77 | mode=mode,
78 | warmup_lr_init=args.warmup_lr,
79 | warmup_t=args.warmup_epochs,
80 | cooldown_t=0,
81 | noise_range_t=noise_range,
82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
83 | noise_std=getattr(args, 'lr_noise_std', 1.),
84 | noise_seed=getattr(args, 'seed', 42),
85 | )
86 |
87 | return lr_scheduler, num_epochs
88 |
--------------------------------------------------------------------------------
/timm/scheduler/step_lr.py:
--------------------------------------------------------------------------------
1 | """ Step Scheduler
2 |
3 | Basic step LR schedule with warmup, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import math
8 | import torch
9 |
10 | from .scheduler import Scheduler
11 |
12 |
13 | class StepLRScheduler(Scheduler):
14 | """
15 | """
16 |
17 | def __init__(self,
18 | optimizer: torch.optim.Optimizer,
19 | decay_t: float,
20 | decay_rate: float = 1.,
21 | warmup_t=0,
22 | warmup_lr_init=0,
23 | t_in_epochs=True,
24 | noise_range_t=None,
25 | noise_pct=0.67,
26 | noise_std=1.0,
27 | noise_seed=42,
28 | initialize=True,
29 | ) -> None:
30 | super().__init__(
31 | optimizer, param_group_field="lr",
32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
33 | initialize=initialize)
34 |
35 | self.decay_t = decay_t
36 | self.decay_rate = decay_rate
37 | self.warmup_t = warmup_t
38 | self.warmup_lr_init = warmup_lr_init
39 | self.t_in_epochs = t_in_epochs
40 | if self.warmup_t:
41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
42 | super().update_groups(self.warmup_lr_init)
43 | else:
44 | self.warmup_steps = [1 for _ in self.base_values]
45 |
46 | def _get_lr(self, t):
47 | if t < self.warmup_t:
48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
49 | else:
50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
51 | return lrs
52 |
53 | def get_epoch_values(self, epoch: int):
54 | if self.t_in_epochs:
55 | return self._get_lr(epoch)
56 | else:
57 | return None
58 |
59 | def get_update_values(self, num_updates: int):
60 | if not self.t_in_epochs:
61 | return self._get_lr(num_updates)
62 | else:
63 | return None
64 |
--------------------------------------------------------------------------------
/timm/scheduler/tanh_lr.py:
--------------------------------------------------------------------------------
1 | """ TanH Scheduler
2 |
3 | TanH schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 | from .scheduler import Scheduler
13 |
14 |
15 | _logger = logging.getLogger(__name__)
16 |
17 |
18 | class TanhLRScheduler(Scheduler):
19 | """
20 | Hyberbolic-Tangent decay with restarts.
21 | This is described in the paper https://arxiv.org/abs/1806.01593
22 | """
23 |
24 | def __init__(self,
25 | optimizer: torch.optim.Optimizer,
26 | t_initial: int,
27 | lb: float = -6.,
28 | ub: float = 4.,
29 | t_mul: float = 1.,
30 | lr_min: float = 0.,
31 | decay_rate: float = 1.,
32 | warmup_t=0,
33 | warmup_lr_init=0,
34 | warmup_prefix=False,
35 | cycle_limit=0,
36 | t_in_epochs=True,
37 | noise_range_t=None,
38 | noise_pct=0.67,
39 | noise_std=1.0,
40 | noise_seed=42,
41 | initialize=True) -> None:
42 | super().__init__(
43 | optimizer, param_group_field="lr",
44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
45 | initialize=initialize)
46 |
47 | assert t_initial > 0
48 | assert lr_min >= 0
49 | assert lb < ub
50 | assert cycle_limit >= 0
51 | assert warmup_t >= 0
52 | assert warmup_lr_init >= 0
53 | self.lb = lb
54 | self.ub = ub
55 | self.t_initial = t_initial
56 | self.t_mul = t_mul
57 | self.lr_min = lr_min
58 | self.decay_rate = decay_rate
59 | self.cycle_limit = cycle_limit
60 | self.warmup_t = warmup_t
61 | self.warmup_lr_init = warmup_lr_init
62 | self.warmup_prefix = warmup_prefix
63 | self.t_in_epochs = t_in_epochs
64 | if self.warmup_t:
65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
67 | super().update_groups(self.warmup_lr_init)
68 | else:
69 | self.warmup_steps = [1 for _ in self.base_values]
70 |
71 | def _get_lr(self, t):
72 | if t < self.warmup_t:
73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
74 | else:
75 | if self.warmup_prefix:
76 | t = t - self.warmup_t
77 |
78 | if self.t_mul != 1:
79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
80 | t_i = self.t_mul ** i * self.t_initial
81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
82 | else:
83 | i = t // self.t_initial
84 | t_i = self.t_initial
85 | t_curr = t - (self.t_initial * i)
86 |
87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
88 | gamma = self.decay_rate ** i
89 | lr_min = self.lr_min * gamma
90 | lr_max_values = [v * gamma for v in self.base_values]
91 |
92 | tr = t_curr / t_i
93 | lrs = [
94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
95 | for lr_max in lr_max_values
96 | ]
97 | else:
98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
99 | return lrs
100 |
101 | def get_epoch_values(self, epoch: int):
102 | if self.t_in_epochs:
103 | return self._get_lr(epoch)
104 | else:
105 | return None
106 |
107 | def get_update_values(self, num_updates: int):
108 | if not self.t_in_epochs:
109 | return self._get_lr(num_updates)
110 | else:
111 | return None
112 |
113 | def get_cycle_length(self, cycles=0):
114 | if not cycles:
115 | cycles = self.cycle_limit
116 | cycles = max(1, cycles)
117 | if self.t_mul == 1.0:
118 | return self.t_initial * cycles
119 | else:
120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
121 |
--------------------------------------------------------------------------------
/timm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .agc import adaptive_clip_grad
2 | from .checkpoint_saver import CheckpointSaver
3 | from .clip_grad import dispatch_clip_grad
4 | from .cuda import ApexScaler, NativeScaler
5 | from .distributed import distribute_bn, reduce_tensor
6 | from .jit import set_jit_legacy
7 | from .log import setup_default_logging, FormatterNoInfo
8 | from .metrics import AverageMeter, accuracy
9 | from .misc import natural_key, add_bool_arg
10 | from .model import unwrap_model, get_state_dict
11 | from .model_ema import ModelEma, ModelEmaV2
12 | from .random import random_seed
13 | from .summary import update_summary, get_outdir
14 |
--------------------------------------------------------------------------------
/timm/utils/agc.py:
--------------------------------------------------------------------------------
1 | """ Adaptive Gradient Clipping
2 |
3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
4 |
5 | @article{brock2021high,
6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
7 | title={High-Performance Large-Scale Image Recognition Without Normalization},
8 | journal={arXiv preprint arXiv:},
9 | year={2021}
10 | }
11 |
12 | Code references:
13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
15 |
16 | Hacked together by / Copyright 2021 Ross Wightman
17 | """
18 | import torch
19 |
20 |
21 | def unitwise_norm(x, norm_type=2.0):
22 | if x.ndim <= 1:
23 | return x.norm(norm_type)
24 | else:
25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
26 | # might need special cases for other weights (possibly MHA) where this may not be true
27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
28 |
29 |
30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
31 | if isinstance(parameters, torch.Tensor):
32 | parameters = [parameters]
33 | for p in parameters:
34 | if p.grad is None:
35 | continue
36 | p_data = p.detach()
37 | g_data = p.grad.detach()
38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type)
40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
42 | p.grad.detach().copy_(new_grads)
43 |
--------------------------------------------------------------------------------
/timm/utils/checkpoint_saver.py:
--------------------------------------------------------------------------------
1 | """ Checkpoint Saver
2 |
3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import glob
9 | import operator
10 | import os
11 | import logging
12 |
13 | import torch
14 |
15 | from .model import unwrap_model, get_state_dict
16 |
17 |
18 | _logger = logging.getLogger(__name__)
19 |
20 |
21 | class CheckpointSaver:
22 | def __init__(
23 | self,
24 | model,
25 | optimizer,
26 | args=None,
27 | model_ema=None,
28 | amp_scaler=None,
29 | checkpoint_prefix='checkpoint',
30 | recovery_prefix='recovery',
31 | checkpoint_dir='',
32 | recovery_dir='',
33 | decreasing=False,
34 | max_history=10,
35 | unwrap_fn=unwrap_model):
36 |
37 | # objects to save state_dicts of
38 | self.model = model
39 | self.optimizer = optimizer
40 | self.args = args
41 | self.model_ema = model_ema
42 | self.amp_scaler = amp_scaler
43 |
44 | # state
45 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
46 | self.best_epoch = None
47 | self.best_metric = None
48 | self.curr_recovery_file = ''
49 | self.last_recovery_file = ''
50 |
51 | # config
52 | self.checkpoint_dir = checkpoint_dir
53 | self.recovery_dir = recovery_dir
54 | self.save_prefix = checkpoint_prefix
55 | self.recovery_prefix = recovery_prefix
56 | self.extension = '.pth.tar'
57 | self.decreasing = decreasing # a lower metric is better if True
58 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
59 | self.max_history = max_history
60 | self.unwrap_fn = unwrap_fn
61 | assert self.max_history >= 1
62 |
63 | def save_checkpoint(self, epoch, metric=None):
64 | assert epoch >= 0
65 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
66 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
67 | self._save(tmp_save_path, epoch, metric)
68 | if os.path.exists(last_save_path):
69 | os.unlink(last_save_path) # required for Windows support.
70 | os.rename(tmp_save_path, last_save_path)
71 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
72 | if (len(self.checkpoint_files) < self.max_history
73 | or metric is None or self.cmp(metric, worst_file[1])):
74 | if len(self.checkpoint_files) >= self.max_history:
75 | self._cleanup_checkpoints(1)
76 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
77 | save_path = os.path.join(self.checkpoint_dir, filename)
78 | os.link(last_save_path, save_path)
79 | self.checkpoint_files.append((save_path, metric))
80 | self.checkpoint_files = sorted(
81 | self.checkpoint_files, key=lambda x: x[1],
82 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better
83 |
84 | checkpoints_str = "Current checkpoints:\n"
85 | for c in self.checkpoint_files:
86 | checkpoints_str += ' {}\n'.format(c)
87 | _logger.info(checkpoints_str)
88 |
89 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
90 | self.best_epoch = epoch
91 | self.best_metric = metric
92 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
93 | if os.path.exists(best_save_path):
94 | os.unlink(best_save_path)
95 | os.link(last_save_path, best_save_path)
96 |
97 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
98 |
99 | def _save(self, save_path, epoch, metric=None):
100 | save_state = {
101 | 'epoch': epoch,
102 | 'arch': type(self.model).__name__.lower(),
103 | 'state_dict': get_state_dict(self.model, self.unwrap_fn),
104 | 'optimizer': self.optimizer.state_dict(),
105 | 'version': 2, # version < 2 increments epoch before save
106 | }
107 | if self.args is not None:
108 | save_state['arch'] = self.args.model
109 | save_state['args'] = self.args
110 | if self.amp_scaler is not None:
111 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
112 | if self.model_ema is not None:
113 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
114 | if metric is not None:
115 | save_state['metric'] = metric
116 | torch.save(save_state, save_path)
117 |
118 | def _cleanup_checkpoints(self, trim=0):
119 | trim = min(len(self.checkpoint_files), trim)
120 | delete_index = self.max_history - trim
121 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
122 | return
123 | to_delete = self.checkpoint_files[delete_index:]
124 | for d in to_delete:
125 | try:
126 | _logger.debug("Cleaning checkpoint: {}".format(d))
127 | os.remove(d[0])
128 | except Exception as e:
129 | _logger.error("Exception '{}' while deleting checkpoint".format(e))
130 | self.checkpoint_files = self.checkpoint_files[:delete_index]
131 |
132 | def save_recovery(self, epoch, batch_idx=0):
133 | assert epoch >= 0
134 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
135 | save_path = os.path.join(self.recovery_dir, filename)
136 | self._save(save_path, epoch)
137 | if os.path.exists(self.last_recovery_file):
138 | try:
139 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
140 | os.remove(self.last_recovery_file)
141 | except Exception as e:
142 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
143 | self.last_recovery_file = self.curr_recovery_file
144 | self.curr_recovery_file = save_path
145 |
146 | def find_recovery(self):
147 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
148 | files = glob.glob(recovery_path + '*' + self.extension)
149 | files = sorted(files)
150 | return files[0] if len(files) else ''
151 |
--------------------------------------------------------------------------------
/timm/utils/clip_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from timm.utils.agc import adaptive_clip_grad
4 |
5 |
6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
7 | """ Dispatch to gradient clipping method
8 |
9 | Args:
10 | parameters (Iterable): model parameters to clip
11 | value (float): clipping value/factor/norm, mode dependant
12 | mode (str): clipping mode, one of 'norm', 'value', 'agc'
13 | norm_type (float): p-norm, default 2.0
14 | """
15 | if mode == 'norm':
16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
17 | elif mode == 'value':
18 | torch.nn.utils.clip_grad_value_(parameters, value)
19 | elif mode == 'agc':
20 | adaptive_clip_grad(parameters, value, norm_type=norm_type)
21 | else:
22 | assert False, f"Unknown clip mode ({mode})."
23 |
24 |
--------------------------------------------------------------------------------
/timm/utils/cuda.py:
--------------------------------------------------------------------------------
1 | """ CUDA / AMP utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 |
7 | try:
8 | from apex import amp
9 | has_apex = True
10 | except ImportError:
11 | amp = None
12 | has_apex = False
13 |
14 | from .clip_grad import dispatch_clip_grad
15 |
16 |
17 | class ApexScaler:
18 | state_dict_key = "amp"
19 |
20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
21 | with amp.scale_loss(loss, optimizer) as scaled_loss:
22 | scaled_loss.backward(create_graph=create_graph)
23 | if clip_grad is not None:
24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
25 | optimizer.step()
26 |
27 | def state_dict(self):
28 | if 'state_dict' in amp.__dict__:
29 | return amp.state_dict()
30 |
31 | def load_state_dict(self, state_dict):
32 | if 'load_state_dict' in amp.__dict__:
33 | amp.load_state_dict(state_dict)
34 |
35 |
36 | class NativeScaler:
37 | state_dict_key = "amp_scaler"
38 |
39 | def __init__(self):
40 | self._scaler = torch.cuda.amp.GradScaler()
41 |
42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
43 | self._scaler.scale(loss).backward(create_graph=create_graph)
44 | if clip_grad is not None:
45 | assert parameters is not None
46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
48 | self._scaler.step(optimizer)
49 | self._scaler.update()
50 |
51 | def state_dict(self):
52 | return self._scaler.state_dict()
53 |
54 | def load_state_dict(self, state_dict):
55 | self._scaler.load_state_dict(state_dict)
56 |
--------------------------------------------------------------------------------
/timm/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """ Distributed training/validation utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | from torch import distributed as dist
7 |
8 | from .model import unwrap_model
9 |
10 |
11 | def reduce_tensor(tensor, n):
12 | rt = tensor.clone()
13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
14 | rt /= n
15 | return rt
16 |
17 |
18 | def distribute_bn(model, world_size, reduce=False):
19 | # ensure every node has the same running bn stats
20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
21 | if ('running_mean' in bn_name) or ('running_var' in bn_name):
22 | if reduce:
23 | # average bn stats across whole group
24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
25 | bn_buf /= float(world_size)
26 | else:
27 | # broadcast bn stats from rank 0 to whole group
28 | torch.distributed.broadcast(bn_buf, 0)
29 |
--------------------------------------------------------------------------------
/timm/utils/jit.py:
--------------------------------------------------------------------------------
1 | """ JIT scripting/tracing utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 |
7 |
8 | def set_jit_legacy():
9 | """ Set JIT executor to legacy w/ support for op fusion
10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
11 | in the JIT exectutor. These API are not supported so could change.
12 | """
13 | #
14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
15 | torch._C._jit_set_profiling_executor(False)
16 | torch._C._jit_set_profiling_mode(False)
17 | torch._C._jit_override_can_fuse_on_gpu(True)
18 | #torch._C._jit_set_texpr_fuser_enabled(True)
19 |
--------------------------------------------------------------------------------
/timm/utils/log.py:
--------------------------------------------------------------------------------
1 | """ Logging helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import logging
6 | import logging.handlers
7 |
8 |
9 | class FormatterNoInfo(logging.Formatter):
10 | def __init__(self, fmt='%(levelname)s: %(message)s'):
11 | logging.Formatter.__init__(self, fmt)
12 |
13 | def format(self, record):
14 | if record.levelno == logging.INFO:
15 | return str(record.getMessage())
16 | return logging.Formatter.format(self, record)
17 |
18 |
19 | def setup_default_logging(default_level=logging.INFO, log_path=''):
20 | console_handler = logging.StreamHandler()
21 | console_handler.setFormatter(FormatterNoInfo())
22 | logging.root.addHandler(console_handler)
23 | logging.root.setLevel(default_level)
24 | if log_path:
25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
27 | file_handler.setFormatter(file_formatter)
28 | logging.root.addHandler(file_handler)
29 |
--------------------------------------------------------------------------------
/timm/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """ Eval metrics and related
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 |
7 | class AverageMeter:
8 | """Computes and stores the average and current value"""
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the accuracy over the k top predictions for the specified values of k"""
27 | maxk = max(topk)
28 | batch_size = target.size(0)
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
32 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
33 |
--------------------------------------------------------------------------------
/timm/utils/misc.py:
--------------------------------------------------------------------------------
1 | """ Misc utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import re
6 |
7 |
8 | def natural_key(string_):
9 | """See http://www.codinghorror.com/blog/archives/001018.html"""
10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
11 |
12 |
13 | def add_bool_arg(parser, name, default=False, help=''):
14 | dest_name = name.replace('-', '_')
15 | group = parser.add_mutually_exclusive_group(required=False)
16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
18 | parser.set_defaults(**{dest_name: default})
19 |
--------------------------------------------------------------------------------
/timm/utils/model.py:
--------------------------------------------------------------------------------
1 | """ Model / state_dict utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from .model_ema import ModelEma
6 | import torch
7 | import fnmatch
8 |
9 | def unwrap_model(model):
10 | if isinstance(model, ModelEma):
11 | return unwrap_model(model.ema)
12 | else:
13 | return model.module if hasattr(model, 'module') else model
14 |
15 |
16 | def get_state_dict(model, unwrap_fn=unwrap_model):
17 | return unwrap_fn(model).state_dict()
18 |
19 |
20 | def avg_sq_ch_mean(model, input, output):
21 | "calculate average channel square mean of output activations"
22 | return torch.mean(output.mean(axis=[0,2,3])**2).item()
23 |
24 |
25 | def avg_ch_var(model, input, output):
26 | "calculate average channel variance of output activations"
27 | return torch.mean(output.var(axis=[0,2,3])).item()\
28 |
29 |
30 | def avg_ch_var_residual(model, input, output):
31 | "calculate average channel variance of output activations"
32 | return torch.mean(output.var(axis=[0,2,3])).item()
33 |
34 |
35 | class ActivationStatsHook:
36 | """Iterates through each of `model`'s modules and matches modules using unix pattern
37 | matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
38 | a match.
39 |
40 | Arguments:
41 | model (nn.Module): model from which we will extract the activation stats
42 | hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
43 | matching with the name of model's modules.
44 | hook_fns (List[Callable]): List of hook functions to be registered at every
45 | module in `layer_names`.
46 |
47 | Inspiration from https://docs.fast.ai/callback.hook.html.
48 |
49 | Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
50 | on how to plot Signal Propogation Plots using `ActivationStatsHook`.
51 | """
52 |
53 | def __init__(self, model, hook_fn_locs, hook_fns):
54 | self.model = model
55 | self.hook_fn_locs = hook_fn_locs
56 | self.hook_fns = hook_fns
57 | if len(hook_fn_locs) != len(hook_fns):
58 | raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
59 | their lengths are different.")
60 | self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
61 | for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
62 | self.register_hook(hook_fn_loc, hook_fn)
63 |
64 | def _create_hook(self, hook_fn):
65 | def append_activation_stats(module, input, output):
66 | out = hook_fn(module, input, output)
67 | self.stats[hook_fn.__name__].append(out)
68 | return append_activation_stats
69 |
70 | def register_hook(self, hook_fn_loc, hook_fn):
71 | for name, module in self.model.named_modules():
72 | if not fnmatch.fnmatch(name, hook_fn_loc):
73 | continue
74 | module.register_forward_hook(self._create_hook(hook_fn))
75 |
76 |
77 | def extract_spp_stats(model,
78 | hook_fn_locs,
79 | hook_fns,
80 | input_shape=[8, 3, 224, 224]):
81 | """Extract average square channel mean and variance of activations during
82 | forward pass to plot Signal Propogation Plots (SPP).
83 |
84 | Paper: https://arxiv.org/abs/2101.08692
85 |
86 | Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
87 | """
88 | x = torch.normal(0., 1., input_shape)
89 | hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
90 | _ = model(x)
91 | return hook.stats
92 |
--------------------------------------------------------------------------------
/timm/utils/model_ema.py:
--------------------------------------------------------------------------------
1 | """ Exponential Moving Average (EMA) of model updates
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import logging
6 | from collections import OrderedDict
7 | from copy import deepcopy
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | _logger = logging.getLogger(__name__)
13 |
14 |
15 | class ModelEma:
16 | """ Model Exponential Moving Average (DEPRECATED)
17 |
18 | Keep a moving average of everything in the model state_dict (parameters and buffers).
19 | This version is deprecated, it does not work with scripted models. Will be removed eventually.
20 |
21 | This is intended to allow functionality like
22 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
23 |
24 | A smoothed version of the weights is necessary for some training schemes to perform well.
25 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
26 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
27 | smoothing of weights to match results. Pay attention to the decay constant you are using
28 | relative to your update count per epoch.
29 |
30 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
31 | disable validation of the EMA weights. Validation will have to be done manually in a separate
32 | process, or after the training stops converging.
33 |
34 | This class is sensitive where it is initialized in the sequence of model init,
35 | GPU assignment and distributed training wrappers.
36 | """
37 | def __init__(self, model, decay=0.9999, device='', resume=''):
38 | # make a copy of the model for accumulating moving average of weights
39 | self.ema = deepcopy(model)
40 | self.ema.eval()
41 | self.decay = decay
42 | self.device = device # perform ema on different device from model if set
43 | if device:
44 | self.ema.to(device=device)
45 | self.ema_has_module = hasattr(self.ema, 'module')
46 | if resume:
47 | self._load_checkpoint(resume)
48 | for p in self.ema.parameters():
49 | p.requires_grad_(False)
50 |
51 | def _load_checkpoint(self, checkpoint_path):
52 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
53 | assert isinstance(checkpoint, dict)
54 | if 'state_dict_ema' in checkpoint:
55 | new_state_dict = OrderedDict()
56 | for k, v in checkpoint['state_dict_ema'].items():
57 | # ema model may have been wrapped by DataParallel, and need module prefix
58 | if self.ema_has_module:
59 | name = 'module.' + k if not k.startswith('module') else k
60 | else:
61 | name = k
62 | new_state_dict[name] = v
63 | self.ema.load_state_dict(new_state_dict)
64 | _logger.info("Loaded state_dict_ema")
65 | else:
66 | _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
67 |
68 | def update(self, model):
69 | # correct a mismatch in state dict keys
70 | needs_module = hasattr(model, 'module') and not self.ema_has_module
71 | with torch.no_grad():
72 | msd = model.state_dict()
73 | for k, ema_v in self.ema.state_dict().items():
74 | if needs_module:
75 | k = 'module.' + k
76 | model_v = msd[k].detach()
77 | if self.device:
78 | model_v = model_v.to(device=self.device)
79 | ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
80 |
81 |
82 | class ModelEmaV2(nn.Module):
83 | """ Model Exponential Moving Average V2
84 |
85 | Keep a moving average of everything in the model state_dict (parameters and buffers).
86 | V2 of this module is simpler, it does not match params/buffers based on name but simply
87 | iterates in order. It works with torchscript (JIT of full model).
88 |
89 | This is intended to allow functionality like
90 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
91 |
92 | A smoothed version of the weights is necessary for some training schemes to perform well.
93 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
94 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
95 | smoothing of weights to match results. Pay attention to the decay constant you are using
96 | relative to your update count per epoch.
97 |
98 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
99 | disable validation of the EMA weights. Validation will have to be done manually in a separate
100 | process, or after the training stops converging.
101 |
102 | This class is sensitive where it is initialized in the sequence of model init,
103 | GPU assignment and distributed training wrappers.
104 | """
105 | def __init__(self, model, decay=0.9999, device=None):
106 | super(ModelEmaV2, self).__init__()
107 | # make a copy of the model for accumulating moving average of weights
108 | self.module = deepcopy(model)
109 | self.module.eval()
110 | self.decay = decay
111 | self.device = device # perform ema on different device from model if set
112 | if self.device is not None:
113 | self.module.to(device=device)
114 |
115 | def _update(self, model, update_fn):
116 | with torch.no_grad():
117 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
118 | if self.device is not None:
119 | model_v = model_v.to(device=self.device)
120 | ema_v.copy_(update_fn(ema_v, model_v))
121 |
122 | def update(self, model):
123 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
124 |
125 | def set(self, model):
126 | self._update(model, update_fn=lambda e, m: m)
127 |
--------------------------------------------------------------------------------
/timm/utils/random.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def random_seed(seed=42, rank=0):
7 | torch.manual_seed(seed + rank)
8 | np.random.seed(seed + rank)
9 | random.seed(seed + rank)
10 |
--------------------------------------------------------------------------------
/timm/utils/summary.py:
--------------------------------------------------------------------------------
1 | """ Summary utilities
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import csv
6 | import os
7 | from collections import OrderedDict
8 | try:
9 | import wandb
10 | except ImportError:
11 | pass
12 |
13 | def get_outdir(path, *paths, inc=False):
14 | outdir = os.path.join(path, *paths)
15 | if not os.path.exists(outdir):
16 | os.makedirs(outdir)
17 | elif inc:
18 | count = 1
19 | outdir_inc = outdir + '-' + str(count)
20 | while os.path.exists(outdir_inc):
21 | count = count + 1
22 | outdir_inc = outdir + '-' + str(count)
23 | assert count < 100
24 | outdir = outdir_inc
25 | os.makedirs(outdir)
26 | return outdir
27 |
28 |
29 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
30 | rowd = OrderedDict(epoch=epoch)
31 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
32 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
33 | if log_wandb:
34 | wandb.log(rowd)
35 | with open(filename, mode='a') as cf:
36 | dw = csv.DictWriter(cf, fieldnames=rowd.keys())
37 | if write_header: # first iteration (epoch == 1 can't be used)
38 | dw.writeheader()
39 | dw.writerow(rowd)
40 |
--------------------------------------------------------------------------------
/timm/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.4.12'
2 |
--------------------------------------------------------------------------------
/utils_evaluate.py:
--------------------------------------------------------------------------------
1 | '''
2 | Adapted from https://github.com/lupantech/ScienceQA
3 | '''
4 |
5 | import os
6 | import json
7 | import argparse
8 | import warnings
9 | import pandas as pd
10 | from sentence_transformers import SentenceTransformer
11 | from evaluations import caculate_bleu, caculate_rouge, caculate_similariry
12 |
13 | warnings.filterwarnings('ignore')
14 |
15 | def get_acc_with_contion(res_pd, key, values):
16 | if isinstance(values, list):
17 | total_pd = res_pd[res_pd[key].isin(values)]
18 | else:
19 | total_pd = res_pd[res_pd[key] == values]
20 | correct_pd = total_pd[total_pd['true_false'] == True]
21 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
22 | return acc
23 |
24 |
25 | def get_scores(result_data, rationale_data, results_reference, data_file):
26 | # read result file
27 | results = result_data
28 | num = len(results)
29 | assert num == 4241
30 | #print("number of questions:", num)
31 |
32 | # read data file
33 | sqa_data = json.load(open(data_file))
34 |
35 | # construct pandas data
36 | sqa_pd = pd.DataFrame(sqa_data).T
37 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
38 |
39 | # update data
40 | for index, row in res_pd.iterrows():
41 |
42 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
43 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False
44 | res_pd.loc[index, 'has_image'] = True if row['image'] else False
45 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
46 |
47 | label = row['answer']
48 | pred = int(results[index])
49 | res_pd.loc[index, 'pred'] = pred
50 | res_pd.loc[index, 'true_false'] = (label == pred)
51 |
52 | # accuracy scores
53 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
54 | #assert result_file.split('_')[-1] == "{:.3f}.json".format(acc_average)
55 |
56 |
57 | # rationale quality
58 |
59 | ## BLEU
60 | bleu1 = caculate_bleu(rationale_data, results_reference, gram=1)
61 | bleu4 = caculate_bleu(rationale_data, results_reference, gram=4)
62 |
63 | ## Rouge-L
64 | rouge = caculate_rouge(rationale_data, results_reference)
65 |
66 | ## Similarity
67 | model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
68 | similariry = caculate_similariry(rationale_data, results_reference, model)
69 |
70 | scores = {
71 | "answer":{
72 | 'acc_natural':
73 | get_acc_with_contion(res_pd, 'subject', 'natural science'),
74 | 'acc_social':
75 | get_acc_with_contion(res_pd, 'subject', 'social science'),
76 | 'acc_language':
77 | get_acc_with_contion(res_pd, 'subject', 'language science'),
78 | 'acc_has_text':
79 | get_acc_with_contion(res_pd, 'has_text', True),
80 | 'acc_has_image':
81 | get_acc_with_contion(res_pd, 'has_image', True),
82 | 'acc_no_context':
83 | get_acc_with_contion(res_pd, 'no_context', True),
84 | 'acc_grade_1_6':
85 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
86 | 'acc_grade_7_12':
87 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
88 | 'acc_average':
89 | "{:.2f}".format(acc_average),
90 | },
91 | "rationale":{
92 | 'bleu1': bleu1 * 100,
93 | 'bleu4': bleu4 * 100,
94 | 'rouge': rouge * 100,
95 | 'similariry': similariry * 100,
96 | }
97 | }
98 |
99 | return scores
100 |
101 |
102 | def print_scores(scores):
103 | latex_output = ""
104 | for key, score in scores.items():
105 | print(f"{key[4:]}: \t{score}")
106 | latex_output += f"& {score} "
107 | latex_output += "\\\\"
108 | print(latex_output)
109 |
--------------------------------------------------------------------------------
/vision_features/mm-cot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/mm-cot/8dd4ac02b94f21347973491f6e6b828502d23f9d/vision_features/mm-cot.png
--------------------------------------------------------------------------------