├── LICENSE ├── README.md ├── benchmark.py ├── dataset ├── video_dataset.py ├── video_dataset_aug.py ├── video_dataset_config.py └── video_transforms.py ├── engine.py ├── hubconf.py ├── losses.py ├── main.py ├── models.py ├── my_models ├── __init__.py ├── action_conv.py ├── sifar_swin.py └── sifar_util.py ├── requirements.txt ├── samplers.py ├── sifar.png ├── simclr.py ├── tools ├── convert_contrastive_model.py └── convert_vit_model_to_diff_input_size.py ├── utils.py ├── video_dataset.py ├── video_dataset_aug.py ├── video_dataset_config.py └── video_transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIFAR: Super Image for Action Recognition 2 | 3 | This repository contains a PyTorch implementation of SIFAR, an approach that repurposes image classifiers for efficient action recognition by rearranging input video frames into super images. 4 | 5 | ![sifar image](./sifar.png) 6 | 7 | For details please see the work, [Can An Image classifier Suffice for Action Recognition?](https://openreview.net/pdf?id=qhkFX-HLuHV) by Quanfu Fan*, Richard Chen* and Rameswar Panda*. 8 | 9 | If you use this code for a paper please cite: 10 | 11 | ``` 12 | @INPROCEEDINGS{fan-iclr2022, 13 | title={Can an Image Classifier Suffice for Action Recognition?}, 14 | author={Quanfu Fan, Richard Chen, Rameswar Panda}, 15 | booktitle={International Conference on Learning Representations (ICLR)}, 16 | year={2022} 17 | } 18 | ``` 19 | 20 | # Usage 21 | 22 | First, clone the repository locally: 23 | ``` 24 | git clone https://github.com/IBM/sifar-pytorch 25 | ``` 26 | ## Requirements 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | To load video input, you need to install the [PyAV package](https://pyav.org/docs/develop/overview/installation.htmlgit). 32 | 33 | # Data Preparation 34 | Please refer to https://github.com/IBM/action-recognition-pytorch for how to prepare action recognition benchmark datasets such as Kinetics400 and Something-to-Something. For Kinetics400, we used the urls provided at [this link](https://github.com/youngwanLEE/VoV3D/blob/main/DATA.md#kinetics-400) to download the data. 35 | 36 | 37 | ## Training and Evaluation 38 | 39 | | Model | Frames | super image| Image Size | Model Size| FLOPs (G) | 40 | | --- | --- | --- | --- | --- | --- | 41 | | SIFAR-B-7 (`sifar_base_patch4_window7_224`) | 8 | 3x3 | 224 | 87 |138 | 42 | | SIFAR-B-12 (`sifar_base_patch4_window12_192_3x3`) | 8 | 3x3 | 192 | 87| 106 | 43 | | SIFAR-B-14 (`sifar_base_patch4_window14_224_3x3`) | 8 | 3x3 | 224 | 87| 147 | 44 | | SIFAR-B-12† (`sifar_base_patch4_window12_192_4x4`) | 16 | 4x4 | 192 | 87| 189 | 45 | | SIFAR-B-14† (`sifar_base_patch4_window12_224_4x4`) | 16 | 4x4 | 224 | 87| 263 | 46 | | SIFAR-B-12‡ (`sifar_base_patch4_window12_192_3x3`) | 8 | 3x3 | 384 | 87| 423 | 47 | 48 | The table above lists the configurations of different models supported by SIFAR. When training or testing a model, please make sure that the input arguments match a confiuration in the table. 49 | 50 | Here is an example of training a 8-frame kinetics400 model with `Uniform Sampling` on a single node with 6 GPUs, 51 | 52 | ``` 53 | python -m torch.distributed.launch --nproc_per_node=6 main.py --data_dir [path-to-video] --use_pyav --dataset kinetics400 \ 54 | --opt adamw --lr 1e-4 --epochs 30 --sched cosine --duration 8 --batch-size 2 --super_img_rows 3 --disable_scaleup \ 55 | --mixup 0.8 --cutmix 1.0 --drop-path 0.1 --pretrained --warmup-epochs 5 --no-amp --model sifar_base_patch4_window14_224_3x3 \ 56 | --output_dir [output_dir] 57 | ``` 58 | To enable position embedding, add '--hpe_to_token' to the script. 59 | 60 | Below is another example of fine tuning a SSV2 model using a Kinetics400 pretrain, 61 | ``` 62 | python -m torch.distributed.launch --nproc_per_node=6 main.py --data_dir [path-to-video] --use_pyav --dataset sth2stv2 \ 63 | --opt adamw --lr 1e-4 --epochs 20 --sched cosine --duration 8 --batch-size 2 --super_img_rows 3 --disable_scaleup \ 64 | --mixup 0.8 --cutmix 1.0 --drop-path 0.1 --pretrained --warmup-epochs 0 --no-amp --model sifar_base_patch4_window14_224_3x3 \ 65 | --logdir [output_dir] --hpe_to_token --initial_checkpoint [path-to-pretrain] 66 | ``` 67 | 68 | More options for training SIFAR models can be found in `main.py`. You can get help via 69 | ``` 70 | python3 main.py --help 71 | ``` 72 | 73 | To evaluate a model, add '--eval' to a training script and specify the path to the model to be tested by '--initial_checkpoint'. The number of crops and clips for evaluation can be set via '--num_clips' and '--num_crops'. Below is an example of running a model with 3 crops and 3 clipts, 74 | ``` 75 | python -m torch.distributed.launch --nproc_per_node=6 main.py --data_dir [path-to-video] --use_pyav --dataset sth2stv2 \ 76 | --opt adamw --lr 1e-4 --epochs 30 --sched cosine --duration 8 --batch-size 2 --super_img_rows 3 --disable_scaleup \ 77 | --mixup 0.8 --cutmix 1.0 --drop-path 0.1 --pretrained --warmup-epochs 5 --no-amp --model sifar_base_patch4_window14_224_3x3 \ 78 | --output_dir [output_dir] --hpe_to_token --initial_checkpoint [path-to-pretrain] --eval --num_crops 3 --num_clips 3 79 | ``` 80 | 81 | |Dataset| Model | Frames | Top1 | Top5 | Download | 82 | | --- | --- | --- | --- | --- | --- | 83 | | Kinetics400| SIFAR-B-12 | 8 | 80.0 | 94.5 | - | 84 | | | SIFAR-B-12† | 16 | 80.4 | 94.4 | - | 85 | | | SIFAR-B-14 | 8 | 80.2 | 94.4 | [link](https://github.com/IBM/sifar-pytorch/releases/download/action-models/sifar_base_patch4_window14_224_3x3-kinetics400_f8_pe_aug.pth)| 86 | | | SIFAR-B-14† | 16 | 81.8 | 95.2 | [link](https://github.com/IBM/sifar-pytorch/releases/download/action-models/sifar_base_patch4_window14_224_4x4-kinetics400_f16_pe_aug_v1.pth) | 87 | | SSV2 | SIFAR-B-12 | 8 | 60.8 | 87.3 | - | 88 | | | SIFAR-B-12† | 16 | 61.4 | 87.6 | - | 89 | | | SIFAR-B-14 | 8 | 61.6 | 87.9 | [link](https://github.com/IBM/sifar-pytorch/releases/download/action-models/sifar_base_patch4_window14_224_3x3-st2stv2_kineticsft_f8_pe_aug.pth) | 90 | | | SIFAR-B-14† | 16 | 62.6 | 88.5 | [link](https://github.com/IBM/sifar-pytorch/releases/download/action-models/sifar_base_patch4_window14_224_4x4-st2stv2_f16_kineticsft_pe_aug_v1.pth) | 91 | 92 | # License 93 | This repository is released under the appache-2.0. license as found in the [LICENSE](LICENSE) file. 94 | 95 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import glob 8 | import argparse 9 | import datetime 10 | import numpy as np 11 | import time 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import json 15 | import os 16 | import warnings 17 | 18 | from pathlib import Path 19 | 20 | from timm.data import Mixup 21 | from timm.models import create_model 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from timm.scheduler import create_scheduler 24 | from timm.optim import create_optimizer 25 | from timm.utils import NativeScaler, get_state_dict, ModelEma 26 | 27 | from datasets import build_dataset 28 | from engine import train_one_epoch, evaluate 29 | from samplers import RASampler 30 | import models 31 | import my_models 32 | from torch.utils.tensorboard import SummaryWriter 33 | import torch.nn as nn 34 | import simclr 35 | import utils 36 | from losses import DeepMutualLoss, ONELoss, MulMixturelLoss, SelfDistillationLoss 37 | from vtab import DATASET_REGISTRY 38 | 39 | 40 | 41 | from collections import OrderedDict 42 | 43 | #from timm.models.vision_transformer import Block, Attention 44 | from my_models import action_vit_ts, action_vit_hub, action_vit_swin 45 | 46 | 47 | from timm.models.registry import register_model 48 | from timm.models.helpers import build_model_with_cfg 49 | from timm.models.resnet import Bottleneck, ResNet, default_cfgs 50 | from video_dataset_config import get_dataset_config, DATASET_CONFIG 51 | 52 | from main import get_args_parser 53 | 54 | def _create_resnet(variant, pretrained=False, **kwargs): 55 | return build_model_with_cfg( 56 | ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) 57 | 58 | 59 | @register_model 60 | def ecaresnet152d(pretrained=False, **kwargs): 61 | """Constructs a ResNet-101-D model with eca. 62 | """ 63 | model_args = dict( 64 | block=Bottleneck, layers=[3, 4, 36, 3], stem_width=32, stem_type='deep', avg_down=True, 65 | block_args=dict(attn_layer='eca'), **kwargs) 66 | return _create_resnet('ecaresnet101d', pretrained, **model_args) 67 | 68 | 69 | warnings.filterwarnings("ignore", category=UserWarning) 70 | #torch.multiprocessing.set_start_method('spawn', force=True) 71 | 72 | def summary(model, input_tensor, attention_cls): 73 | def register_hook(module): 74 | def hook(module, input, output): 75 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 76 | module_idx = len(summary) 77 | 78 | m_key = '%s-%i' % (class_name, module_idx + 1) 79 | summary[m_key] = OrderedDict({'input_shape': 'N/A', 'output_shape': 'N/A', 'flops': 0, 'nb_params': 0}) 80 | if not isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, attention_cls)): 81 | return 82 | if isinstance(input[0], (list, tuple)): 83 | return 84 | 85 | summary[m_key]['input_shape'] = list(input[0].size()) 86 | batch_size = summary[m_key]['input_shape'][0] 87 | # summary[m_key]['input_shape'][0] = -1 88 | if isinstance(output, (list, tuple)): 89 | summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output] 90 | else: 91 | summary[m_key]['output_shape'] = list(output.size()) 92 | summary[m_key]['output_shape'][0] = -1 93 | summary[m_key]['output_shape'] = list(output.size()) 94 | params = 0 95 | if hasattr(module, 'weight') and hasattr(module.weight, 'size'): 96 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 97 | summary[m_key]['trainable'] = module.weight.requires_grad 98 | if hasattr(module, 'bias') and hasattr(module.bias, 'size'): 99 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 100 | 101 | summary[m_key]['nb_params'] = params 102 | 103 | 104 | if hasattr(module, 'kernel_size') and hasattr(module, 'out_channels') and hasattr(module, 'in_channels'): 105 | output_size = torch.prod(torch.LongTensor(summary[m_key]['output_shape'][1:])) 106 | flops_per_point = np.prod(module.kernel_size) * module.in_channels / module.groups 107 | summary[m_key]['flops'] = int(output_size * flops_per_point) 108 | else: 109 | if isinstance(module, nn.Linear): 110 | if len(summary[m_key]['output_shape']) == 4: 111 | summary[m_key]['flops'] = summary[m_key]['input_shape'][-1] * summary[m_key]['output_shape'][1] * summary[m_key]['output_shape'][2] * summary[m_key]['output_shape'][3] 112 | elif len(summary[m_key]['output_shape']) == 3: 113 | summary[m_key]['flops'] = summary[m_key]['input_shape'][-1] * summary[m_key]['output_shape'][1] * summary[m_key]['output_shape'][2] 114 | elif len(summary[m_key]['output_shape']) == 2: 115 | summary[m_key]['flops'] = summary[m_key]['input_shape'][-1] * summary[m_key]['output_shape'][-1] 116 | else: 117 | summary[m_key]['flops'] = 0 118 | elif isinstance(module, (attention_cls)): 119 | n = summary[m_key]['input_shape'][1] 120 | c = summary[m_key]['input_shape'][-1] 121 | summary[m_key]['flops'] = (2 * (n * n * c)) 122 | else: 123 | summary[m_key]['flops'] = 0 124 | summary[m_key]['flops'] *= batch_size 125 | 126 | if (not isinstance(module, nn.Sequential) and 127 | not isinstance(module, nn.ModuleList) and 128 | not (module == model)): 129 | hooks.append(module.register_forward_hook(hook)) 130 | 131 | model.eval() 132 | # create properties 133 | summary = OrderedDict() 134 | hooks = [] 135 | # register hook 136 | model.apply(register_hook) 137 | # make a forward pass 138 | model(input_tensor) 139 | # remove these hooks 140 | for h in hooks: 141 | h.remove() 142 | 143 | ret = "" 144 | ret += '-----------------------------------------------------------------------------------\n' 145 | line_new = '{:>24} {:>25} {:>15} {:>15}\n'.format('Layer (type)', 'Output Shape', 'Param #', 'FLOPs #') 146 | ret += line_new 147 | ret += '===================================================================================\n' 148 | total_params = 0 149 | trainable_params = 0 150 | total_flops = 0 151 | for layer in summary: 152 | 153 | # if summary[layer]['flops'] == 0: 154 | # continue 155 | # input_shape, output_shape, trainable, nb_params 156 | line_new = '{:>24} {:>25} {:>15} {:>15}\n'.format(layer, str(summary[layer]['output_shape']), '{0:,}'.format(summary[layer]['nb_params']), '{0:,}'.format(summary[layer]['flops'])) 157 | total_params += summary[layer]['nb_params'] 158 | total_flops += summary[layer]['flops'] 159 | if 'trainable' in summary[layer]: 160 | if summary[layer]['trainable'] == True: 161 | trainable_params += summary[layer]['nb_params'] 162 | ret += line_new 163 | 164 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 165 | ret += '===================================================================================\n' 166 | ret += 'Total flops: {0:,}\n'.format(total_flops) 167 | ret += 'Total params: {0:,}\n'.format(total_params) 168 | ret += 'Trainable params: {0:,}\n'.format(trainable_params) 169 | ret += 'Non-trainable params: {0:,}\n'.format(total_params - trainable_params) 170 | ret += '-----------------------------------------------------------------------------------' 171 | return ret, total_flops, total_params 172 | # return summary 173 | 174 | 175 | 176 | def main(args): 177 | #utils.init_distributed_mode(args) 178 | #print(args) 179 | # Patch 180 | if not hasattr(args, 'hard_contrastive'): 181 | args.hard_contrastive = False 182 | if not hasattr(args, 'selfdis_w'): 183 | args.selfdis_w = 0.0 184 | 185 | 186 | device = torch.device(args.device) 187 | 188 | # fix the seed for reproducibility 189 | seed = args.seed + utils.get_rank() 190 | torch.manual_seed(seed) 191 | np.random.seed(seed) 192 | # random.seed(seed) 193 | 194 | cudnn.benchmark = True 195 | num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config( 196 | args.dataset, args.use_lmdb) 197 | 198 | args.num_classes = num_classes 199 | if args.modality == 'rgb': 200 | args.input_channels = 3 201 | elif args.modality == 'flow': 202 | args.input_channels = 2 * 5 203 | 204 | if 'action_vit_ts' in args.model: 205 | Attention = action_vit_ts.Attention 206 | elif 'action_vit_hub' in args.model: 207 | Attention = action_vit_hub.Attention 208 | elif 'action_vit_swin' in args.model: 209 | Attention = action_vit_swin.WindowAttention 210 | else: 211 | from timm.models.vision_transformer import Block, Attention 212 | 213 | #print(f"Creating model: {args.model}") 214 | model = create_model( 215 | args.model, 216 | pretrained=args.pretrained, 217 | duration=args.duration, 218 | frame_cls_tokens=args.frame_cls_tokens, 219 | temporal_module_name=args.temporal_module_name, 220 | temporal_attention_only=args.temporal_attention_only, 221 | temporal_heads_scale=args.temporal_heads_scale, 222 | temporal_mlp_scale = args.temporal_mlp_scale, 223 | hpe_to_token = args.hpe_to_token, 224 | spatial_hub_size = args.spatial_hub_size, 225 | hub_attention=args.hub_attention, 226 | hub_aggregation=args.hub_aggregation, 227 | temporal_pooling = args.temporal_pooling, 228 | bottleneck = args.bottleneck, 229 | rel_pos = args.rel_pos, 230 | window_size=args.window_size, 231 | super_img_rows = args.super_img_rows, 232 | token_mask=not args.no_token_mask, 233 | online_learning = args.one_w >0.0 or args.dml_w >0.0, 234 | num_classes=args.num_classes, 235 | drop_rate=args.drop, 236 | drop_path_rate=args.drop_path, 237 | drop_block_rate=args.drop_block, 238 | ) 239 | 240 | optimizer = create_optimizer(args, model) 241 | 242 | model.to(device) 243 | criterion = torch.nn.CrossEntropyLoss() 244 | #total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 245 | data = torch.randn((args.batch_size, 3 * args.duration, args.input_size, args.input_size), device=device, dtype=torch.float) 246 | data_ = torch.randn((1, 3 * args.duration, args.input_size, args.input_size), device=device, dtype=torch.float) 247 | with torch.no_grad(): 248 | o, flops, params = summary(model, data_, Attention) 249 | #print(o) 250 | print(f"FLOPs: {flops}, Params: {params}") 251 | exit(0) 252 | #flops /= args.batch_size 253 | target = torch.ones((args.batch_size), device=device, dtype=torch.long) 254 | start = torch.cuda.Event(enable_timing=True) 255 | end = torch.cuda.Event(enable_timing=True) 256 | 257 | # training 258 | #print("Start!") 259 | if args.eval: 260 | model.eval() 261 | with torch.no_grad(): 262 | for i in range(10): 263 | model(data) 264 | start.record() 265 | with torch.no_grad(): 266 | for i in range(args.iters): 267 | model(data) 268 | end.record() 269 | else: 270 | for i in range(10): 271 | optimizer.zero_grad() 272 | out = model(data) 273 | loss = torch.mean(out) 274 | loss.backward() 275 | optimizer.step() 276 | start.record() 277 | for i in range(args.iters): 278 | optimizer.zero_grad() 279 | out = model(data) 280 | loss = torch.mean(out) 281 | loss.backward() 282 | optimizer.step() 283 | end.record() 284 | torch.cuda.synchronize() 285 | 286 | all_accs = {} 287 | try: 288 | log_paths = sorted(glob.glob(f'checkpoint/**/log.txt', recursive=True)) 289 | for log_path in log_paths: 290 | model_name = os.path.basename(os.path.dirname(log_path)) 291 | if args.model not in model_name: 292 | continue 293 | finish = False 294 | best_acc, best_epoch = 0, 0 295 | with open(log_path) as f: 296 | for line in f.readlines(): 297 | line = line.strip() 298 | if line == '': 299 | continue 300 | stat = json.loads(line) 301 | curr_acc = stat['test_acc1'] 302 | curr_epoch = stat['epoch'] 303 | if curr_acc > best_acc: 304 | best_acc = curr_acc 305 | best_epoch = curr_epoch 306 | if curr_epoch == 299: 307 | finish = True 308 | if not finish: 309 | model_name = model_name + f"({curr_epoch})" 310 | all_accs[model_name] = best_acc 311 | except Exception as e: 312 | print(e) 313 | model_name = "X_" + args.model 314 | best_acc = 0.0 315 | 316 | if all_accs == {}: 317 | all_accs[args.model] = 0 318 | for model_name, best_acc in all_accs.items(): 319 | print(f"{model_name}\t{params / 1e6:.1f}\t{flops / 1e9:.1f}\t{best_acc:.2f}") 320 | print(f"{args.model}{'@Val' if args.eval else '@Train'}: {flops / 1e9:.1f} & {args.iters * args.batch_size / start.elapsed_time(end) * 1000.0:.1f} & {params / 1e6:.1f}") 321 | #print(f"{args.model}{'@Val' if args.eval else '@Train'}: {args.iters * args.batch_size / start.elapsed_time(end) * 1000.0:.2f} Images/second. FLOPs:{flops},Parameters: {params}") 322 | 323 | if __name__ == '__main__': 324 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 325 | args = parser.parse_args() 326 | if args.output_dir: 327 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 328 | main(args) 329 | -------------------------------------------------------------------------------- /dataset/video_dataset_aug.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from typing import Union, List, Tuple 3 | 4 | import torch 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torch.utils.data.distributed 9 | import torchvision.transforms as transforms 10 | from .video_transforms import (GroupRandomHorizontalFlip, GroupOverSample, 11 | GroupMultiScaleCrop, GroupScale, GroupCenterCrop, GroupRandomCrop, 12 | GroupNormalize, Stack, ToTorchFormatTensor, GroupRandomScale) 13 | 14 | def get_augmentor(is_train: bool, image_size: int, mean: List[float] = None, 15 | std: List[float] = None, disable_scaleup: bool = False, 16 | threed_data: bool = False, version: str = 'v1', scale_range: [int] = None, 17 | modality: str = 'rgb', num_clips: int = 1, num_crops: int = 1, dataset: str = ''): 18 | 19 | mean = [0.485, 0.456, 0.406] if mean is None else mean 20 | std = [0.229, 0.224, 0.225] if std is None else std 21 | scale_range = [256, 320] if scale_range is None else scale_range 22 | 23 | if modality == 'sound': 24 | augments = [ 25 | Stack(threed_data=threed_data), 26 | ToTorchFormatTensor(div=False, num_clips_crops=num_clips * num_crops) 27 | ] 28 | else: 29 | augments = [] 30 | if is_train: 31 | if version == 'v1': 32 | augments += [ 33 | GroupMultiScaleCrop(image_size, [1, .875, .75, .66]) 34 | ] 35 | elif version == 'v2': 36 | augments += [ 37 | GroupRandomScale(scale_range), 38 | GroupRandomCrop(image_size), 39 | ] 40 | if not (dataset.startswith('ststv') or 'jester' in dataset or 'mini_ststv' in dataset): 41 | augments += [GroupRandomHorizontalFlip(is_flow=(modality == 'flow'))] 42 | else: 43 | scaled_size = image_size if disable_scaleup else int(image_size / 0.875 + 0.5) 44 | if num_crops == 1: 45 | augments += [ 46 | GroupScale(scaled_size), 47 | GroupCenterCrop(image_size) 48 | ] 49 | else: 50 | flip = True if num_crops == 10 else False 51 | augments += [ 52 | GroupOverSample(image_size, scaled_size, num_crops=num_crops, flip=flip), 53 | ] 54 | augments += [ 55 | Stack(threed_data=threed_data), 56 | ToTorchFormatTensor(num_clips_crops=num_clips * num_crops), 57 | GroupNormalize(mean=mean, std=std, threed_data=threed_data) 58 | ] 59 | 60 | augmentor = transforms.Compose(augments) 61 | return augmentor 62 | 63 | 64 | def build_dataflow(dataset, is_train, batch_size, workers=36, is_distributed=False): 65 | workers = min(workers, multiprocessing.cpu_count()) 66 | shuffle = False 67 | 68 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None 69 | if is_train: 70 | shuffle = sampler is None 71 | 72 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 73 | num_workers=workers, pin_memory=True, sampler=sampler) 74 | 75 | return data_loader 76 | 77 | -------------------------------------------------------------------------------- /dataset/video_dataset_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | DATASET_CONFIG = { 4 | 'st2stv2': { 5 | 'num_classes': 174, 6 | 'train_list_name': 'train.txt', 7 | 'val_list_name': 'val.txt', 8 | 'test_list_name': 'test.txt', 9 | 'filename_seperator': " ", 10 | 'image_tmpl': '{:05d}.jpg', 11 | 'filter_video': 3, 12 | 'label_file': 'categories.txt' 13 | }, 14 | 'mini_st2stv2': { 15 | 'num_classes': 87, 16 | 'train_list_name': 'mini_train.txt', 17 | 'val_list_name': 'mini_val.txt', 18 | 'test_list_name': 'mini_test.txt', 19 | 'filename_seperator': " ", 20 | 'image_tmpl': '{:05d}.jpg', 21 | 'filter_video': 3, 22 | }, 23 | 'st2stv1': { 24 | 'num_classes': 174, 25 | 'train_list_name': 'training_256.txt', 26 | 'val_list_name': 'validation_256.txt', 27 | 'test_list_name': 'testing_256.txt', 28 | 'filename_seperator': " ", 29 | 'image_tmpl': '{:05d}.jpg', 30 | 'filter_video': 3, 31 | 'label_file': 'something-something-v1-labels.csv' 32 | }, 33 | 'kinetics400': { 34 | 'num_classes': 400, 35 | 'train_list_name': 'train.txt', 36 | 'val_list_name': 'val.txt', 37 | 'test_list_name': 'test.txt', 38 | 'filename_seperator': ";", 39 | 'image_tmpl': '{:05d}.jpg', 40 | 'filter_video': 30, 41 | 'label_file': 'image/kinetics-400_label.csv' 42 | }, 43 | 'mini_kinetics400': { 44 | 'num_classes': 200, 45 | 'train_list_name': 'mini_train.txt', 46 | 'val_list_name': 'mini_val.txt', 47 | 'test_list_name': 'mini_test.txt', 48 | 'filename_seperator': ";", 49 | 'image_tmpl': '{:05d}.jpg', 50 | 'filter_video': 30 51 | }, 52 | 'charades': { 53 | 'num_classes': 157, 54 | 'train_list_name': 'train.txt', 55 | 'val_list_name': 'val.txt', 56 | 'filename_seperator': " ", 57 | 'image_tmpl': '{:06d}.jpg', 58 | 'filter_video': 0 59 | }, 60 | 'diva': { 61 | 'num_classes': 19, 62 | 'train_list_name': 'DIVA_GT_RGB_TSM_train.txt', 63 | 'val_list_name': 'DIVA_GT_RGB_TSM_validate.txt', 64 | 'filename_seperator': " ", 65 | 'image_tmpl': '{:08d}.jpg', 66 | 'filter_video': 0 67 | }, 68 | 'diva_pvi': { 69 | 'num_classes': 8, 70 | 'train_list_name': 'DIVA_PVI_GT_RGB_TSM_train.txt', 71 | 'val_list_name': 'DIVA_PVI_GT_RGB_TSM_validate.txt', 72 | 'filename_seperator': " ", 73 | 'image_tmpl': '{:08d}.jpg', 74 | 'filter_video': 0 75 | }, 76 | 'moments': { 77 | 'num_classes': 339, 78 | 'train_list_name': 'train.txt', 79 | 'val_list_name': 'val.txt', 80 | 'filename_seperator': " ", 81 | 'image_tmpl': '{:05d}.jpg', 82 | 'filter_video': 0 83 | }, 84 | 'mini_moments': { 85 | 'num_classes': 200, 86 | 'train_list_name': 'mini_train.txt', 87 | 'val_list_name': 'mini_val.txt', 88 | 'filename_seperator': " ", 89 | 'image_tmpl': '{:05d}.jpg', 90 | 'filter_video': 0 91 | }, 92 | 'ucf101': { 93 | 'num_classes': 101, 94 | 'train_list_name': 'train.txt', 95 | 'val_list_name': 'val.txt', 96 | 'filename_seperator': " ", 97 | 'image_tmpl': '{:05d}.jpg', 98 | 'filter_video': 0 99 | }, 100 | 'hmdb51': { 101 | 'num_classes': 51, 102 | 'train_list_name': 'train.txt', 103 | 'val_list_name': 'val.txt', 104 | 'filename_seperator': " ", 105 | 'image_tmpl': '{:05d}.jpg', 106 | 'filter_video': 0 107 | }, 108 | 'jester': { 109 | 'num_classes': 27, 110 | 'train_list_name': 'train.txt', 111 | 'val_list_name': 'val.txt', 112 | 'filename_seperator': " ", 113 | 'image_tmpl': '{:05d}.jpg', 114 | 'filter_video': 0 115 | }, 116 | } 117 | 118 | 119 | def get_dataset_config(dataset, use_lmdb=False): 120 | ret = DATASET_CONFIG[dataset] 121 | num_classes = ret['num_classes'] 122 | train_list_name = ret['train_list_name'].replace("txt", "lmdb") if use_lmdb \ 123 | else ret['train_list_name'] 124 | val_list_name = ret['val_list_name'].replace("txt", "lmdb") if use_lmdb \ 125 | else ret['val_list_name'] 126 | test_list_name = ret.get('test_lmdb_name', None) 127 | if test_list_name is not None: 128 | test_list_name = test_list_name.replace("txt", "lmdb") 129 | filename_seperator = ret['filename_seperator'] 130 | image_tmpl = ret['image_tmpl'] 131 | filter_video = ret.get('filter_video', 0) 132 | label_file = ret.get('label_file', None) 133 | 134 | return num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, \ 135 | image_tmpl, filter_video, label_file 136 | -------------------------------------------------------------------------------- /dataset/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | class GroupRandomCrop(object): 10 | def __init__(self, size): 11 | if isinstance(size, numbers.Number): 12 | self.size = (int(size), int(size)) 13 | else: 14 | self.size = size 15 | 16 | def __call__(self, img_group): 17 | 18 | w, h = img_group[0].size 19 | th, tw = self.size 20 | 21 | out_images = list() 22 | 23 | x1 = random.randint(0, w - tw) 24 | y1 = random.randint(0, h - th) 25 | 26 | for img in img_group: 27 | assert(img.size[0] == w and img.size[1] == h) 28 | if w == tw and h == th: 29 | out_images.append(img) 30 | else: 31 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 32 | 33 | return out_images 34 | 35 | 36 | class GroupCenterCrop(object): 37 | def __init__(self, size): 38 | self.worker = torchvision.transforms.CenterCrop(size) 39 | 40 | def __call__(self, img_group): 41 | return [self.worker(img) for img in img_group] 42 | 43 | 44 | class GroupRandomHorizontalFlip(object): 45 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 46 | """ 47 | def __init__(self, is_flow=False): 48 | self.is_flow = is_flow 49 | 50 | def __call__(self, img_group, is_flow=False): 51 | v = random.random() 52 | if v < 0.5: 53 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 54 | if self.is_flow: 55 | for i in range(0, len(ret), 2): 56 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 57 | return ret 58 | else: 59 | return img_group 60 | 61 | 62 | class GroupNormalize(object): 63 | def __init__(self, mean, std, threed_data=False): 64 | self.threed_data = threed_data 65 | if self.threed_data: 66 | # convert to the proper format 67 | self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1) 68 | self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1) 69 | else: 70 | self.mean = mean 71 | self.std = std 72 | 73 | def __call__(self, tensor): 74 | 75 | if self.threed_data: 76 | tensor.sub_(self.mean).div_(self.std) 77 | else: 78 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 79 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 80 | 81 | # TODO: make efficient 82 | for t, m, s in zip(tensor, rep_mean, rep_std): 83 | t.sub_(m).div_(s) 84 | 85 | return tensor 86 | 87 | 88 | class GroupScale(object): 89 | """ Rescales the input PIL.Image to the given 'size'. 90 | 'size' will be the size of the smaller edge. 91 | For example, if height > width, then image will be 92 | rescaled to (size * height / width, size) 93 | size: size of the smaller edge 94 | interpolation: Default: PIL.Image.BILINEAR 95 | """ 96 | 97 | def __init__(self, size, interpolation=Image.BILINEAR): 98 | self.worker = torchvision.transforms.Resize(size, interpolation) 99 | 100 | def __call__(self, img_group): 101 | return [self.worker(img) for img in img_group] 102 | 103 | class GroupRandomScale(object): 104 | """ Rescales the input PIL.Image to the given 'size'. 105 | 'size' will be the size of the smaller edge. 106 | For example, if height > width, then image will be 107 | rescaled to (size * height / width, size) 108 | size: size of the smaller edge 109 | interpolation: Default: PIL.Image.BILINEAR 110 | 111 | Randomly select the smaller edge from the range of 'size'. 112 | """ 113 | def __init__(self, size, interpolation=Image.BILINEAR): 114 | self.size = size 115 | self.interpolation = interpolation 116 | 117 | def __call__(self, img_group): 118 | selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int) 119 | scale = GroupScale(selected_size, interpolation=self.interpolation) 120 | return scale(img_group) 121 | 122 | class GroupOverSample(object): 123 | def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False): 124 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 125 | 126 | if scale_size is not None: 127 | self.scale_worker = GroupScale(scale_size) 128 | else: 129 | self.scale_worker = None 130 | 131 | if num_crops not in [1, 3, 5, 10]: 132 | raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops)) 133 | self.num_crops = num_crops 134 | 135 | self.flip = flip 136 | 137 | def __call__(self, img_group): 138 | 139 | if self.scale_worker is not None: 140 | img_group = self.scale_worker(img_group) 141 | 142 | image_w, image_h = img_group[0].size 143 | crop_w, crop_h = self.crop_size 144 | 145 | if self.num_crops == 3: 146 | w_step = (image_w - crop_w) // 4 147 | h_step = (image_h - crop_h) // 4 148 | offsets = list() 149 | if image_w != crop_w and image_h != crop_h: 150 | offsets.append((0 * w_step, 0 * h_step)) # top 151 | offsets.append((4 * w_step, 4 * h_step)) # bottom 152 | offsets.append((2 * w_step, 2 * h_step)) # center 153 | else: 154 | if image_w < image_h: 155 | offsets.append((2 * w_step, 0 * h_step)) # top 156 | offsets.append((2 * w_step, 4 * h_step)) # bottom 157 | offsets.append((2 * w_step, 2 * h_step)) # center 158 | else: 159 | offsets.append((0 * w_step, 2 * h_step)) # left 160 | offsets.append((4 * w_step, 2 * h_step)) # right 161 | offsets.append((2 * w_step, 2 * h_step)) # center 162 | 163 | else: 164 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 165 | 166 | oversample_group = list() 167 | for o_w, o_h in offsets: 168 | normal_group = list() 169 | flip_group = list() 170 | for i, img in enumerate(img_group): 171 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 172 | normal_group.append(crop) 173 | if self.flip: 174 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 175 | 176 | if img.mode == 'L' and i % 2 == 0: 177 | flip_group.append(ImageOps.invert(flip_crop)) 178 | else: 179 | flip_group.append(flip_crop) 180 | 181 | oversample_group.extend(normal_group) 182 | if self.flip: 183 | oversample_group.extend(flip_group) 184 | return oversample_group 185 | 186 | 187 | class GroupMultiScaleCrop(object): 188 | 189 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 190 | self.scales = scales if scales is not None else [1, 875, .75, .66] 191 | self.max_distort = max_distort 192 | self.fix_crop = fix_crop 193 | self.more_fix_crop = more_fix_crop 194 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 195 | self.interpolation = Image.BILINEAR 196 | 197 | def __call__(self, img_group): 198 | 199 | im_size = img_group[0].size 200 | 201 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 202 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 203 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 204 | for img in crop_img_group] 205 | return ret_img_group 206 | 207 | def _sample_crop_size(self, im_size): 208 | image_w, image_h = im_size[0], im_size[1] 209 | 210 | # find a crop size 211 | base_size = min(image_w, image_h) 212 | crop_sizes = [int(base_size * x) for x in self.scales] 213 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 214 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 215 | 216 | pairs = [] 217 | for i, h in enumerate(crop_h): 218 | for j, w in enumerate(crop_w): 219 | if abs(i - j) <= self.max_distort: 220 | pairs.append((w, h)) 221 | 222 | crop_pair = random.choice(pairs) 223 | if not self.fix_crop: 224 | w_offset = random.randint(0, image_w - crop_pair[0]) 225 | h_offset = random.randint(0, image_h - crop_pair[1]) 226 | else: 227 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 228 | 229 | return crop_pair[0], crop_pair[1], w_offset, h_offset 230 | 231 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 232 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 233 | return random.choice(offsets) 234 | 235 | @staticmethod 236 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 237 | w_step = (image_w - crop_w) // 4 238 | h_step = (image_h - crop_h) // 4 239 | 240 | ret = list() 241 | ret.append((0, 0)) # upper left 242 | ret.append((4 * w_step, 0)) # upper right 243 | ret.append((0, 4 * h_step)) # lower left 244 | ret.append((4 * w_step, 4 * h_step)) # lower right 245 | ret.append((2 * w_step, 2 * h_step)) # center 246 | 247 | if more_fix_crop: 248 | ret.append((0, 2 * h_step)) # center left 249 | ret.append((4 * w_step, 2 * h_step)) # center right 250 | ret.append((2 * w_step, 4 * h_step)) # lower center 251 | ret.append((2 * w_step, 0 * h_step)) # upper center 252 | 253 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 254 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 255 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 256 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 257 | 258 | return ret 259 | 260 | 261 | class GroupRandomSizedCrop(object): 262 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 263 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 264 | This is popularly used to train the Inception networks 265 | size: size of the smaller edge 266 | interpolation: Default: PIL.Image.BILINEAR 267 | """ 268 | def __init__(self, size, interpolation=Image.BILINEAR): 269 | self.size = size 270 | self.interpolation = interpolation 271 | 272 | def __call__(self, img_group): 273 | for attempt in range(10): 274 | area = img_group[0].size[0] * img_group[0].size[1] 275 | target_area = random.uniform(0.08, 1.0) * area 276 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 277 | 278 | w = int(round(math.sqrt(target_area * aspect_ratio))) 279 | h = int(round(math.sqrt(target_area / aspect_ratio))) 280 | 281 | if random.random() < 0.5: 282 | w, h = h, w 283 | 284 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 285 | x1 = random.randint(0, img_group[0].size[0] - w) 286 | y1 = random.randint(0, img_group[0].size[1] - h) 287 | found = True 288 | break 289 | else: 290 | found = False 291 | x1 = 0 292 | y1 = 0 293 | 294 | if found: 295 | out_group = list() 296 | for img in img_group: 297 | img = img.crop((x1, y1, x1 + w, y1 + h)) 298 | assert(img.size == (w, h)) 299 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 300 | return out_group 301 | else: 302 | # Fallback 303 | scale = GroupScale(self.size, interpolation=self.interpolation) 304 | crop = GroupRandomCrop(self.size) 305 | return crop(scale(img_group)) 306 | 307 | 308 | class Stack(object): 309 | 310 | def __init__(self, roll=False, threed_data=False): 311 | self.roll = roll 312 | self.threed_data = threed_data 313 | 314 | def __call__(self, img_group): 315 | if img_group[0].mode == 'L': 316 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 317 | elif img_group[0].mode == 'RGB': 318 | if self.threed_data: 319 | return np.stack(img_group, axis=0) 320 | else: 321 | if self.roll: 322 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 323 | else: 324 | return np.concatenate(img_group, axis=2) 325 | 326 | 327 | class ToTorchFormatTensor(object): 328 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 329 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 330 | def __init__(self, div=True, num_clips_crops=1): 331 | self.div = div 332 | self.num_clips_crops = num_clips_crops 333 | 334 | def __call__(self, pic): 335 | if isinstance(pic, np.ndarray): 336 | # handle numpy array 337 | if len(pic.shape) == 4: 338 | # ((NF)xCxHxW) --> (Cx(NF)xHxW) 339 | img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() 340 | else: # data is HW(FC) 341 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 342 | else: 343 | # handle PIL Image 344 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 345 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 346 | # put it from HWC to CHW format 347 | # yikes, this transpose takes 80% of the loading time/CPU 348 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 349 | return img.float().div(255) if self.div else img.float() 350 | 351 | 352 | class IdentityTransform(object): 353 | 354 | def __call__(self, data): 355 | return data 356 | 357 | 358 | if __name__ == "__main__": 359 | trans = torchvision.transforms.Compose([ 360 | GroupScale(256), 361 | GroupRandomCrop(224), 362 | Stack(), 363 | ToTorchFormatTensor(), 364 | GroupNormalize( 365 | mean=[.485, .456, .406], 366 | std=[.229, .224, .225] 367 | )] 368 | ) 369 | 370 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 371 | 372 | color_group = [im] * 3 373 | rst = trans(color_group) 374 | 375 | gray_group = [im.convert('L')] * 9 376 | gray_rst = trans(gray_group) 377 | 378 | trans2 = torchvision.transforms.Compose([ 379 | GroupRandomSizedCrop(256), 380 | Stack(), 381 | ToTorchFormatTensor(), 382 | GroupNormalize( 383 | mean=[.485, .456, .406], 384 | std=[.229, .224, .225]) 385 | ]) 386 | print(trans2(color_group)) 387 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Train and eval functions used in main.py 9 | """ 10 | import math 11 | import sys 12 | from typing import Iterable, Optional 13 | from einops import rearrange 14 | import torch 15 | 16 | from timm.data import Mixup 17 | from timm.utils import accuracy, ModelEma, reduce_tensor 18 | 19 | import utils 20 | from losses import DeepMutualLoss, ONELoss, SelfDistillationLoss 21 | 22 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 25 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 26 | world_size: int = 1, distributed: bool = True, amp=True, 27 | simclr_criterion=None, simclr_w=0., 28 | branch_div_criterion=None, branch_div_w=0., 29 | simsiam_criterion=None, simsiam_w=0., 30 | moco_criterion=None, moco_w=0., 31 | byol_criterion=None, byol_w=0., 32 | contrastive_nomixup=False, hard_contrastive=False, 33 | finetune=False 34 | ): 35 | # TODO fix this for finetuning 36 | if finetune: 37 | model.train(not finetune) 38 | else: 39 | model.train() 40 | #criterion.train() 41 | metric_logger = utils.MetricLogger(delimiter=" ") 42 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 43 | header = 'Epoch: [{}]'.format(epoch) 44 | print_freq = 50 45 | 46 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 47 | 48 | batch_size = targets.size(0) 49 | if simclr_criterion is not None or simsiam_criterion is not None or moco_criterion is not None or byol_criterion is not None: 50 | samples = [samples[0].to(device, non_blocking=True), samples[1].to(device, non_blocking=True)] 51 | targets = targets.to(device, non_blocking=True) 52 | ori_samples = [x.clone() for x in samples] # copy the original samples 53 | 54 | if mixup_fn is not None: 55 | samples[0], targets_ = mixup_fn(samples[0], targets) 56 | if contrastive_nomixup: # remain one copy for ce loss 57 | samples[1] = ori_samples[0] 58 | samples.append(ori_samples[1]) 59 | elif hard_contrastive: 60 | samples[1] = samples[1] 61 | else: 62 | samples[1], _ = mixup_fn(samples[1], targets) 63 | targets = targets_ 64 | 65 | else: 66 | samples = samples.to(device, non_blocking=True) 67 | targets = targets.to(device, non_blocking=True) 68 | 69 | if mixup_fn is not None: 70 | # batch size has to be an even number 71 | if batch_size == 1: 72 | continue 73 | if batch_size % 2 != 0: 74 | samples, targets = samples[:-1], targets[:-1] 75 | samples, targets = mixup_fn(samples, targets) 76 | 77 | with torch.cuda.amp.autocast(enabled=amp): 78 | outputs = model(samples) 79 | if simclr_criterion is not None: 80 | # outputs 0: ce logits, bs x class, outputs 1: normalized embeddings of two views, bs x 2 x dim 81 | loss_ce = criterion(outputs[0], targets) 82 | loss_simclr = simclr_criterion(outputs[1]) 83 | loss = loss_ce * (1.0 - simclr_w) + loss_simclr * simclr_w 84 | elif simsiam_criterion is not None: 85 | # outputs 0: ce logits, bs x class, outputs 1: normalized embeddings of two views, 4[bs x dim], [p1, z1, p2, z2] 86 | loss_ce = criterion(outputs[0], targets) 87 | loss_simsiam = simsiam_criterion(*outputs[1]) 88 | loss = loss_ce * (1.0 - simsiam_w) + loss_simsiam * simsiam_w 89 | elif branch_div_criterion is not None: 90 | # outputs 0: ce logits, bs x class, outputs 1: embeddings of K branches, K[bs x dim] 91 | loss_ce = criterion(outputs[0], targets) 92 | loss_div = 0.0 93 | for i in range(0, len(outputs[1]), 2): 94 | loss_div += torch.mean(branch_div_criterion(outputs[1][i], outputs[1][i + 1])) 95 | loss = loss_ce * (1.0 - branch_div_w) + loss_div * branch_div_w 96 | elif moco_criterion is not None: 97 | loss_ce = criterion(outputs[0], targets) 98 | loss_moco = moco_criterion(outputs[1][0], outputs[1][1]) 99 | loss = loss_ce * (1.0 - moco_w) + loss_moco * moco_w 100 | elif byol_criterion is not None: 101 | loss_ce = criterion(outputs[0], targets) 102 | loss_byol = byol_criterion(*outputs[1]) 103 | loss = loss_ce * (1.0 - byol_w) + loss_byol * byol_w 104 | else: 105 | if isinstance(criterion, (DeepMutualLoss, ONELoss, SelfDistillationLoss)): 106 | loss, loss_ce, loss_kd = criterion(outputs, targets) 107 | else: 108 | loss = criterion(outputs, targets) 109 | 110 | loss_value = loss.item() 111 | 112 | if not math.isfinite(loss_value): 113 | print("Loss is {}, stopping training".format(loss_value)) 114 | raise ValueError("Loss is {}, stopping training".format(loss_value)) 115 | 116 | optimizer.zero_grad() 117 | 118 | # this attribute is added by timm on one optimizer (adahessian) 119 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 120 | 121 | if amp: 122 | loss_scaler(loss, optimizer, clip_grad=max_norm, 123 | parameters=model.parameters(), create_graph=is_second_order) 124 | else: 125 | loss.backward(create_graph=is_second_order) 126 | if max_norm is not None and max_norm != 0.0: 127 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 128 | optimizer.step() 129 | 130 | torch.cuda.synchronize() 131 | if model_ema is not None: 132 | model_ema.update(model) 133 | 134 | if simclr_criterion is not None: 135 | metric_logger.update(loss_ce=loss_ce.item()) 136 | metric_logger.update(loss_simclr=loss_simclr.item()) 137 | elif simsiam_criterion is not None: 138 | metric_logger.update(loss_ce=loss_ce.item()) 139 | metric_logger.update(loss_simsiam=loss_simsiam.item()) 140 | elif branch_div_criterion is not None: 141 | metric_logger.update(loss_ce=loss_ce.item()) 142 | metric_logger.update(loss_div=loss_div.item()) 143 | elif moco_criterion is not None: 144 | metric_logger.update(loss_ce=loss_ce.item()) 145 | metric_logger.update(loss_moco=loss_moco.item()) 146 | elif byol_criterion is not None: 147 | metric_logger.update(loss_ce=loss_ce.item()) 148 | metric_logger.update(loss_byol=loss_byol.item()) 149 | elif isinstance(criterion, (DeepMutualLoss, ONELoss)): 150 | metric_logger.update(loss_ce=loss_ce.item()) 151 | metric_logger.update(loss_kd=loss_kd.item()) 152 | metric_logger.update(loss=loss_value) 153 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 154 | # gather the stats from all processes 155 | metric_logger.synchronize_between_processes() 156 | print("Averaged stats:", metric_logger) 157 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 158 | 159 | 160 | @torch.no_grad() 161 | def evaluate(data_loader, model, device, world_size, distributed=True, amp=False, num_crops=1, num_clips=1): 162 | criterion = torch.nn.CrossEntropyLoss() 163 | 164 | metric_logger = utils.MetricLogger(delimiter=" ") 165 | header = 'Test:' 166 | 167 | # switch to evaluation mode 168 | model.eval() 169 | 170 | outputs = [] 171 | targets = [] 172 | 173 | for images, target in metric_logger.log_every(data_loader, 10, header): 174 | images = images.to(device, non_blocking=True) 175 | target = target.to(device, non_blocking=True) 176 | # compute output 177 | batch_size = images.shape[0] 178 | #images = images.view((batch_size * num_crops * num_clips, -1) + images.size()[2:]) 179 | with torch.cuda.amp.autocast(enabled=amp): 180 | output = model(images) 181 | #loss = criterion(output, target) 182 | output = output.reshape(batch_size, num_crops * num_clips, -1).mean(dim=1) 183 | #acc1, acc5 = accuracy(output, target, topk=(1, 5)) 184 | 185 | if distributed: 186 | outputs.append(concat_all_gather(output)) 187 | targets.append(concat_all_gather(target)) 188 | else: 189 | outputs.append(output) 190 | targets.append(target) 191 | 192 | batch_size = images.shape[0] 193 | #metric_logger.update(loss=reduced_loss.item()) 194 | #metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 195 | #metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 196 | 197 | 198 | num_data = len(data_loader.dataset) 199 | outputs = torch.cat(outputs, dim=0) 200 | targets = torch.cat(targets, dim=0) 201 | import os 202 | if os.environ.get('TEST', False): 203 | import numpy as np 204 | print("dumping results...") 205 | tmp = outputs[:num_data].cpu().numpy() 206 | tt = targets[:num_data].cpu().numpy() 207 | np.savez("con_mix.npz", pred=tmp, gt=tt) 208 | 209 | real_acc1, real_acc5 = accuracy(outputs[:num_data], targets[:num_data], topk=(1, 5)) 210 | real_loss = criterion(outputs, targets) 211 | metric_logger.update(loss=real_loss.item()) 212 | metric_logger.meters['acc1'].update(real_acc1.item()) 213 | metric_logger.meters['acc5'].update(real_acc5.item()) 214 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 215 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 216 | 217 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 218 | 219 | 220 | @torch.no_grad() 221 | def concat_all_gather(tensor): 222 | """ 223 | Performs all_gather operation on the provided tensors. 224 | *** Warning ***: torch.distributed.all_gather has no gradient. 225 | """ 226 | tensors_gather = [torch.ones_like(tensor) 227 | for _ in range(torch.distributed.get_world_size())] 228 | torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False) 229 | 230 | #output = torch.cat(tensors_gather, dim=0) 231 | if tensor.dim() == 1: 232 | output = rearrange(tensors_gather, 'n b -> (b n)') 233 | else: 234 | output = rearrange(tensors_gather, 'n b c -> (b n) c') 235 | 236 | return output 237 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from models import deit_tiny_patch16_224, deit_small_patch16_224, deit_base_patch16_224 8 | 9 | dependencies = ["torch", "torchvision", "timm"] 10 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 7 | 8 | 9 | class DeepMutualLoss(nn.Module): 10 | 11 | def __init__(self, base_criterion, w, temperature=1.0): 12 | super().__init__() 13 | self.base_criterion = base_criterion 14 | self.kd_criterion = nn.KLDivLoss(reduction='batchmean', log_target=True) 15 | self.w = w if w > 0 else -w 16 | self.T = temperature 17 | 18 | self.neg = w < 0 19 | 20 | def forward(self, logits, targets): 21 | n = len(logits) 22 | 23 | # CE losses 24 | ce_loss = [self.base_criterion(logits[i], targets) for i in range(n)] 25 | ce_loss = torch.sum(torch.stack(ce_loss, dim=0), dim=0) 26 | 27 | # KD Loss 28 | kd_loss = [1. / (n-1) * 29 | self.kd_criterion( 30 | F.log_softmax(logits[i] / self.T, dim=1), 31 | F.log_softmax(logits[j] / self.T, dim=1).detach() 32 | ) * self.T * self.T 33 | for i, j in itertools.permutations(range(n), 2)] 34 | kd_loss = torch.sum(torch.stack(kd_loss, dim=0), dim=0) 35 | if self.neg: 36 | kd_loss = -1.0 * kd_loss 37 | 38 | total_loss = (1.0 - self.w) * ce_loss + self.w * kd_loss 39 | return total_loss, ce_loss.detach(), kd_loss.detach() 40 | 41 | 42 | class ONELoss(nn.Module): 43 | 44 | def __init__(self, base_criterion, w, temperature=1.0): 45 | super().__init__() 46 | self.base_criterion = base_criterion 47 | self.kd_criterion = nn.KLDivLoss(reduction='batchmean', log_target=True) 48 | self.w = w 49 | self.T = temperature 50 | 51 | def forward(self, logits, targets): 52 | n = len(logits) 53 | ensemble_logits = torch.mean(torch.stack(logits, dim=0), dim=0) 54 | 55 | # CE losses 56 | ce_loss = [self.base_criterion(logits[i], targets) for i in range(n)] + [self.base_criterion(ensemble_logits, targets)] 57 | #ce_loss = torch.sum(torch.stack(ce_loss, dim=0), dim=0) 58 | ce_loss = torch.mean(torch.stack(ce_loss, dim=0), dim=0) 59 | 60 | # One Loss 61 | kd_loss = [self.kd_criterion( 62 | F.log_softmax(logits[i] / self.T, dim=1), 63 | F.log_softmax(ensemble_logits / self.T, dim=1).detach() 64 | ) * self.T * self.T for i in range(n)] 65 | #kd_loss = torch.sum(torch.stack(kd_loss, dim=0), dim=0) 66 | kd_loss = torch.mean(torch.stack(kd_loss, dim=0), dim=0) 67 | 68 | #total_loss = (1.0 - self.w) * ce_loss + self.w * kd_loss 69 | #total_loss = (1.0 - self.w) * ce_loss - self.w * kd_loss 70 | total_loss = ce_loss + self.w * kd_loss 71 | return total_loss, ce_loss.detach(), kd_loss.detach() 72 | 73 | 74 | class MulMixLabelSmoothingCrossEntropy(nn.Module): 75 | """ 76 | NLL loss with label smoothing. 77 | """ 78 | def __init__(self, smoothing=0.1): 79 | """ 80 | Constructor for the LabelSmoothing module. 81 | :param smoothing: label smoothing factor 82 | """ 83 | super(MulMixLabelSmoothingCrossEntropy, self).__init__() 84 | assert smoothing < 1.0 85 | self.smoothing = smoothing 86 | self.confidence = 1. - smoothing 87 | 88 | def forward(self, x, target, beta=1.0): 89 | inv_prob = torch.pow(1.0 - F.softmax(x, dim=-1), beta) 90 | logprobs = F.log_softmax(x, dim=-1) 91 | logprobs = logprobs * inv_prob 92 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 93 | nll_loss = nll_loss.squeeze(1) 94 | smooth_loss = -logprobs.mean(dim=-1) 95 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 96 | return loss.mean() 97 | 98 | 99 | class MulMixSoftTargetCrossEntropy(nn.Module): 100 | 101 | def __init__(self): 102 | super(MulMixSoftTargetCrossEntropy, self).__init__() 103 | 104 | def forward(self, x, target, beta=1.0): 105 | inv_prob = torch.pow(1.0 - F.softmax(x, dim=-1), beta) 106 | loss = torch.sum(-target * F.log_softmax(x, dim=-1) * inv_prob, dim=-1) 107 | return loss.mean() 108 | 109 | 110 | class MulMixturelLoss(nn.Module): 111 | 112 | def __init__(self, base_criterion, beta): 113 | super().__init__() 114 | 115 | if isinstance(base_criterion, LabelSmoothingCrossEntropy): 116 | self.base_criterion = MulMixLabelSmoothingCrossEntropy(base_criterion.smoothing) 117 | elif isinstance(base_criterion, SoftTargetCrossEntropy): 118 | self.base_criterion = MulMixSoftTargetCrossEntropy() 119 | else: 120 | raise ValueError("Unknown type") 121 | 122 | self.beta = beta 123 | 124 | def forward(self, logits, targets): 125 | n = len(logits) 126 | 127 | # CE losses 128 | ce_loss = [self.base_criterion(logits[i], targets, self.beta / (n - 1)) for i in range(n)] 129 | ce_loss = torch.sum(torch.stack(ce_loss, dim=0), dim=0) 130 | return ce_loss 131 | 132 | 133 | class SelfDistillationLoss(nn.Module): 134 | 135 | def __init__(self, base_criterion, w, temperature=1.0): 136 | super().__init__() 137 | self.base_criterion = base_criterion 138 | self.kd_criterion = nn.KLDivLoss(reduction='batchmean', log_target=True) 139 | self.w = w 140 | self.T = temperature 141 | 142 | def forward(self, logits, targets): 143 | # logits is a list, the first one is the reference logits for self-distillation 144 | 145 | # CE losses 146 | ce_loss = self.base_criterion(logits[1], targets) 147 | 148 | # KD Loss 149 | kd_loss = self.kd_criterion( 150 | F.log_softmax(logits[1] / self.T, dim=1), 151 | F.log_softmax(logits[0] / self.T, dim=1).detach() 152 | ) * self.T * self.T 153 | 154 | total_loss = (1.0 - self.w) * ce_loss + self.w * kd_loss 155 | return total_loss, ce_loss.detach(), kd_loss.detach() 156 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | import datetime 9 | import numpy as np 10 | import time 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import json 14 | import os 15 | import warnings 16 | 17 | from pathlib import Path 18 | 19 | from timm.data import Mixup 20 | from timm.models import create_model 21 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 22 | from timm.scheduler import create_scheduler 23 | from timm.optim import create_optimizer 24 | from timm.utils import NativeScaler, get_state_dict, ModelEma 25 | 26 | #from datasets import build_dataset 27 | from engine import train_one_epoch, evaluate 28 | from samplers import RASampler 29 | import models 30 | import my_models 31 | import torch.nn as nn 32 | #import simclr 33 | import utils 34 | from losses import DeepMutualLoss, ONELoss, MulMixturelLoss, SelfDistillationLoss 35 | 36 | from video_dataset import VideoDataSet, VideoDataSetLMDB, VideoDataSetOnline 37 | from video_dataset_aug import get_augmentor, build_dataflow 38 | from video_dataset_config import get_dataset_config, DATASET_CONFIG 39 | 40 | warnings.filterwarnings("ignore", category=UserWarning) 41 | #torch.multiprocessing.set_start_method('spawn', force=True) 42 | 43 | def get_args_parser(): 44 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 45 | parser.add_argument('--batch-size', default=64, type=int) 46 | parser.add_argument('--epochs', default=300, type=int) 47 | 48 | # Dataset parameters 49 | parser.add_argument('--data_dir', type=str, metavar='DIR', help='path to dataset') 50 | parser.add_argument('--dataset', default='st2stv2', 51 | choices=list(DATASET_CONFIG.keys()), help='path to dataset file list') 52 | parser.add_argument('--duration', default=8, type=int, help='number of frames') 53 | parser.add_argument('--frames_per_group', default=1, type=int, 54 | help='[uniform sampling] number of frames per group; ' 55 | '[dense sampling]: sampling frequency') 56 | parser.add_argument('--threed_data', action='store_true', 57 | help='load data in the layout for 3D conv') 58 | parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size') 59 | parser.add_argument('--disable_scaleup', action='store_true', 60 | help='do not scale up and then crop a small region, directly crop the input_size') 61 | parser.add_argument('--random_sampling', action='store_true', 62 | help='perform determinstic sampling for data loader') 63 | parser.add_argument('--dense_sampling', action='store_true', 64 | help='perform dense sampling for data loader') 65 | parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'], 66 | help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`') 67 | parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+", 68 | metavar='scale_range', help='scale range for augmentor v2') 69 | parser.add_argument('--modality', default='rgb', type=str, help='rgb or flow', 70 | choices=['rgb', 'flow']) 71 | parser.add_argument('--use_lmdb', action='store_true', help='use lmdb instead of jpeg.') 72 | parser.add_argument('--use_pyav', action='store_true', help='use video directly.') 73 | 74 | # temporal module 75 | parser.add_argument('--pretrained', action='store_true', default=False, 76 | help='Start with pretrained version of specified network (if avail)') 77 | parser.add_argument('--temporal_module_name', default=None, type=str, metavar='TEM', choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'], 78 | help='temporal module applied. [TAM]') 79 | parser.add_argument('--temporal_attention_only', action='store_true', default=False, 80 | help='use attention only in temporal module]') 81 | parser.add_argument('--no_token_mask', action='store_true', default=False, help='do not apply token mask') 82 | parser.add_argument('--temporal_heads_scale', default=1.0, type=float, help='scale of the number of spatial heads') 83 | parser.add_argument('--temporal_mlp_scale', default=1.0, type=float, help='scale of spatial mlp') 84 | parser.add_argument('--rel_pos', action='store_true', default=False, 85 | help='use relative positioning in temporal module]') 86 | parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv'], 87 | help='perform temporal pooling]') 88 | parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'], 89 | help='use depth-wise bottleneck in temporal attention') 90 | 91 | parser.add_argument('--window_size', default=7, type=int, help='number of frames') 92 | parser.add_argument('--super_img_rows', default=1, type=int, help='number of frames per row') 93 | 94 | parser.add_argument('--hpe_to_token', default=False, action='store_true', 95 | help='add hub position embedding to image tokens') 96 | # Model parameters 97 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 98 | help='Name of model to train') 99 | # parser.add_argument('--input-size', default=224, type=int, help='images input size') 100 | 101 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 102 | help='Dropout rate (default: 0.)') 103 | parser.add_argument('--drop-path', type=float, default=0.0, metavar='PCT', 104 | help='Drop path rate (default: 0.1)') 105 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 106 | help='Drop block rate (default: None)') 107 | 108 | parser.add_argument('--model-ema', action='store_true') 109 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 110 | parser.set_defaults(model_ema=True) 111 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 112 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 113 | 114 | # Optimizer parameters 115 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 116 | help='Optimizer (default: "adamw"') 117 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 118 | help='Optimizer Epsilon (default: 1e-8)') 119 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 120 | help='Optimizer Betas (default: None, use opt default)') 121 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 122 | help='Clip gradient norm (default: None, no clipping)') 123 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 124 | help='SGD momentum (default: 0.9)') 125 | parser.add_argument('--weight-decay', type=float, default=0.05, 126 | help='weight decay (default: 0.05)') 127 | # Learning rate schedule parameters 128 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 129 | help='LR scheduler (default: "cosine"') 130 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 131 | help='learning rate (default: 5e-4)') 132 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 133 | help='learning rate noise on/off epoch percentages') 134 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 135 | help='learning rate noise limit percent (default: 0.67)') 136 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 137 | help='learning rate noise std-dev (default: 1.0)') 138 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 139 | help='warmup learning rate (default: 1e-6)') 140 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 141 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 142 | 143 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 144 | help='epoch interval to decay LR') 145 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 146 | help='epochs to warmup LR, if scheduler supports') 147 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 148 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 149 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 150 | help='patience epochs for Plateau LR scheduler (default: 10') 151 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 152 | help='LR decay rate (default: 0.1)') 153 | 154 | # Augmentation parameters 155 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 156 | help='Color jitter factor (default: 0.4)') 157 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 158 | help='Use AutoAugment policy. "v0" or "original". " + \ 159 | "(default: rand-m9-mstd0.5-inc1)'), 160 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 161 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 162 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 163 | 164 | parser.add_argument('--repeated-aug', action='store_true') 165 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 166 | parser.set_defaults(repeated_aug=False) 167 | 168 | # * Random Erase params 169 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', 170 | help='Random erase prob (default: 0.25)') 171 | parser.add_argument('--remode', type=str, default='pixel', 172 | help='Random erase mode (default: "pixel")') 173 | parser.add_argument('--recount', type=int, default=1, 174 | help='Random erase count (default: 1)') 175 | parser.add_argument('--resplit', action='store_true', default=False, 176 | help='Do not random erase first (clean) augmentation split') 177 | 178 | # * Mixup params 179 | parser.add_argument('--mixup', type=float, default=0.0, 180 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 181 | parser.add_argument('--cutmix', type=float, default=0.0, 182 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 183 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 184 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 185 | parser.add_argument('--mixup-prob', type=float, default=1.0, 186 | help='Probability of performing mixup or cutmix when either/both is enabled') 187 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 188 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 189 | parser.add_argument('--mixup-mode', type=str, default='batch', 190 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 191 | 192 | # Dataset parameters 193 | # parser.add_argument('--data-path', default=os.path.join(os.path.expanduser("~"), 'datasets/image_cls/imagenet1k/'), type=str, 194 | # help='dataset path') 195 | # parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19', 'IMNET21K', 'Flowers102', 'StanfordCars', 'iNaturalist2019', 'Caltech101'], 196 | # type=str, help='Image Net dataset path') 197 | # parser.add_argument('--inat-category', default='name', 198 | # choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 199 | # type=str, help='semantic granularity') 200 | 201 | parser.add_argument('--output_dir', default='', 202 | help='path where to save, empty for no saving') 203 | parser.add_argument('--device', default='cuda', 204 | help='device to use for training / testing') 205 | parser.add_argument('--seed', default=0, type=int) 206 | parser.add_argument('--resume', default='', help='resume from checkpoint') 207 | parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler') 208 | parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp') 209 | parser.add_argument('--use_checkpoint', default=False, action='store_true', help='use checkpoint to save memory') 210 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 211 | help='start epoch') 212 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 213 | parser.add_argument('--num_workers', default=10, type=int) 214 | parser.add_argument('--pin-mem', action='store_true', 215 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 216 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 217 | help='') 218 | parser.set_defaults(pin_mem=True) 219 | 220 | # for testing and validation 221 | parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10]) 222 | parser.add_argument('--num_clips', default=1, type=int) 223 | 224 | # distributed training parameters 225 | parser.add_argument('--world_size', default=1, type=int, 226 | help='number of distributed processes') 227 | parser.add_argument("--local_rank", type=int) 228 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 229 | 230 | 231 | parser.add_argument('--auto-resume', action='store_true', help='auto resume') 232 | # exp 233 | parser.add_argument('--simclr_w', type=float, default=0., help='weights for simclr loss') 234 | parser.add_argument('--contrastive_nomixup', action='store_true', help='do not involve mixup in contrastive learning') 235 | parser.add_argument('--temperature', type=float, default=0.07, help='temperature of NCE') 236 | parser.add_argument('--branch_div_w', type=float, default=0., help='add branch divergence in the loss') 237 | parser.add_argument('--simsiam_w', type=float, default=0., help='weights for simsiam loss') 238 | parser.add_argument('--moco_w', type=float, default=0., help='weights for moco loss') 239 | parser.add_argument('--byol_w', type=float, default=0., help='weights for byol loss') 240 | parser.add_argument('--finetune', action='store_true', help='finetune model') 241 | parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model') 242 | parser.add_argument('--dml_w', type=float, default=0., help='enable deep mutual learning') 243 | parser.add_argument('--one_w', type=float, default=0., help='enable ONE') 244 | parser.add_argument('--kd_temp', type=float, default=1.0, help='temperature for kd loss') 245 | parser.add_argument('--mulmix_b', type=float, default=0., help='mulmix beta') 246 | parser.add_argument('--hard_contrastive', action='store_true', help='use HEXA') 247 | parser.add_argument('--selfdis_w', type=float, default=0., help='enable self distillation') 248 | 249 | return parser 250 | 251 | 252 | def main(args): 253 | utils.init_distributed_mode(args) 254 | print(args) 255 | # Patch 256 | if not hasattr(args, 'hard_contrastive'): 257 | args.hard_contrastive = False 258 | if not hasattr(args, 'selfdis_w'): 259 | args.selfdis_w = 0.0 260 | 261 | #is_imnet21k = args.data_set == 'IMNET21K' 262 | 263 | device = torch.device(args.device) 264 | 265 | # fix the seed for reproducibility 266 | seed = args.seed + utils.get_rank() 267 | torch.manual_seed(seed) 268 | np.random.seed(seed) 269 | # random.seed(seed) 270 | 271 | cudnn.benchmark = True 272 | 273 | num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config( 274 | args.dataset, args.use_lmdb) 275 | 276 | args.num_classes = num_classes 277 | if args.modality == 'rgb': 278 | args.input_channels = 3 279 | elif args.modality == 'flow': 280 | args.input_channels = 2 * 5 281 | 282 | # mean = IMAGENET_DEFAULT_MEAN 283 | # std = IMAGENET_DEFAULT_STD 284 | 285 | mixup_fn = None 286 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 287 | if mixup_active: 288 | mixup_fn = Mixup( 289 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 290 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 291 | label_smoothing=args.smoothing, num_classes=args.num_classes) 292 | 293 | print(f"Creating model: {args.model}") 294 | model = create_model( 295 | args.model, 296 | pretrained=args.pretrained, 297 | duration=args.duration, 298 | hpe_to_token = args.hpe_to_token, 299 | rel_pos = args.rel_pos, 300 | window_size=args.window_size, 301 | super_img_rows = args.super_img_rows, 302 | token_mask=not args.no_token_mask, 303 | online_learning = args.one_w >0.0 or args.dml_w >0.0, 304 | num_classes=args.num_classes, 305 | drop_rate=args.drop, 306 | drop_path_rate=args.drop_path, 307 | drop_block_rate=args.drop_block, 308 | use_checkpoint=args.use_checkpoint 309 | ) 310 | 311 | # TODO: finetuning 312 | 313 | model.to(device) 314 | 315 | model_ema = None 316 | if args.model_ema: 317 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 318 | model_ema = ModelEma( 319 | model, 320 | decay=args.model_ema_decay, 321 | device='cpu' if args.model_ema_force_cpu else '', 322 | resume='') 323 | 324 | model_without_ddp = model 325 | if args.distributed: 326 | #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 327 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 328 | model_without_ddp = model.module 329 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 330 | print('number of params:', n_parameters) 331 | 332 | #linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 333 | #args.lr = linear_scaled_lr 334 | optimizer = create_optimizer(args, model) 335 | loss_scaler = NativeScaler() 336 | #print(f"Scaled learning rate (batch size: {args.batch_size * utils.get_world_size()}): {linear_scaled_lr}") 337 | lr_scheduler, _ = create_scheduler(args, optimizer) 338 | 339 | criterion = LabelSmoothingCrossEntropy() 340 | 341 | if args.mixup > 0.: 342 | # smoothing is handled with mixup label transform 343 | criterion = SoftTargetCrossEntropy() 344 | elif args.smoothing: 345 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 346 | else: 347 | criterion = torch.nn.CrossEntropyLoss() 348 | 349 | if args.dml_w > 0.: 350 | criterion = DeepMutualLoss(criterion, args.dml_w, args.kd_temp) 351 | elif args.one_w > 0.: 352 | criterion = ONELoss(criterion, args.one_w, args.kd_temp) 353 | elif args.mulmix_b > 0.: 354 | criterion = MulMixturelLoss(criterion, args.mulmix_b) 355 | elif args.selfdis_w > 0.: 356 | criterion = SelfDistillationLoss(criterion, args.selfdis_w, args.kd_temp) 357 | 358 | simclr_criterion = simclr.NTXent(temperature=args.temperature) if args.simclr_w > 0. else None 359 | branch_div_criterion = torch.nn.CosineSimilarity() if args.branch_div_w > 0. else None 360 | simsiam_criterion = simclr.SimSiamLoss() if args.simsiam_w > 0. else None 361 | moco_criterion = torch.nn.CrossEntropyLoss() if args.moco_w > 0. else None 362 | byol_criterion = simclr.BYOLLoss() if args.byol_w > 0. else None 363 | 364 | max_accuracy = 0.0 365 | output_dir = Path(args.output_dir) 366 | 367 | if args.initial_checkpoint: 368 | print("Loading pretrained model") 369 | checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 370 | utils.load_checkpoint(model, checkpoint['model']) 371 | 372 | if args.auto_resume: 373 | if args.resume == '': 374 | args.resume = str(output_dir / "checkpoint.pth") 375 | if not os.path.exists(args.resume): 376 | args.resume = '' 377 | 378 | if args.resume: 379 | if args.resume.startswith('https'): 380 | checkpoint = torch.hub.load_state_dict_from_url( 381 | args.resume, map_location='cpu', check_hash=True) 382 | else: 383 | checkpoint = torch.load(args.resume, map_location='cpu') 384 | utils.load_checkpoint(model, checkpoint['model']) 385 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 386 | optimizer.load_state_dict(checkpoint['optimizer']) 387 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 388 | args.start_epoch = checkpoint['epoch'] + 1 389 | if 'scaler' in checkpoint and args.resume_loss_scaler: 390 | print("Resume with previous loss scaler state") 391 | loss_scaler.load_state_dict(checkpoint['scaler']) 392 | if args.model_ema: 393 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 394 | max_accuracy = checkpoint['max_accuracy'] 395 | 396 | mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean'] 397 | std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std'] 398 | 399 | # dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 400 | # create data loaders w/ augmentation pipeiine 401 | if args.use_lmdb: 402 | video_data_cls = VideoDataSetLMDB 403 | elif args.use_pyav: 404 | video_data_cls = VideoDataSetOnline 405 | else: 406 | video_data_cls = VideoDataSet 407 | train_list = os.path.join(args.data_dir, train_list_name) 408 | 409 | train_augmentor = get_augmentor(True, args.input_size, mean, std, threed_data=args.threed_data, 410 | version=args.augmentor_ver, scale_range=args.scale_range, dataset=args.dataset) 411 | dataset_train = video_data_cls(args.data_dir, train_list, args.duration, args.frames_per_group, 412 | num_clips=args.num_clips, 413 | modality=args.modality, image_tmpl=image_tmpl, 414 | dense_sampling=args.dense_sampling, 415 | transform=train_augmentor, is_train=True, test_mode=False, 416 | seperator=filename_seperator, filter_video=filter_video) 417 | 418 | num_tasks = utils.get_world_size() 419 | data_loader_train = build_dataflow(dataset_train, is_train=True, batch_size=args.batch_size, 420 | workers=args.num_workers, is_distributed=args.distributed) 421 | 422 | val_list = os.path.join(args.data_dir, val_list_name) 423 | val_augmentor = get_augmentor(False, args.input_size, mean, std, args.disable_scaleup, 424 | threed_data=args.threed_data, version=args.augmentor_ver, 425 | scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops, dataset=args.dataset) 426 | dataset_val = video_data_cls(args.data_dir, val_list, args.duration, args.frames_per_group, 427 | num_clips=args.num_clips, 428 | modality=args.modality, image_tmpl=image_tmpl, 429 | dense_sampling=args.dense_sampling, 430 | transform=val_augmentor, is_train=False, test_mode=False, 431 | seperator=filename_seperator, filter_video=filter_video) 432 | 433 | data_loader_val = build_dataflow(dataset_val, is_train=False, batch_size=args.batch_size, 434 | workers=args.num_workers, is_distributed=args.distributed) 435 | 436 | 437 | if args.eval: 438 | test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips) 439 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 440 | return 441 | 442 | print(f"Start training, currnet max acc is {max_accuracy:.2f}") 443 | start_time = time.time() 444 | for epoch in range(args.start_epoch, args.epochs): 445 | 446 | if args.distributed: 447 | data_loader_train.sampler.set_epoch(epoch) 448 | 449 | train_stats = train_one_epoch( 450 | model, criterion, data_loader_train, 451 | optimizer, device, epoch, loss_scaler, 452 | args.clip_grad, model_ema, mixup_fn, num_tasks, True, 453 | amp=args.amp, 454 | simclr_criterion=simclr_criterion, simclr_w=args.simclr_w, 455 | branch_div_criterion=branch_div_criterion, branch_div_w=args.branch_div_w, 456 | simsiam_criterion=simsiam_criterion, simsiam_w=args.simsiam_w, 457 | moco_criterion=moco_criterion, moco_w=args.moco_w, 458 | byol_criterion=byol_criterion, byol_w=args.byol_w, 459 | contrastive_nomixup=args.contrastive_nomixup, 460 | hard_contrastive=args.hard_contrastive, 461 | finetune=args.finetune 462 | ) 463 | 464 | lr_scheduler.step(epoch) 465 | 466 | test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) 467 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 468 | 469 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 470 | print(f'Max accuracy: {max_accuracy:.2f}%') 471 | 472 | if args.output_dir: 473 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 474 | if test_stats["acc1"] == max_accuracy: 475 | checkpoint_paths.append(output_dir / 'model_best.pth') 476 | for checkpoint_path in checkpoint_paths: 477 | state_dict = { 478 | 'model': model_without_ddp.state_dict(), 479 | 'optimizer': optimizer.state_dict(), 480 | 'lr_scheduler': lr_scheduler.state_dict(), 481 | 'epoch': epoch, 482 | 'args': args, 483 | 'scaler': loss_scaler.state_dict(), 484 | 'max_accuracy': max_accuracy 485 | } 486 | if args.model_ema: 487 | state_dict['model_ema'] = get_state_dict(model_ema) 488 | utils.save_on_master(state_dict, checkpoint_path) 489 | 490 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 491 | **{f'test_{k}': v for k, v in test_stats.items()}, 492 | 'epoch': epoch, 493 | 'n_parameters': n_parameters} 494 | 495 | if args.output_dir and utils.is_main_process(): 496 | with (output_dir / "log.txt").open("a") as f: 497 | f.write(json.dumps(log_stats) + "\n") 498 | 499 | 500 | total_time = time.time() - start_time 501 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 502 | print('Training time {}'.format(total_time_str)) 503 | 504 | 505 | if __name__ == '__main__': 506 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 507 | args = parser.parse_args() 508 | if args.output_dir: 509 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 510 | main(args) 511 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial 10 | 11 | from timm.models.vision_transformer import VisionTransformer, _cfg 12 | from timm.models.registry import register_model 13 | 14 | 15 | @register_model 16 | def deit_tiny_patch8_224(pretrained=False, **kwargs): 17 | model = VisionTransformer( 18 | patch_size=8, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 19 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 20 | model.default_cfg = _cfg() 21 | if pretrained: 22 | checkpoint = torch.hub.load_state_dict_from_url( 23 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 24 | map_location="cpu", check_hash=True 25 | ) 26 | model.load_state_dict(checkpoint["model"]) 27 | return model 28 | 29 | 30 | @register_model 31 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 32 | model = VisionTransformer( 33 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 34 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 35 | model.default_cfg = _cfg() 36 | if pretrained: 37 | checkpoint = torch.hub.load_state_dict_from_url( 38 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 39 | map_location="cpu", check_hash=True 40 | ) 41 | model.load_state_dict(checkpoint["model"]) 42 | return model 43 | 44 | 45 | @register_model 46 | def deit_tiny_patch16_d_6_224(pretrained=False, **kwargs): 47 | model = VisionTransformer( 48 | patch_size=16, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4, qkv_bias=True, 49 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 50 | model.default_cfg = _cfg() 51 | if pretrained: 52 | checkpoint = torch.hub.load_state_dict_from_url( 53 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 54 | map_location="cpu", check_hash=True 55 | ) 56 | model.load_state_dict(checkpoint["model"]) 57 | return model 58 | 59 | 60 | @register_model 61 | def deit_tiny_patch32_224(pretrained=False, **kwargs): 62 | model = VisionTransformer( 63 | patch_size=32, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | model.default_cfg = _cfg() 66 | if pretrained: 67 | checkpoint = torch.hub.load_state_dict_from_url( 68 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 69 | map_location="cpu", check_hash=True 70 | ) 71 | model.load_state_dict(checkpoint["model"]) 72 | return model 73 | 74 | 75 | @register_model 76 | def deit_small_patch8_224(pretrained=False, **kwargs): 77 | model = VisionTransformer( 78 | patch_size=8, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 79 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 80 | model.default_cfg = _cfg() 81 | if pretrained: 82 | checkpoint = torch.hub.load_state_dict_from_url( 83 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 84 | map_location="cpu", check_hash=True 85 | ) 86 | model.load_state_dict(checkpoint["model"]) 87 | return model 88 | 89 | 90 | @register_model 91 | def deit_small_patch16_224(pretrained=False, **kwargs): 92 | model = VisionTransformer( 93 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 94 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 95 | model.default_cfg = _cfg() 96 | if pretrained: 97 | checkpoint = torch.hub.load_state_dict_from_url( 98 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 99 | map_location="cpu", check_hash=True 100 | ) 101 | model.load_state_dict(checkpoint["model"]) 102 | return model 103 | 104 | 105 | @register_model 106 | def deit_small_patch16_d_6_224(pretrained=False, **kwargs): 107 | model = VisionTransformer( 108 | patch_size=16, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4, qkv_bias=True, 109 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 110 | model.default_cfg = _cfg() 111 | if pretrained: 112 | checkpoint = torch.hub.load_state_dict_from_url( 113 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 114 | map_location="cpu", check_hash=True 115 | ) 116 | model.load_state_dict(checkpoint["model"]) 117 | return model 118 | 119 | 120 | @register_model 121 | def deit_small_patch32_224(pretrained=False, **kwargs): 122 | model = VisionTransformer( 123 | patch_size=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 125 | model.default_cfg = _cfg() 126 | if pretrained: 127 | checkpoint = torch.hub.load_state_dict_from_url( 128 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 129 | map_location="cpu", check_hash=True 130 | ) 131 | model.load_state_dict(checkpoint["model"]) 132 | return model 133 | 134 | 135 | @register_model 136 | def deit_base_patch8_224(pretrained=False, **kwargs): 137 | model = VisionTransformer( 138 | patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 139 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 140 | model.default_cfg = _cfg() 141 | if pretrained: 142 | checkpoint = torch.hub.load_state_dict_from_url( 143 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 144 | map_location="cpu", check_hash=True 145 | ) 146 | model.load_state_dict(checkpoint["model"]) 147 | return model 148 | 149 | 150 | @register_model 151 | def deit_base_patch16_224(pretrained=False, **kwargs): 152 | model = VisionTransformer( 153 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 154 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 155 | model.default_cfg = _cfg() 156 | if pretrained: 157 | checkpoint = torch.hub.load_state_dict_from_url( 158 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 159 | map_location="cpu", check_hash=True 160 | ) 161 | model.load_state_dict(checkpoint["model"]) 162 | return model 163 | 164 | 165 | @register_model 166 | def deit_base_patch16_ft_224(pretrained=False, **kwargs): 167 | model = VisionTransformer( 168 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 169 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 170 | model.default_cfg = _cfg() 171 | if pretrained: 172 | checkpoint = torch.hub.load_state_dict_from_url( 173 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 174 | map_location="cpu", check_hash=True 175 | ) 176 | model.load_state_dict(checkpoint["model"]) 177 | 178 | for m in model.parameters(): 179 | m.requires_grad = False 180 | 181 | for m in model.head.parameters(): 182 | m.requires_grad = True 183 | 184 | return model 185 | 186 | 187 | 188 | @register_model 189 | def deit_base24_patch16_224(pretrained=False, **kwargs): 190 | model = VisionTransformer( 191 | patch_size=16, embed_dim=768, depth=24, num_heads=12, mlp_ratio=4, qkv_bias=True, 192 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 193 | model.default_cfg = _cfg() 194 | if pretrained: 195 | checkpoint = torch.hub.load_state_dict_from_url( 196 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 197 | map_location="cpu", check_hash=True 198 | ) 199 | model.load_state_dict(checkpoint["model"]) 200 | return model 201 | 202 | 203 | @register_model 204 | def deit_base16_patch16_224(pretrained=False, **kwargs): 205 | model = VisionTransformer( 206 | patch_size=16, embed_dim=768, depth=16, num_heads=12, mlp_ratio=4, qkv_bias=True, 207 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 208 | model.default_cfg = _cfg() 209 | if pretrained: 210 | checkpoint = torch.hub.load_state_dict_from_url( 211 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 212 | map_location="cpu", check_hash=True 213 | ) 214 | model.load_state_dict(checkpoint["model"]) 215 | return model 216 | 217 | 218 | @register_model 219 | def deit_base_patch16_384(pretrained=False, **kwargs): 220 | model = VisionTransformer(img_size=384, 221 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 222 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 223 | model.default_cfg = _cfg() 224 | if pretrained: 225 | checkpoint = torch.hub.load_state_dict_from_url( 226 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 227 | map_location="cpu", check_hash=True 228 | ) 229 | model.load_state_dict(checkpoint["model"]) 230 | return model 231 | 232 | 233 | @register_model 234 | def deit_base_patch32_224(pretrained=False, **kwargs): 235 | model = VisionTransformer( 236 | patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 237 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 238 | model.default_cfg = _cfg() 239 | if pretrained: 240 | checkpoint = torch.hub.load_state_dict_from_url( 241 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 242 | map_location="cpu", check_hash=True 243 | ) 244 | model.load_state_dict(checkpoint["model"]) 245 | return model 246 | -------------------------------------------------------------------------------- /my_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sifar_swin import * 2 | -------------------------------------------------------------------------------- /my_models/action_conv.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from timm.models.registry import register_model 12 | import logging 13 | from einops import rearrange, reduce, repeat 14 | from timm.models import resnet50, tv_resnet101, tv_resnet152 15 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | import torchvision.models as models 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | def _cfg(url='', **kwargs): 21 | return { 22 | 'url': url, 23 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 24 | 'crop_pct': 0.875, 'interpolation': 'bilinear', 25 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 26 | 'first_conv': 'conv1', 'classifier': 'fc', 27 | **kwargs 28 | } 29 | 30 | default_cfgs = { 31 | # ResNet 32 | 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 33 | 'resnet34': _cfg( 34 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), 35 | 'resnet50': _cfg( 36 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', 37 | interpolation='bicubic'), 38 | 'resnet101': _cfg(url='', interpolation='bicubic'), 39 | 'resnet101d': _cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth', 41 | interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), 42 | crop_pct=1.0, test_input_size=(3, 320, 320)), 43 | 'resnet152': _cfg(url='', interpolation='bicubic'), 44 | 'resnet200': _cfg(url='', interpolation='bicubic'), 45 | } 46 | 47 | class ConvActionModule(nn.Module): 48 | r""" Swin Transformer 49 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 50 | https://arxiv.org/pdf/2103.14030 51 | 52 | Args: 53 | img_size (int | tuple(int)): Input image size. Default 224 54 | patch_size (int | tuple(int)): Patch size. Default: 4 55 | in_chans (int): Number of input image channels. Default: 3 56 | num_classes (int): Number of classes for classification head. Default: 1000 57 | embed_dim (int): Patch embedding dimension. Default: 96 58 | depths (tuple(int)): Depth of each Swin Transformer layer. 59 | num_heads (tuple(int)): Number of attention heads in different layers. 60 | window_size (int): Window size. Default: 7 61 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 62 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 63 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 64 | drop_rate (float): Dropout rate. Default: 0 65 | attn_drop_rate (float): Attention dropout rate. Default: 0 66 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 67 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 68 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 69 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 70 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 71 | """ 72 | 73 | def __init__(self, backbone=None, duration=8, img_size=224, in_chans=3, num_classes=1000, num_features=0, 74 | super_img_rows=1, default_cfg=None, **kwargs): 75 | super().__init__() 76 | self.backbone = backbone 77 | self.num_features = int(num_features) 78 | self.duration = duration 79 | self.num_classes = num_classes 80 | self.super_img_rows = super_img_rows 81 | self.default_cfg = default_cfg 82 | 83 | self.img_size = img_size 84 | self.frame_padding = self.duration % super_img_rows 85 | if self.frame_padding != 0: 86 | self.frame_padding = self.super_img_rows - self.frame_padding 87 | self.duration += self.frame_padding 88 | # assert (self.duration % super_img_rows) == 0, 'number of fames must be a multiple of the rows of the super image' 89 | 90 | self.avgpool = nn.AdaptiveAvgPool2d(1) 91 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 92 | 93 | self.apply(self._init_weights) 94 | 95 | print('image_size:', self.img_size, 'padding frame:', self.frame_padding, 'super_img_size:', (super_img_rows, self.duration // super_img_rows)) 96 | 97 | def _init_weights(self, m): 98 | if isinstance(m, nn.Linear): 99 | trunc_normal_(m.weight, std=.02) 100 | if isinstance(m, nn.Linear) and m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | 103 | def pad_frames(self, x): 104 | frame_num = self.duration - self.frame_padding 105 | x = x.view((-1,3*frame_num)+x.size()[2:]) 106 | x_padding = torch.zeros((x.shape[0], 3*self.frame_padding) + x.size()[2:]).cuda() 107 | x = torch.cat((x, x_padding), dim=1) 108 | assert x.shape[1] == 3 * self.duration, 'frame number %d not the same as adjusted input size %d' % (x.shape[1], 3 * self.duration) 109 | 110 | return x 111 | 112 | def create_super_img(self, x): 113 | input_size = x.shape[-2:] 114 | if input_size != self.img_size: 115 | x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') 116 | 117 | x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', th=self.super_img_rows, c=3) 118 | return x 119 | 120 | def forward_features(self, x): 121 | # x = rearrange(x, 'b (t c) h w -> b c h (t w)', t=self.duration) 122 | # in evaluation, it's Bx(num_crops*num_cips*num_frames*3)xHxW 123 | if self.frame_padding > 0: 124 | x = self.pad_frames(x) 125 | else: 126 | x = x.view((-1,3*self.duration)+x.size()[2:]) 127 | 128 | x = self.create_super_img(x) 129 | 130 | x = self.backbone.forward_features(x) 131 | x = self.avgpool(x) 132 | x = torch.flatten(x, 1) 133 | return x 134 | 135 | def forward(self, x): 136 | x = self.forward_features(x) 137 | x = self.head(x) 138 | 139 | return x 140 | 141 | 142 | @register_model 143 | def action_conv_resnet50(pretrained=False, **kwargs): 144 | 145 | num_features = 2048 146 | model = ConvActionModule(backbone=None, num_features=num_features, **kwargs) 147 | 148 | backbone = resnet50(pretrained=pretrained) 149 | model.backbone = backbone 150 | model.default_cfg = backbone.default_cfg 151 | 152 | return model 153 | 154 | @register_model 155 | def action_conv_resnet101(pretrained=False, **kwargs): 156 | 157 | num_features = 2048 158 | model = ConvActionModule(backbone=None, num_features=num_features, **kwargs) 159 | 160 | backbone = tv_resnet101(pretrained=pretrained) 161 | model.backbone = backbone 162 | model.default_cfg = backbone.default_cfg 163 | 164 | return model 165 | 166 | @register_model 167 | def action_conv_resnet152(pretrained=False, **kwargs): 168 | 169 | num_features = 2048 170 | model = ConvActionModule(backbone=None, num_features=num_features, **kwargs) 171 | 172 | backbone = tv_resnet152(pretrained=pretrained) 173 | model.backbone = backbone 174 | model.default_cfg = backbone.default_cfg 175 | 176 | return model 177 | 178 | ''' 179 | @register_model 180 | def action_tf_efficientnetv2_m_in21k(pretrained=False, **kwargs): 181 | num_features = 2048 182 | model_kwargs = dict(num_features=num_features, **kwargs) 183 | model = ConvActionModule(backbone=None, **model_kwargs) 184 | 185 | backbone = action_tf_efficientnetv2_m_in21k(pretrained=pretrained) 186 | print (backone) 187 | model.backbone = backbone 188 | model.default_cfg = backbone.default_cfga 189 | 190 | return model 191 | ''' 192 | -------------------------------------------------------------------------------- /my_models/sifar_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | 5 | from timm.models.layers import to_2tuple 6 | 7 | def create_super_img(x, img_size, super_img_rows): 8 | input_size = x.shape[-2:] 9 | 10 | if not isinstance(img_size, tuple): 11 | img_size = to_2tuple(img_size) 12 | 13 | if input_size != img_size: 14 | x = nn.functional.interpolate(x, size=img_size, mode='bilinear') 15 | x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', th=super_img_rows, c=3) 16 | return x 17 | 18 | def frames_to_super_image(x, super_img_rows, super_img_cols, img_h, img_w): 19 | x = rearrange(x, '(b th tw) (h w) c -> b (th h tw w) c', th=super_img_rows, tw=super_img_cols, h=img_h, w=img_w) 20 | return x 21 | 22 | def super_image_to_frames(x, super_img_rows, super_img_cols, img_h, img_w): 23 | x = rearrange(x, 'b (th h tw w) c -> (b th tw) (h w) c', th=super_img_rows, tw=super_img_cols, h=img_h, w=img_w) 24 | return x 25 | 26 | def pad_frames(x, duration, frame_padding): 27 | frame_num = duration - frame_padding 28 | x = x.view((-1, 3 * frame_num) + x.size()[2:]) 29 | x_padding = torch.zeros((x.shape[0], 3 * frame_padding) + x.size()[2:]).cuda() 30 | x = torch.cat((x, x_padding), dim=1) 31 | assert x.shape[1] == 3 * duration, 'frame number %d not the same as adjusted input size %d' % ( 32 | x.shape[1], 3 * duration) 33 | 34 | return x 35 | 36 | def get_super_img_layout(duration, img_rows): 37 | return (img_rows, duration // img_rows) 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | timm==0.3.2 4 | fvcore 5 | einops 6 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.distributed as dist 9 | import math 10 | 11 | 12 | class RASampler(torch.utils.data.Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset for distributed, 14 | with repeated augmentation. 15 | It ensures that different each augmented version of a sample will be visible to a 16 | different process (GPU) 17 | Heavily based on torch.utils.data.DistributedSampler 18 | """ 19 | 20 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 21 | if num_replicas is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | num_replicas = dist.get_world_size() 25 | if rank is None: 26 | if not dist.is_available(): 27 | raise RuntimeError("Requires distributed package to be available") 28 | rank = dist.get_rank() 29 | self.dataset = dataset 30 | self.num_replicas = num_replicas 31 | self.rank = rank 32 | self.epoch = 0 33 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 36 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 37 | self.shuffle = shuffle 38 | 39 | def __iter__(self): 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | if self.shuffle: 44 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 45 | else: 46 | indices = list(range(len(self.dataset))) 47 | 48 | # add extra samples to make it evenly divisible 49 | indices = [ele for ele in indices for i in range(3)] 50 | indices += indices[:(self.total_size - len(indices))] 51 | assert len(indices) == self.total_size 52 | 53 | # subsample 54 | indices = indices[self.rank:self.total_size:self.num_replicas] 55 | assert len(indices) == self.num_samples 56 | 57 | return iter(indices[:self.num_selected_samples]) 58 | 59 | def __len__(self): 60 | return self.num_selected_samples 61 | 62 | def set_epoch(self, epoch): 63 | self.epoch = epoch 64 | -------------------------------------------------------------------------------- /sifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/sifar-pytorch/3ac9103245b98a4916dd45bcdf0167d01b5f9b38/sifar.png -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | 12 | 13 | class SupConLoss(nn.Module): 14 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 15 | It also supports the unsupervised contrastive loss in SimCLR""" 16 | 17 | def __init__(self, temperature=0.07, contrast_mode='all', 18 | base_temperature=0.07): 19 | super(SupConLoss, self).__init__() 20 | self.temperature = temperature 21 | self.contrast_mode = contrast_mode 22 | self.base_temperature = base_temperature 23 | 24 | def forward(self, features, labels=None, mask=None): 25 | """Compute loss for model. If both `labels` and `mask` are None, 26 | it degenerates to SimCLR unsupervised loss: 27 | https://arxiv.org/pdf/2002.05709.pdf 28 | 29 | Args: 30 | features: hidden vector of shape [bsz, n_views, ...]. 31 | labels: ground truth of shape [bsz]. 32 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 33 | has the same class as sample i. Can be asymmetric. 34 | Returns: 35 | A loss scalar. 36 | """ 37 | device = (torch.device('cuda') 38 | if features.is_cuda 39 | else torch.device('cpu')) 40 | 41 | if len(features.shape) < 3: 42 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 43 | 'at least 3 dimensions are required') 44 | if len(features.shape) > 3: 45 | features = features.view(features.shape[0], features.shape[1], -1) 46 | 47 | batch_size = features.shape[0] 48 | if labels is not None and mask is not None: 49 | raise ValueError('Cannot define both `labels` and `mask`') 50 | elif labels is None and mask is None: 51 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 52 | elif labels is not None: 53 | labels = labels.contiguous().view(-1, 1) 54 | if labels.shape[0] != batch_size: 55 | raise ValueError('Num of labels does not match num of features') 56 | mask = torch.eq(labels, labels.T).float().to(device) 57 | else: 58 | mask = mask.float().to(device) 59 | 60 | contrast_count = features.shape[1] 61 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 62 | if self.contrast_mode == 'one': 63 | anchor_feature = features[:, 0] 64 | anchor_count = 1 65 | elif self.contrast_mode == 'all': 66 | anchor_feature = contrast_feature 67 | anchor_count = contrast_count 68 | else: 69 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 70 | 71 | # compute logits 72 | anchor_dot_contrast = torch.div( 73 | torch.matmul(anchor_feature, contrast_feature.T), 74 | self.temperature) 75 | # for numerical stability 76 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 77 | logits = anchor_dot_contrast - logits_max.detach() 78 | 79 | # tile mask 80 | mask = mask.repeat(anchor_count, contrast_count) 81 | # mask-out self-contrast cases 82 | logits_mask = torch.scatter( 83 | torch.ones_like(mask), 84 | 1, 85 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 86 | 0 87 | ) 88 | mask = mask * logits_mask 89 | 90 | # compute log_prob 91 | exp_logits = torch.exp(logits) * logits_mask 92 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 93 | 94 | # compute mean of log-likelihood over positive 95 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 96 | 97 | # loss 98 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 99 | loss = loss.view(anchor_count, batch_size).mean() 100 | 101 | return loss 102 | 103 | 104 | class TwoCropTransform: 105 | """Create two crops of the same image""" 106 | 107 | def __init__(self, transform): 108 | self.transform = transform 109 | 110 | def __call__(self, x): 111 | return [self.transform(x), self.transform(x)] 112 | 113 | 114 | class SimSiamLoss(nn.Module): 115 | 116 | def __init__(self, version='simplified'): 117 | super().__init__() 118 | self.version = version 119 | 120 | def forward(self, z1, z2, p1, p2): 121 | 122 | def _loss(p, z, version): 123 | if version == 'original': 124 | z = z.detach() # stop gradient 125 | p = F.normalize(p, dim=1) # l2-normalize 126 | z = F.normalize(z, dim=1) # l2-normalize 127 | return -(p * z).sum(dim=1).mean() 128 | elif version == 'simplified': # same thing, much faster. Scroll down, speed test in __main__ 129 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 130 | else: 131 | raise Exception 132 | 133 | out = torch.mean(_loss(p1, z2, self.version) / 2 + _loss(p2, z1, self.version) / 2) 134 | return out 135 | 136 | 137 | class BYOLLoss(nn.Module): 138 | 139 | def __init__(self): 140 | super().__init__() 141 | 142 | def forward(self, online_pred_one, online_pred_two, target_proj_one, target_proj_two): 143 | 144 | def loss_fn(x, y): 145 | x = F.normalize(x, dim=-1, p=2) 146 | y = F.normalize(y, dim=-1, p=2) 147 | return 2 - 2 * (x * y).sum(dim=-1) 148 | 149 | loss_one = loss_fn(online_pred_one, target_proj_two.detach()) 150 | loss_two = loss_fn(online_pred_two, target_proj_one.detach()) 151 | 152 | loss = loss_one + loss_two 153 | return loss.mean() 154 | 155 | 156 | # https://github.com/Spijkervet/SimCLR/blob/654f05f107ce17c0a9db385f298a2dc6f8b3b870/modules/nt_xent.py 157 | class GatherLayer(torch.autograd.Function): 158 | """Gather tensors from all process, supporting backward propagation.""" 159 | 160 | @staticmethod 161 | def forward(ctx, input): 162 | ctx.save_for_backward(input) 163 | output = [torch.zeros_like(input) \ 164 | for _ in range(dist.get_world_size())] 165 | dist.all_gather(output, input) 166 | return tuple(output) 167 | 168 | @staticmethod 169 | def backward(ctx, *grads): 170 | input, = ctx.saved_tensors 171 | grad_out = torch.zeros_like(input) 172 | grad_out[:] = grads[dist.get_rank()] 173 | return grad_out 174 | 175 | 176 | class NT_Xent(nn.Module): 177 | def __init__(self, temperature=0.07): 178 | super(NT_Xent, self).__init__() 179 | self.temperature = temperature 180 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 181 | self.similarity_f = nn.CosineSimilarity(dim=2) 182 | 183 | def mask_correlated_samples(self, batch_size, world_size): 184 | N = 2 * batch_size * world_size 185 | mask = torch.ones((N, N), dtype=torch.bool) 186 | mask = mask.fill_diagonal_(0) 187 | for i in range(batch_size * world_size): 188 | mask[i, batch_size + i] = 0 189 | mask[batch_size + i, i] = 0 190 | return mask 191 | 192 | def forward(self, zz): 193 | """ 194 | We do not sample negative examples explicitly. 195 | Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. 196 | """ 197 | z_i, z_j = zz[:, 0], zz[:, 1] 198 | batch_size = z_i.shape[0] 199 | world_size = dist.get_world_size() 200 | N = 2 * batch_size * world_size 201 | 202 | mask = self.mask_correlated_samples(batch_size, world_size) 203 | 204 | z = torch.cat((z_i, z_j), dim=0) 205 | if world_size > 1: 206 | z = torch.cat(GatherLayer.apply(z), dim=0) 207 | 208 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature 209 | 210 | sim_i_j = torch.diag(sim, batch_size * world_size) 211 | 212 | sim_j_i = torch.diag(sim, -batch_size * world_size) 213 | 214 | # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN 215 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape( 216 | N, 1 217 | ) 218 | negative_samples = sim[mask].reshape(N, -1) 219 | labels = torch.zeros(N, device=positive_samples.device).long() 220 | logits = torch.cat((positive_samples, negative_samples), dim=1) 221 | loss = self.criterion(logits, labels) 222 | loss /= N 223 | return loss 224 | 225 | 226 | def all_gather(tensor, expand_dim=0, num_replicas=None): 227 | """Gathers a tensor from other replicas, concat on expand_dim and return.""" 228 | num_replicas = dist.get_world_size() if num_replicas is None else num_replicas 229 | other_replica_tensors = [torch.zeros_like(tensor) for _ in range(num_replicas)] 230 | dist.all_gather(other_replica_tensors, tensor) 231 | return torch.cat([o.unsqueeze(expand_dim) for o in other_replica_tensors], expand_dim) 232 | 233 | 234 | class NTXent(nn.Module): 235 | """Wrap a module to get self.training member.""" 236 | 237 | def __init__(self, temperature=0.07): 238 | super(NTXent, self).__init__() 239 | self.temperature = temperature 240 | 241 | def forward(self, embeddings): 242 | """NT-XENT Loss from SimCLR 243 | :param embedding1: embedding of augmentation1 244 | :param embedding2: embedding of augmentation2 245 | :param temperature: nce normalization temp 246 | :param num_replicas: number of compute devices 247 | :returns: scalar loss 248 | :rtype: float32 249 | """ 250 | embedding1, embedding2 = embeddings[:, 0].contiguous(), embeddings[:, 1].contiguous() 251 | batch_size = embedding1.shape[0] 252 | feature_size = embedding1.shape[-1] 253 | num_replicas = dist.get_world_size() 254 | LARGE_NUM = 1e9 255 | 256 | if num_replicas > 1 and self.training: 257 | # First grab the tensor from all other embeddings 258 | embedding1_full = all_gather(embedding1, num_replicas=num_replicas) 259 | embedding2_full = all_gather(embedding2, num_replicas=num_replicas) 260 | 261 | # fold the tensor in to create [B, F] 262 | embedding1_full = embedding1_full.reshape(-1, feature_size) 263 | embedding2_full = embedding2_full.reshape(-1, feature_size) 264 | 265 | # Create pseudo-labels using the current replica id & ont-hotting 266 | replica_id = dist.get_rank() 267 | labels = torch.arange(batch_size, device=embedding1.device) + replica_id * batch_size 268 | labels = labels.type(torch.int64) 269 | full_batch_size = embedding1_full.shape[0] 270 | masks = F.one_hot(labels, full_batch_size).to(embedding1_full.device) 271 | labels = F.one_hot(labels, full_batch_size * 2).to(embedding1_full.device) 272 | else: # no replicas or we are in test mode; test set is same size on all replicas for now 273 | embedding1_full = embedding1 274 | embedding2_full = embedding2 275 | masks = F.one_hot(torch.arange(batch_size), batch_size).to(embedding1.device) 276 | labels = F.one_hot(torch.arange(batch_size), batch_size * 2).to(embedding1.device) 277 | 278 | # Matmul-to-mask 279 | logits_aa = torch.matmul(embedding1, embedding1_full.T) / self.temperature 280 | logits_aa = logits_aa - masks * LARGE_NUM 281 | logits_bb = torch.matmul(embedding2, embedding2_full.T) / self.temperature 282 | logits_bb = logits_bb - masks * LARGE_NUM 283 | logits_ab = torch.matmul(embedding1, embedding2_full.T) / self.temperature 284 | logits_ba = torch.matmul(embedding2, embedding1_full.T) / self.temperature 285 | 286 | # Use our standard cross-entropy loss which uses log-softmax internally. 287 | # Concat on the feature dimension to provide all features for standard softmax-xent 288 | loss_a = F.cross_entropy(input=torch.cat([logits_ab, logits_aa], 1), 289 | target=torch.argmax(labels, -1), 290 | reduction="none") 291 | loss_b = F.cross_entropy(input=torch.cat([logits_ba, logits_bb], 1), 292 | target=torch.argmax(labels, -1), 293 | reduction="none") 294 | loss = (loss_a + loss_b) * 0.5 295 | return torch.mean(loss) 296 | -------------------------------------------------------------------------------- /tools/convert_contrastive_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import scipy.ndimage 4 | import torch 5 | 6 | parser = argparse.ArgumentParser(description='Convert from ViT for different input size') 7 | parser.add_argument('--mode', default='', type=str, metavar='MODE', 8 | help='moco or byol') 9 | parser.add_argument('model_path', default='', type=str, metavar='MODEL', 10 | help='The pretrained model') 11 | 12 | 13 | def from_contrastive_model(model_path, mode=None): 14 | 15 | model = torch.load(model_path, map_location='cpu') 16 | state_dict_name = 'model' 17 | 18 | if mode == 'moco': 19 | keyword = 'encoder_q.' 20 | elif mode == 'byol': 21 | keyword = 'online_encoder.' 22 | else: 23 | raise ValueError(f"Unknown mode: {mode}") 24 | 25 | state_dict = model[state_dict_name] 26 | for k in list(state_dict.keys()): 27 | # retain only encoder_q up to before the embedding layer 28 | if k.startswith(keyword): 29 | # remove prefix 30 | state_dict[k[len(keyword):]] = state_dict[k] 31 | # delete renamed or unused k 32 | del state_dict[k] 33 | 34 | for k in state_dict.keys(): 35 | print(k) 36 | 37 | model = {state_dict_name: state_dict} 38 | return model 39 | 40 | 41 | def main(): 42 | args = parser.parse_args() 43 | 44 | model = from_contrastive_model(args.model_path, args.mode) 45 | name = os.path.join(os.path.dirname(args.model_path), os.path.basename(args.model_path).split(".")[0] + f'_remove_contrastive_wrapper' + '.pth.tar') 46 | print(f"Save model to {name}") 47 | torch.save(model, name, _use_new_zipfile_serialization=False) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /tools/convert_vit_model_to_diff_input_size.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import scipy.ndimage 4 | import torch 5 | 6 | parser = argparse.ArgumentParser(description='Convert from ViT for different input size') 7 | parser.add_argument('--model', default='', type=str, metavar='MODEL', 8 | help='The pretrained model') 9 | parser.add_argument('--ori_input_size', '--ois', default=[224], type=int, nargs='+') 10 | parser.add_argument('--new_input_size', '--nis', default=[384], type=int, nargs='+') 11 | parser.add_argument('--ori_patch_size', '--ops', default=[16], type=int, nargs='+') 12 | parser.add_argument('--new_patch_size', '--nps', default=[16], type=int, nargs='+') 13 | parser.add_argument('--ema', action='store_true',) 14 | parser.add_argument('--remove_fc', action='store_true',) 15 | 16 | # need to convert the pos embedding. 17 | # https://github.com/google-research/vision_transformer/blob/f952d612e1b55d1099b2fbf87fc04218d5c4fe18/vit_jax/checkpoint.py#L185 18 | 19 | 20 | def _convert_one_set(pos_embed, n_p, o_p): 21 | pos_embed = pos_embed.reshape(1, o_p, o_p, -1).permute(0, 3, 1, 2) 22 | pos_embed = torch.nn.functional.interpolate( 23 | pos_embed, size=(n_p, n_p), mode='bicubic', align_corners=False) 24 | new_pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2) 25 | 26 | """ 27 | pos_embed = pos_embed.squeeze(0).reshape(o_p, o_p, -1).cpu().numpy() 28 | zoom = (n_p / o_p, n_p / o_p, 1) 29 | new_pos_embed = scipy.ndimage.zoom(pos_embed, zoom, order=1) 30 | new_pos_embed = torch.tensor(new_pos_embed).reshape((1, n_p * n_p, -1)) 31 | """ 32 | 33 | return new_pos_embed 34 | 35 | 36 | def convert(model_path, ori_input_size, ori_patch_size, new_input_size, new_patch_size, use_ema=False, remove_fc=False): 37 | 38 | model = torch.load(model_path, map_location='cpu') 39 | state_dict_name = 'model_ema' if use_ema else 'model' 40 | model = {state_dict_name: model[state_dict_name]} 41 | if remove_fc: 42 | model[state_dict_name].pop('head.weight', None) 43 | model[state_dict_name].pop('head.bias', None) 44 | 45 | for idx, (nis, nps, ois, ops) in enumerate(zip(new_input_size, new_patch_size, ori_input_size, ori_patch_size)): 46 | new_num_patches = (nis // nps) * (nis // nps) 47 | ori_num_patches = (ois // ops) * (ois // ops) 48 | 49 | if new_num_patches == ori_num_patches: 50 | continue 51 | else: 52 | print(f"Resize the pos embedding: num_pos in checkpoint: {ori_num_patches}, num_pos in model: {new_num_patches}", flush=True) 53 | if f'pos_embed.{idx}' in model[state_dict_name]: 54 | embed_name = f'pos_embed.{idx}' 55 | else: 56 | embed_name = 'pos_embed' 57 | ori_pos_embed = model[state_dict_name][embed_name] 58 | start_pos = ori_pos_embed.shape[1] - ori_num_patches 59 | cls_token_embed = ori_pos_embed[:, :start_pos, :] 60 | n_p = (nis // nps) 61 | o_p = (ois // ops) 62 | new_pos_embed = _convert_one_set(ori_pos_embed[:, start_pos:, :], n_p, o_p) 63 | out = torch.cat((cls_token_embed, new_pos_embed), dim=1) 64 | model[state_dict_name][embed_name] = out 65 | 66 | """ 67 | ori_pos_embed = ori_pos_embed[:, start_pos:, :] # remove cls 68 | ori_pos_embed = ori_pos_embed.squeeze(0).reshape(o_p, o_p, -1).cpu().numpy() 69 | zoom = (n_p / o_p, n_p / o_p, 1) 70 | new_pos_embed = scipy.ndimage.zoom(ori_pos_embed, zoom, order=1) 71 | new_pos_embed = torch.tensor(new_pos_embed).reshape((1, n_p * n_p, -1)) 72 | out = torch.cat((cls_token_embed, new_pos_embed), dim=1) 73 | print(f"{model[state_dict_name]['pos_embed'].shape}, {out.shape}") 74 | model[state_dict_name]['pos_embed'] = out 75 | """ 76 | return model 77 | 78 | 79 | def main(): 80 | args = parser.parse_args() 81 | 82 | model = convert(args.model, args.ori_input_size, args.ori_patch_size, args.new_input_size, args.new_patch_size, args.ema, args.remove_fc) 83 | name = os.path.join(os.path.dirname(args.model), os.path.basename(args.model).split(".")[0] + f'_{max(args.new_input_size)}' + ("_no_fc" if args.remove_fc else "") + '.pth.tar') 84 | print(f"Save model to {name}") 85 | torch.save(model, name, _use_new_zipfile_serialization=False) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | 91 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Misc functions, including distributed helpers. 9 | 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | import io 13 | import os 14 | import time 15 | from collections import defaultdict, deque 16 | import datetime 17 | import tempfile 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from fvcore.common.checkpoint import Checkpointer 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if isinstance(v, torch.Tensor): 94 | v = v.item() 95 | assert isinstance(v, (float, int)) 96 | self.meters[k].update(v) 97 | 98 | def __getattr__(self, attr): 99 | if attr in self.meters: 100 | return self.meters[attr] 101 | if attr in self.__dict__: 102 | return self.__dict__[attr] 103 | raise AttributeError("'{}' object has no attribute '{}'".format( 104 | type(self).__name__, attr)) 105 | 106 | def __str__(self): 107 | loss_str = [] 108 | for name, meter in self.meters.items(): 109 | loss_str.append( 110 | "{}: {}".format(name, str(meter)) 111 | ) 112 | return self.delimiter.join(loss_str) 113 | 114 | def synchronize_between_processes(self): 115 | for meter in self.meters.values(): 116 | meter.synchronize_between_processes() 117 | 118 | def add_meter(self, name, meter): 119 | self.meters[name] = meter 120 | 121 | def log_every(self, iterable, print_freq, header=None): 122 | i = 0 123 | if not header: 124 | header = '' 125 | start_time = time.time() 126 | end = time.time() 127 | iter_time = SmoothedValue(fmt='{avg:.4f}') 128 | data_time = SmoothedValue(fmt='{avg:.4f}') 129 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 130 | log_msg = [ 131 | header, 132 | '[{0' + space_fmt + '}/{1}]', 133 | 'eta: {eta}', 134 | '{meters}', 135 | 'time: {time}', 136 | 'data: {data}' 137 | ] 138 | if torch.cuda.is_available(): 139 | log_msg.append('max mem: {memory:.0f}') 140 | log_msg = self.delimiter.join(log_msg) 141 | MB = 1024.0 * 1024.0 142 | for obj in iterable: 143 | data_time.update(time.time() - end) 144 | yield obj 145 | iter_time.update(time.time() - end) 146 | if i % print_freq == 0 or i == len(iterable) - 1: 147 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 148 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 149 | if torch.cuda.is_available(): 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time), 154 | memory=torch.cuda.max_memory_allocated() / MB)) 155 | else: 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), data=str(data_time))) 160 | i += 1 161 | end = time.time() 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('{} Total time: {} ({:.4f} s / it)'.format( 165 | header, total_time_str, total_time / len(iterable))) 166 | 167 | 168 | def _load_checkpoint_for_ema(model_ema, checkpoint): 169 | """ 170 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 171 | """ 172 | mem_file = io.BytesIO() 173 | torch.save(checkpoint, mem_file) 174 | mem_file.seek(0) 175 | model_ema._load_checkpoint(mem_file) 176 | 177 | 178 | def load_checkpoint(model, state_dict, mode=None): 179 | 180 | # reuse Checkpointer in fvcore to support flexible loading 181 | ckpt = Checkpointer(model, save_to_disk=False) 182 | # since Checkpointer requires the weight to be put under `model` field, we need to save it to disk 183 | tmp_path = tempfile.NamedTemporaryFile('w+b') 184 | torch.save({'model': state_dict}, tmp_path.name) 185 | ckpt.load(tmp_path.name) 186 | 187 | def setup_for_distributed(is_master): 188 | """ 189 | This function disables printing when not in master process 190 | """ 191 | import builtins as __builtin__ 192 | builtin_print = __builtin__.print 193 | 194 | def print(*args, **kwargs): 195 | force = kwargs.pop('force', False) 196 | if is_master or force: 197 | builtin_print(*args, **kwargs) 198 | 199 | __builtin__.print = print 200 | 201 | 202 | def is_dist_avail_and_initialized(): 203 | if not dist.is_available(): 204 | return False 205 | if not dist.is_initialized(): 206 | return False 207 | return True 208 | 209 | 210 | def get_world_size(): 211 | if not is_dist_avail_and_initialized(): 212 | return 1 213 | return dist.get_world_size() 214 | 215 | 216 | def get_rank(): 217 | if not is_dist_avail_and_initialized(): 218 | return 0 219 | return dist.get_rank() 220 | 221 | 222 | def is_main_process(): 223 | return get_rank() == 0 224 | 225 | 226 | def save_on_master(*args, **kwargs): 227 | if is_main_process(): 228 | torch.save(*args, **kwargs) 229 | 230 | 231 | def init_distributed_mode(args): 232 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 233 | args.rank = int(os.environ["RANK"]) 234 | args.world_size = int(os.environ['WORLD_SIZE']) 235 | args.gpu = int(os.environ['LOCAL_RANK']) 236 | elif 'SLURM_PROCID' in os.environ: 237 | args.rank = int(os.environ['SLURM_PROCID']) 238 | args.gpu = args.rank % torch.cuda.device_count() 239 | else: 240 | print('Not using distributed mode') 241 | args.distributed = False 242 | return 243 | 244 | args.distributed = True 245 | 246 | torch.cuda.set_device(args.gpu) 247 | args.dist_backend = 'nccl' 248 | print('| distributed init (rank {}): {}'.format( 249 | args.rank, args.dist_url), flush=True) 250 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 251 | world_size=args.world_size, rank=args.rank) 252 | torch.distributed.barrier() 253 | setup_for_distributed(args.rank == 0) 254 | -------------------------------------------------------------------------------- /video_dataset_aug.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from typing import Union, List, Tuple 3 | 4 | import torch 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torch.utils.data.distributed 9 | import torchvision.transforms as transforms 10 | from video_transforms import (GroupRandomHorizontalFlip, GroupOverSample, 11 | GroupMultiScaleCrop, GroupScale, GroupCenterCrop, GroupRandomCrop, 12 | GroupNormalize, Stack, ToTorchFormatTensor, GroupRandomScale) 13 | 14 | def get_augmentor(is_train: bool, image_size: int, mean: List[float] = None, 15 | std: List[float] = None, disable_scaleup: bool = False, 16 | threed_data: bool = False, version: str = 'v1', scale_range: [int] = None, 17 | modality: str = 'rgb', num_clips: int = 1, num_crops: int = 1, dataset: str = ''): 18 | 19 | mean = [0.485, 0.456, 0.406] if mean is None else mean 20 | std = [0.229, 0.224, 0.225] if std is None else std 21 | scale_range = [256, 320] if scale_range is None else scale_range 22 | 23 | if modality == 'sound': 24 | augments = [ 25 | Stack(threed_data=threed_data), 26 | ToTorchFormatTensor(div=False, num_clips_crops=num_clips * num_crops) 27 | ] 28 | else: 29 | augments = [] 30 | if is_train: 31 | if version == 'v1': 32 | augments += [ 33 | GroupMultiScaleCrop(image_size, [1, .875, .75, .66]) 34 | ] 35 | elif version == 'v2': 36 | augments += [ 37 | GroupRandomScale(scale_range), 38 | GroupRandomCrop(image_size), 39 | ] 40 | if not (dataset.startswith('ststv') or 'jester' in dataset or 'mini_ststv' in dataset): 41 | augments += [GroupRandomHorizontalFlip(is_flow=(modality == 'flow'))] 42 | else: 43 | scaled_size = image_size if disable_scaleup else int(image_size / 0.875 + 0.5) 44 | if num_crops == 1: 45 | augments += [ 46 | GroupScale(scaled_size), 47 | GroupCenterCrop(image_size) 48 | ] 49 | else: 50 | flip = True if num_crops == 10 else False 51 | augments += [ 52 | GroupOverSample(image_size, scaled_size, num_crops=num_crops, flip=flip), 53 | ] 54 | augments += [ 55 | Stack(threed_data=threed_data), 56 | ToTorchFormatTensor(num_clips_crops=num_clips * num_crops), 57 | GroupNormalize(mean=mean, std=std, threed_data=threed_data) 58 | ] 59 | 60 | augmentor = transforms.Compose(augments) 61 | return augmentor 62 | 63 | 64 | def build_dataflow(dataset, is_train, batch_size, workers=36, is_distributed=False): 65 | workers = min(workers, multiprocessing.cpu_count()) 66 | shuffle = False 67 | 68 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None 69 | if is_train: 70 | shuffle = sampler is None 71 | 72 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 73 | num_workers=workers, pin_memory=True, sampler=sampler) 74 | 75 | return data_loader 76 | 77 | -------------------------------------------------------------------------------- /video_dataset_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | DATASET_CONFIG = { 4 | 'st2stv2': { 5 | 'num_classes': 174, 6 | 'train_list_name': 'train.txt', 7 | 'val_list_name': 'val.txt', 8 | 'test_list_name': 'test.txt', 9 | 'filename_seperator': " ", 10 | 'image_tmpl': '{:05d}.jpg', 11 | 'filter_video': 3, 12 | 'label_file': 'categories.txt' 13 | }, 14 | 'mini_st2stv2': { 15 | 'num_classes': 87, 16 | 'train_list_name': 'mini_train.txt', 17 | 'val_list_name': 'mini_val.txt', 18 | 'test_list_name': 'mini_test.txt', 19 | 'filename_seperator': " ", 20 | 'image_tmpl': '{:05d}.jpg', 21 | 'filter_video': 3, 22 | }, 23 | 'st2stv1': { 24 | 'num_classes': 174, 25 | 'train_list_name': 'training_256.txt', 26 | 'val_list_name': 'validation_256.txt', 27 | 'test_list_name': 'testing_256.txt', 28 | 'filename_seperator': " ", 29 | 'image_tmpl': '{:05d}.jpg', 30 | 'filter_video': 3, 31 | 'label_file': 'something-something-v1-labels.csv' 32 | }, 33 | 'kinetics400': { 34 | 'num_classes': 400, 35 | 'train_list_name': 'train.txt', 36 | 'val_list_name': 'val.txt', 37 | 'test_list_name': 'test.txt', 38 | 'filename_seperator': ";", 39 | 'image_tmpl': '{:05d}.jpg', 40 | 'filter_video': 30, 41 | 'label_file': 'image/kinetics-400_label.csv' 42 | }, 43 | 'mini_kinetics400': { 44 | 'num_classes': 200, 45 | 'train_list_name': 'mini_train.txt', 46 | 'val_list_name': 'mini_val.txt', 47 | 'test_list_name': 'mini_test.txt', 48 | 'filename_seperator': ";", 49 | 'image_tmpl': '{:05d}.jpg', 50 | 'filter_video': 30 51 | }, 52 | 'charades': { 53 | 'num_classes': 157, 54 | 'train_list_name': 'train.txt', 55 | 'val_list_name': 'val.txt', 56 | 'filename_seperator': " ", 57 | 'image_tmpl': '{:06d}.jpg', 58 | 'filter_video': 0 59 | }, 60 | 'diva': { 61 | 'num_classes': 19, 62 | 'train_list_name': 'DIVA_GT_RGB_TSM_train.txt', 63 | 'val_list_name': 'DIVA_GT_RGB_TSM_validate.txt', 64 | 'filename_seperator': " ", 65 | 'image_tmpl': '{:08d}.jpg', 66 | 'filter_video': 0 67 | }, 68 | 'diva_pvi': { 69 | 'num_classes': 8, 70 | 'train_list_name': 'DIVA_PVI_GT_RGB_TSM_train.txt', 71 | 'val_list_name': 'DIVA_PVI_GT_RGB_TSM_validate.txt', 72 | 'filename_seperator': " ", 73 | 'image_tmpl': '{:08d}.jpg', 74 | 'filter_video': 0 75 | }, 76 | 'moments': { 77 | 'num_classes': 339, 78 | 'train_list_name': 'train.txt', 79 | 'val_list_name': 'val.txt', 80 | 'filename_seperator': " ", 81 | 'image_tmpl': '{:05d}.jpg', 82 | 'filter_video': 0 83 | }, 84 | 'mini_moments': { 85 | 'num_classes': 200, 86 | 'train_list_name': 'mini_train.txt', 87 | 'val_list_name': 'mini_val.txt', 88 | 'filename_seperator': " ", 89 | 'image_tmpl': '{:05d}.jpg', 90 | 'filter_video': 0 91 | }, 92 | 'ucf101': { 93 | 'num_classes': 101, 94 | 'train_list_name': 'train.txt', 95 | 'val_list_name': 'val.txt', 96 | 'filename_seperator': " ", 97 | 'image_tmpl': '{:05d}.jpg', 98 | 'filter_video': 0 99 | }, 100 | 'hmdb51': { 101 | 'num_classes': 51, 102 | 'train_list_name': 'train.txt', 103 | 'val_list_name': 'val.txt', 104 | 'filename_seperator': " ", 105 | 'image_tmpl': '{:05d}.jpg', 106 | 'filter_video': 0 107 | }, 108 | 'jester': { 109 | 'num_classes': 27, 110 | 'train_list_name': 'train.txt', 111 | 'val_list_name': 'val.txt', 112 | 'filename_seperator': " ", 113 | 'image_tmpl': '{:05d}.jpg', 114 | 'filter_video': 0 115 | }, 116 | } 117 | 118 | 119 | def get_dataset_config(dataset, use_lmdb=False): 120 | ret = DATASET_CONFIG[dataset] 121 | num_classes = ret['num_classes'] 122 | train_list_name = ret['train_list_name'].replace("txt", "lmdb") if use_lmdb \ 123 | else ret['train_list_name'] 124 | val_list_name = ret['val_list_name'].replace("txt", "lmdb") if use_lmdb \ 125 | else ret['val_list_name'] 126 | test_list_name = ret.get('test_lmdb_name', None) 127 | if test_list_name is not None: 128 | test_list_name = test_list_name.replace("txt", "lmdb") 129 | filename_seperator = ret['filename_seperator'] 130 | image_tmpl = ret['image_tmpl'] 131 | filter_video = ret.get('filter_video', 0) 132 | label_file = ret.get('label_file', None) 133 | 134 | return num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, \ 135 | image_tmpl, filter_video, label_file 136 | -------------------------------------------------------------------------------- /video_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | class GroupRandomCrop(object): 10 | def __init__(self, size): 11 | if isinstance(size, numbers.Number): 12 | self.size = (int(size), int(size)) 13 | else: 14 | self.size = size 15 | 16 | def __call__(self, img_group): 17 | 18 | w, h = img_group[0].size 19 | th, tw = self.size 20 | 21 | out_images = list() 22 | 23 | x1 = random.randint(0, w - tw) 24 | y1 = random.randint(0, h - th) 25 | 26 | for img in img_group: 27 | assert(img.size[0] == w and img.size[1] == h) 28 | if w == tw and h == th: 29 | out_images.append(img) 30 | else: 31 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 32 | 33 | return out_images 34 | 35 | 36 | class GroupCenterCrop(object): 37 | def __init__(self, size): 38 | self.worker = torchvision.transforms.CenterCrop(size) 39 | 40 | def __call__(self, img_group): 41 | return [self.worker(img) for img in img_group] 42 | 43 | 44 | class GroupRandomHorizontalFlip(object): 45 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 46 | """ 47 | def __init__(self, is_flow=False): 48 | self.is_flow = is_flow 49 | 50 | def __call__(self, img_group, is_flow=False): 51 | v = random.random() 52 | if v < 0.5: 53 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 54 | if self.is_flow: 55 | for i in range(0, len(ret), 2): 56 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 57 | return ret 58 | else: 59 | return img_group 60 | 61 | 62 | class GroupNormalize(object): 63 | def __init__(self, mean, std, threed_data=False): 64 | self.threed_data = threed_data 65 | if self.threed_data: 66 | # convert to the proper format 67 | self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1) 68 | self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1) 69 | else: 70 | self.mean = mean 71 | self.std = std 72 | 73 | def __call__(self, tensor): 74 | 75 | if self.threed_data: 76 | tensor.sub_(self.mean).div_(self.std) 77 | else: 78 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 79 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 80 | 81 | # TODO: make efficient 82 | for t, m, s in zip(tensor, rep_mean, rep_std): 83 | t.sub_(m).div_(s) 84 | 85 | return tensor 86 | 87 | 88 | class GroupScale(object): 89 | """ Rescales the input PIL.Image to the given 'size'. 90 | 'size' will be the size of the smaller edge. 91 | For example, if height > width, then image will be 92 | rescaled to (size * height / width, size) 93 | size: size of the smaller edge 94 | interpolation: Default: PIL.Image.BILINEAR 95 | """ 96 | 97 | def __init__(self, size, interpolation=Image.BILINEAR): 98 | self.worker = torchvision.transforms.Resize(size, interpolation) 99 | 100 | def __call__(self, img_group): 101 | return [self.worker(img) for img in img_group] 102 | 103 | class GroupRandomScale(object): 104 | """ Rescales the input PIL.Image to the given 'size'. 105 | 'size' will be the size of the smaller edge. 106 | For example, if height > width, then image will be 107 | rescaled to (size * height / width, size) 108 | size: size of the smaller edge 109 | interpolation: Default: PIL.Image.BILINEAR 110 | 111 | Randomly select the smaller edge from the range of 'size'. 112 | """ 113 | def __init__(self, size, interpolation=Image.BILINEAR): 114 | self.size = size 115 | self.interpolation = interpolation 116 | 117 | def __call__(self, img_group): 118 | selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int) 119 | scale = GroupScale(selected_size, interpolation=self.interpolation) 120 | return scale(img_group) 121 | 122 | class GroupOverSample(object): 123 | def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False): 124 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 125 | 126 | if scale_size is not None: 127 | self.scale_worker = GroupScale(scale_size) 128 | else: 129 | self.scale_worker = None 130 | 131 | if num_crops not in [1, 3, 5, 10]: 132 | raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops)) 133 | self.num_crops = num_crops 134 | 135 | self.flip = flip 136 | 137 | def __call__(self, img_group): 138 | 139 | if self.scale_worker is not None: 140 | img_group = self.scale_worker(img_group) 141 | 142 | image_w, image_h = img_group[0].size 143 | crop_w, crop_h = self.crop_size 144 | 145 | if self.num_crops == 3: 146 | w_step = (image_w - crop_w) // 4 147 | h_step = (image_h - crop_h) // 4 148 | offsets = list() 149 | if image_w != crop_w and image_h != crop_h: 150 | offsets.append((0 * w_step, 0 * h_step)) # top 151 | offsets.append((4 * w_step, 4 * h_step)) # bottom 152 | offsets.append((2 * w_step, 2 * h_step)) # center 153 | else: 154 | if image_w < image_h: 155 | offsets.append((2 * w_step, 0 * h_step)) # top 156 | offsets.append((2 * w_step, 4 * h_step)) # bottom 157 | offsets.append((2 * w_step, 2 * h_step)) # center 158 | else: 159 | offsets.append((0 * w_step, 2 * h_step)) # left 160 | offsets.append((4 * w_step, 2 * h_step)) # right 161 | offsets.append((2 * w_step, 2 * h_step)) # center 162 | 163 | else: 164 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 165 | 166 | oversample_group = list() 167 | for o_w, o_h in offsets: 168 | normal_group = list() 169 | flip_group = list() 170 | for i, img in enumerate(img_group): 171 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 172 | normal_group.append(crop) 173 | if self.flip: 174 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 175 | 176 | if img.mode == 'L' and i % 2 == 0: 177 | flip_group.append(ImageOps.invert(flip_crop)) 178 | else: 179 | flip_group.append(flip_crop) 180 | 181 | oversample_group.extend(normal_group) 182 | if self.flip: 183 | oversample_group.extend(flip_group) 184 | return oversample_group 185 | 186 | 187 | class GroupMultiScaleCrop(object): 188 | 189 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 190 | self.scales = scales if scales is not None else [1, 875, .75, .66] 191 | self.max_distort = max_distort 192 | self.fix_crop = fix_crop 193 | self.more_fix_crop = more_fix_crop 194 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 195 | self.interpolation = Image.BILINEAR 196 | 197 | def __call__(self, img_group): 198 | 199 | im_size = img_group[0].size 200 | 201 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 202 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 203 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 204 | for img in crop_img_group] 205 | return ret_img_group 206 | 207 | def _sample_crop_size(self, im_size): 208 | image_w, image_h = im_size[0], im_size[1] 209 | 210 | # find a crop size 211 | base_size = min(image_w, image_h) 212 | crop_sizes = [int(base_size * x) for x in self.scales] 213 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 214 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 215 | 216 | pairs = [] 217 | for i, h in enumerate(crop_h): 218 | for j, w in enumerate(crop_w): 219 | if abs(i - j) <= self.max_distort: 220 | pairs.append((w, h)) 221 | 222 | crop_pair = random.choice(pairs) 223 | if not self.fix_crop: 224 | w_offset = random.randint(0, image_w - crop_pair[0]) 225 | h_offset = random.randint(0, image_h - crop_pair[1]) 226 | else: 227 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 228 | 229 | return crop_pair[0], crop_pair[1], w_offset, h_offset 230 | 231 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 232 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 233 | return random.choice(offsets) 234 | 235 | @staticmethod 236 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 237 | w_step = (image_w - crop_w) // 4 238 | h_step = (image_h - crop_h) // 4 239 | 240 | ret = list() 241 | ret.append((0, 0)) # upper left 242 | ret.append((4 * w_step, 0)) # upper right 243 | ret.append((0, 4 * h_step)) # lower left 244 | ret.append((4 * w_step, 4 * h_step)) # lower right 245 | ret.append((2 * w_step, 2 * h_step)) # center 246 | 247 | if more_fix_crop: 248 | ret.append((0, 2 * h_step)) # center left 249 | ret.append((4 * w_step, 2 * h_step)) # center right 250 | ret.append((2 * w_step, 4 * h_step)) # lower center 251 | ret.append((2 * w_step, 0 * h_step)) # upper center 252 | 253 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 254 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 255 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 256 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 257 | 258 | return ret 259 | 260 | 261 | class GroupRandomSizedCrop(object): 262 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 263 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 264 | This is popularly used to train the Inception networks 265 | size: size of the smaller edge 266 | interpolation: Default: PIL.Image.BILINEAR 267 | """ 268 | def __init__(self, size, interpolation=Image.BILINEAR): 269 | self.size = size 270 | self.interpolation = interpolation 271 | 272 | def __call__(self, img_group): 273 | for attempt in range(10): 274 | area = img_group[0].size[0] * img_group[0].size[1] 275 | target_area = random.uniform(0.08, 1.0) * area 276 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 277 | 278 | w = int(round(math.sqrt(target_area * aspect_ratio))) 279 | h = int(round(math.sqrt(target_area / aspect_ratio))) 280 | 281 | if random.random() < 0.5: 282 | w, h = h, w 283 | 284 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 285 | x1 = random.randint(0, img_group[0].size[0] - w) 286 | y1 = random.randint(0, img_group[0].size[1] - h) 287 | found = True 288 | break 289 | else: 290 | found = False 291 | x1 = 0 292 | y1 = 0 293 | 294 | if found: 295 | out_group = list() 296 | for img in img_group: 297 | img = img.crop((x1, y1, x1 + w, y1 + h)) 298 | assert(img.size == (w, h)) 299 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 300 | return out_group 301 | else: 302 | # Fallback 303 | scale = GroupScale(self.size, interpolation=self.interpolation) 304 | crop = GroupRandomCrop(self.size) 305 | return crop(scale(img_group)) 306 | 307 | 308 | class Stack(object): 309 | 310 | def __init__(self, roll=False, threed_data=False): 311 | self.roll = roll 312 | self.threed_data = threed_data 313 | 314 | def __call__(self, img_group): 315 | if img_group[0].mode == 'L': 316 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 317 | elif img_group[0].mode == 'RGB': 318 | if self.threed_data: 319 | return np.stack(img_group, axis=0) 320 | else: 321 | if self.roll: 322 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 323 | else: 324 | return np.concatenate(img_group, axis=2) 325 | 326 | 327 | class ToTorchFormatTensor(object): 328 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 329 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 330 | def __init__(self, div=True, num_clips_crops=1): 331 | self.div = div 332 | self.num_clips_crops = num_clips_crops 333 | 334 | def __call__(self, pic): 335 | if isinstance(pic, np.ndarray): 336 | # handle numpy array 337 | if len(pic.shape) == 4: 338 | # ((NF)xCxHxW) --> (Cx(NF)xHxW) 339 | img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() 340 | else: # data is HW(FC) 341 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 342 | else: 343 | # handle PIL Image 344 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 345 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 346 | # put it from HWC to CHW format 347 | # yikes, this transpose takes 80% of the loading time/CPU 348 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 349 | return img.float().div(255) if self.div else img.float() 350 | 351 | 352 | class IdentityTransform(object): 353 | 354 | def __call__(self, data): 355 | return data 356 | 357 | 358 | if __name__ == "__main__": 359 | trans = torchvision.transforms.Compose([ 360 | GroupScale(256), 361 | GroupRandomCrop(224), 362 | Stack(), 363 | ToTorchFormatTensor(), 364 | GroupNormalize( 365 | mean=[.485, .456, .406], 366 | std=[.229, .224, .225] 367 | )] 368 | ) 369 | 370 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 371 | 372 | color_group = [im] * 3 373 | rst = trans(color_group) 374 | 375 | gray_group = [im.convert('L')] * 9 376 | gray_rst = trans(gray_group) 377 | 378 | trans2 = torchvision.transforms.Compose([ 379 | GroupRandomSizedCrop(256), 380 | Stack(), 381 | ToTorchFormatTensor(), 382 | GroupNormalize( 383 | mean=[.485, .456, .406], 384 | std=[.229, .224, .225]) 385 | ]) 386 | print(trans2(color_group)) 387 | --------------------------------------------------------------------------------