├── LICENSE ├── README.md ├── models ├── __init__.py ├── adamml.py ├── common.py ├── joint_resnet_mobilenetv2.py ├── model_builder.py ├── policy_net.py ├── resnet.py └── sound_mobilenet_v2.py ├── opts.py ├── tools ├── extract_audio.py └── extract_rgb.py ├── train_adamml.py ├── train_unimodal.py └── utils ├── dataset_config.py ├── utils.py ├── video_dataset.py └── video_transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition [[ArXiv]](https://arxiv.org/pdf/2105.05165.pdf) [[Project Page]](https://rpand002.github.io/adamml.html) 2 | 3 | This repository is the official implementation of AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition. 4 | 5 | Rameswar Panda*, Chun-Fu (Richard) Chen*, Quanfu Fan, Ximeng Sun, Kate Saenko, Aude Oliva, Rogerio Feris, "AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition", ICCV 2021. (*: Equal Contribution) 6 | 7 | If you use the codes and models from this repo, please cite our work. Thanks! 8 | 9 | ``` 10 | @inproceedings{panda2021adamml, 11 | title={{AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition}}, 12 | author={Panda, Rameswar and Chen, Chun-Fu and Fan, Quanfu and Sun, Ximeng and Saenko, Kate and Oliva, Aude and Feris, Rogerio}, 13 | booktitle={International Conference on Computer Vision (ICCV)}, 14 | year={2021} 15 | } 16 | ``` 17 | ## Requirements 18 | 19 | ``` 20 | pip3 install torch torchvision librosa tqdm Pillow numpy 21 | ``` 22 | ## Data Preparation 23 | The dataloader (utils/video_dataset.py) can load RGB frames stored in the following format: 24 | ``` 25 | -- dataset_dir 26 | ---- train.txt 27 | ---- val.txt 28 | ---- test.txt 29 | ---- videos 30 | ------ video_0_folder 31 | -------- 00001.jpg 32 | -------- 00002.jpg 33 | -------- ... 34 | ------ video_1_folder 35 | ------ ... 36 | ``` 37 | 38 | Each line in `train.txt` and `val.txt` includes 4 elements and separated by a symbol, e.g. space (` `) or semicolon (`;`). 39 | Four elements (in order) include (1) relative paths to `video_x_folder` from `dataset_dir`, (2) starting frame number, usually 1, (3) ending frame number, (4) label id (a numeric number). 40 | 41 | E.g., a `video_x` has `300` frames and belong to label `1`. 42 | ``` 43 | path/to/video_x_folder 1 300 1 44 | ``` 45 | The difference for `test.txt` is that each line will only have 3 elements (no label information). 46 | 47 | The same format is used for `optical flow` but each file (`00001.jpg`) need to be `x_00001.jpg` and `y_00001.jpg`. 48 | 49 | On the other hand, for audio data, you need to change the first elements to the path of corresponding `wav` files, like 50 | 51 | ``` 52 | path/to/audio_x.wav 1 300 1 53 | ``` 54 | 55 | After that, you need to update the `utils/data_config.py` for the datasets accordingly. 56 | 57 | We provide the scripts in the `tools` folder to extract RGB frames and audios from a video. To extract the optical flow, we use the docker image provided by [TSN](https://hub.docker.com/r/bitxiong/tsn/). Please see the help in the script. 58 | 59 | ## Pretrained models 60 | 61 | We provide the pretrained models on the Kinetics-Sounds dataset, including the unimodality models and our AdaMML models. You can find all the models [here](https://github.com/IBM/AdaMML/releases/tag/weights-v0.1). 62 | 63 | ## Training AdaMML Models 64 | 65 | After downloding the unimodality pretrained models (see below for training instructions), here is the command template to train AdaMML: 66 | 67 | ```shell script 68 | python3 train_adamml.py --multiprocessing-distributed --backbone_net adamml -d 50 \ 69 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 20 --warmup_epochs 5 --finetune_epochs 10 \ 70 | --modality MODALITY1 MODALITY2 --datadir /PATH/TO/MODALITY1 /PATH/TO/MODALITY2 --dataset DATASET --logdir LOGDIR \ 71 | --dense_sampling --fusion_point logits --unimodality_pretrained /PATH/TO/MODEL_MODALITY1 /PATH/TO/MODEL_MODALITY2 \ 72 | --learnable_lf_weights --num_segments 5 --cost_weights 1.0 0.005 --causality_modeling lstm --gammas 10.0 --sync-bn \ 73 | --lr 0.001 --p_lr 0.01 --lr_scheduler multisteps --lr_steps 10 15 74 | ``` 75 | 76 | The length of the following arguments depended on how many modalities you would like to include in AdaMML. 77 | - `--modality`: the modalities, other augments needs to follow this order 78 | - `--datadir`: the data dir for each modality 79 | - `--unimodality_pretrained`: the pretrained unimodality model 80 | 81 | Note that, to use `rgbdiff` as a proxy, both `rgbdiff` and `flow` needs to be specified in `--modality` and their corresponding `--datadir`. 82 | However, you only need to provided `flow` pretrained model in the `--unimodality_pretrained` 83 | 84 | Here are the examples to train AdaMML with different combinations. 85 | 86 | RGB + Audio 87 | 88 | ```shell script 89 | python3 train_adamml.py --multiprocessing-distributed --backbone_net adamml -d 50 \ 90 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 20 --warmup_epochs 5 --finetune_epochs 10 \ 91 | --modality rgb sound --datadir /PATH/TO/RGB_DATA /PATH/TO/AUDIO_DATA --dataset DATASET --logdir LOGDIR \ 92 | --dense_sampling --fusion_point logits --unimodality_pretrained /PATH/TO/RGB_MODEL /PATH/TO/AUDIO_MODEL \ 93 | --learnable_lf_weights --num_segments 5 --cost_weights 1.0 0.05 --causality_modeling lstm --gammas 10.0 --sync-bn \ 94 | --lr 0.001 --p_lr 0.01 --lr_scheduler multisteps --lr_steps 10 15 95 | ``` 96 | 97 | RGB + Flow (with RGBDiff as Proxy) 98 | 99 | ```shell script 100 | python3 train_adamml.py --multiprocessing-distributed --backbone_net adamml -d 50 \ 101 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 20 --warmup_epochs 5 --finetune_epochs 10 \ 102 | --modality rgb flow rgbdiff --datadir /PATH/TO/RGB_DATA /PATH/TO/FLOW_DATA /PATH/TO/RGB_DATA --dataset DATASET --logdir LOGDIR \ 103 | --dense_sampling --fusion_point logits --unimodality_pretrained /PATH/TO/RGB_MODEL /PATH/TO/FLOW_MODEL \ 104 | --learnable_lf_weights --num_segments 5 --cost_weights 1.0 1.0 --causality_modeling lstm --gammas 10.0 --sync-bn \ 105 | --lr 0.001 --p_lr 0.01 --lr_scheduler multisteps --lr_steps 10 15 106 | ``` 107 | 108 | RGB + Audio + Flow (with RGBDiff as Proxy) 109 | 110 | ```shell script 111 | python3 train_adamml.py --multiprocessing-distributed --backbone_net adamml -d 50 \ 112 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 20 --warmup_epochs 5 --finetune_epochs 10 \ 113 | --modality rgb sound flow rgbdiff --datadir /PATH/TO/RGB_DATA /PATH/TO/AUDIO_DATA /PATH/TO/FLOW_DATA /PATH/TO/RGB_DATA --dataset DATASET --logdir LOGDIR \ 114 | --dense_sampling --fusion_point logits --unimodality_pretrained /PATH/TO/RGB_MODEL /PATH/TO/SOUND_MODEL /PATH/TO/FLOW_MODEL \ 115 | --learnable_lf_weights --num_segments 5 --cost_weights 0.5 0.05 0.8 --causality_modeling lstm --gammas 10.0 --sync-bn \ 116 | --lr 0.001 --p_lr 0.01 --lr_scheduler multisteps --lr_steps 10 15 117 | ``` 118 | 119 | ## Training Unimodal Models 120 | 121 | Here are the example commands to train the unimodal models on different datasets: 122 | 123 | RGB 124 | 125 | ```shell script 126 | python3 train_unimodal.py --multiprocessing-distributed --backbone_net resnet -d 50 \ 127 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 60 --modality rgb \ 128 | --datadir /PATH/TO/RGB_DATA --dataset DATASET --logdir LOGDIR \ 129 | --dense_sampling --wd 0.0001 --augmentor_ver v2 --lr_scheduler multisteps --lr_steps 20 40 50 130 | ``` 131 | 132 | Flow 133 | 134 | ```shell script 135 | python3 train_unimodal.py --multiprocessing-distributed --backbone_net resnet -d 50 \ 136 | --groups 8 --frames_per_group 4 -b 72 -j 96 --epochs 60 --modality flow \ 137 | --datadir /PATH/TO/FLOW_DATA --dataset DATASET --logdir LOGDIR \ 138 | --dense_sampling --wd 0.0001 --augmentor_ver v2 --lr_scheduler multisteps --lr_steps 20 40 50 139 | ``` 140 | 141 | Audio 142 | 143 | ```shell script 144 | python3 train_unimodal.py --multiprocessing-distributed --backbone_net sound_mobilenet_v2 \ 145 | -b 72 -j 96 --epochs 60 --modality sound --wd 0.0001 --lr_scheduler multisteps --lr_steps 20 40 50 \ 146 | --datadir /PATH/TO/AUDIO_DATA --dataset DATASET --logdir LOGDIR 147 | ``` 148 | 149 | 150 | ## Evaluation 151 | 152 | Testing an AdaMML model is very straight-forward, you can simply use the training command with following modifications: 153 | - add `-e` in the command 154 | - use `--pretrained /PATH/TO/MODEL` to load the trained model 155 | - remove `--multiprocessing-distributed` and `--unimodality_pretrained` 156 | - set `--val_num_clips` if you would like to test under different number of video segments (default is 10) 157 | 158 | Here is command template: 159 | 160 | ```shell script 161 | python3 train_adamml.py -e --backbone_net adamml -d 50 \ 162 | --groups 8 --frames_per_group 4 -b 72 -j 96 \ 163 | --modality MODALITY1 MODALITY2 --datadir /PATH/TO/MODALITY1 /PATH/TO/MODALITY2 --dataset DATASET --logdir LOGDIR \ 164 | --dense_sampling --fusion_point logits --pretrained /PATH/TO/ADAMML_MODEL \ 165 | --learnable_lf_weights --num_segments 5 --causality_modeling lstm --sync-bn 166 | ``` 167 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.adamml import adamml 2 | from models.resnet import resnet 3 | from models.sound_mobilenet_v2 import sound_mobilenet_v2 4 | from .model_builder import build_model 5 | 6 | __all__ = [ 7 | 'adamml', 8 | 'resnet', 9 | 'sound_mobilenet_v2', 10 | 'build_model' 11 | ] 12 | -------------------------------------------------------------------------------- /models/adamml.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.policy_net import p_joint_mobilenet 7 | from models.joint_resnet_mobilenetv2 import joint_resnet_mobilenetv2 8 | 9 | __all__ = ['adamml'] 10 | 11 | 12 | class AdaMML(nn.Module): 13 | 14 | def __init__(self, policy_net, main_net, num_frames, num_segments, modality, rng_policy, rng_threshold, num_classes): 15 | super().__init__() 16 | self.rng_policy = rng_policy 17 | self.policy_net = policy_net 18 | self.main_net = main_net 19 | self.num_segments = num_segments 20 | self.num_frames = num_frames * num_segments 21 | self.num_frames_per_segment = num_frames 22 | self.modality = modality 23 | 24 | if 'rgbdiff' in modality and 'flow' in modality: 25 | self.num_modality = len(modality) - 1 26 | else: 27 | self.num_modality = len(modality) 28 | self.p_data_idx = [self.modality.index(x) for x in self.policy_net.modality] 29 | self.m_data_idx = [self.modality.index(x) for x in self.main_net.modality] 30 | 31 | self.rng_threshold = rng_threshold 32 | 33 | self.decay_ratio = 0.965 34 | 35 | self.update_policy_net = True 36 | self.update_main_net = True 37 | 38 | if self.rng_policy: 39 | self.freeze_policy_net() 40 | del self.policy_net.fcs 41 | 42 | def data_layer(self, x, num_segments, p_rgb_size=(160, 160)): 43 | p_x, m_x = [], [] 44 | idx = 0 45 | for x_, m in zip(x, self.modality): 46 | if m == 'sound': 47 | # when getting consecutive segments for sound, the signals are stacked at the last dim 48 | # however, if getting 10 clips, the signals are stacked at the second dim 49 | if x_.size(-1) != x_.size(-2): # in training 50 | tmp = x_.chunk(num_segments, dim=-1) 51 | tmp = torch.stack(tmp, dim=0).contiguous() 52 | else: 53 | tmp = x_.view((x_.size(0), num_segments, -1,) + x_.shape[-2:]).transpose(0, 1).contiguous() 54 | p_x.append(tmp) 55 | m_x.append(tmp) 56 | else: # only subsampling non-sound 57 | if idx in self.p_data_idx: 58 | b, s_f_c, h, w = x_.shape 59 | tmp = F.interpolate(x_, size=p_rgb_size, mode='bilinear') 60 | tmp = tmp.view((b, num_segments, self.num_frames_per_segment, -1, ) + p_rgb_size) 61 | tmp = tmp[:, :, range(0, self.num_frames_per_segment, 2), ...] 62 | tmp = tmp.view((b, num_segments, -1,) + p_rgb_size).transpose(0, 1).contiguous() 63 | p_x.append(tmp) 64 | if idx in self.m_data_idx: 65 | m_x.append(x_.view((x_.size(0), num_segments, -1,) + x_.shape[-2:]).transpose(0, 1).contiguous()) 66 | idx += 1 67 | return p_x, m_x, num_segments 68 | 69 | def forward(self, x, num_segments=None): 70 | # x: [Nx(SFC)xHxW], N is batch size, S segment/clip, F frames per clip, C channels, length of list is M 71 | # [SxNxFCxHxW], S normal input of networks, conversion for all modalities, length of list is M 72 | num_segments = num_segments if num_segments else self.num_segments 73 | p_x, m_x, num_segments = self.data_layer(x, num_segments) 74 | if not self.rng_policy: 75 | decisions, decision_logits = self.policy_net(p_x) 76 | else: 77 | decisions = (torch.rand((num_segments, self.num_modality, x[0].size(0)), 78 | dtype=x[0].dtype, device=x[0].device) > self.rng_threshold).float() 79 | # SxMxN tensors, M is number modality and N is batch size, each element is 0 or 1 80 | 81 | # in main net, run each segment one by one to save memory and 82 | # use decision to mask out the output, but still run whole main network 83 | all_logits = [] # NxSxC_class 84 | for i in range(num_segments): 85 | tmp_x = ([m_x[m_i][i, ...] for m_i in range(self.num_modality)]) 86 | all_logits.append(self.main_net(tmp_x, decisions[i])) # NxC 87 | 88 | final_logits = torch.stack(all_logits, dim=1).mean(dim=1) 89 | # reshape, let batch as the first index 90 | decisions = decisions.permute((2, 0, 1)) 91 | return final_logits, decisions 92 | 93 | def mean(self, modality='rgb'): 94 | return [0.485, 0.456, 0.406] if modality == 'rgb' or modality == 'rgbdiff' \ 95 | else [0.5] 96 | 97 | def std(self, modality='rgb'): 98 | return [0.229, 0.224, 0.225] if modality == 'rgb' or modality == 'rgbdiff' \ 99 | else [np.mean([0.229, 0.224, 0.225])] 100 | 101 | @property 102 | def network_name(self): 103 | name = 'adamml' 104 | if self.rng_policy: 105 | name += '-rng-{:.1f}'.format(self.rng_threshold) 106 | else: 107 | name += '-{}'.format(self.policy_net.network_name) 108 | name += '-{}'.format(self.main_net.network_name) 109 | return name 110 | 111 | def decay_temperature(self, decay_ratio=None): 112 | self.policy_net.decay_temperature(decay_ratio if decay_ratio else self.decay_ratio) 113 | 114 | def freeze_policy_net(self): 115 | self.update_policy_net = False 116 | for param in self.policy_net.parameters(): 117 | param.requires_grad = False 118 | 119 | def unfreeze_policy_net(self): 120 | self.update_policy_net = True 121 | for param in self.policy_net.parameters(): 122 | param.requires_grad = True 123 | 124 | def freeze_main_net(self): 125 | self.update_main_net = False 126 | for param in self.main_net.parameters(): 127 | param.requires_grad = False 128 | 129 | def unfreeze_main_net(self): 130 | self.update_main_net = True 131 | for param in self.main_net.parameters(): 132 | param.requires_grad = True 133 | 134 | def adamml( 135 | # shared parameters 136 | groups, modality, input_channels, num_segments, rng_policy, rng_threshold, 137 | # policy net parameters 138 | causality_modeling, 139 | # main net parameters 140 | num_classes, depth, without_t_stride, dropout, pooling_method, fusion_point, 141 | unimodality_pretrained, learnable_lf_weights, **kwargs): 142 | 143 | if 'rgbdiff' in modality and 'flow' in modality: 144 | p_modality = [x for x in modality if x != 'flow'] 145 | m_modality = [x for x in modality if x != 'rgbdiff'] 146 | p_input_channels = [x for x, m in zip(input_channels, modality) if m != 'flow'] 147 | m_input_channels = [x for x, m in zip(input_channels, modality) if m != 'rgbdiff'] 148 | else: 149 | p_modality = modality 150 | m_modality = modality 151 | p_input_channels = input_channels 152 | m_input_channels = input_channels 153 | 154 | # policy net 155 | policy_net = p_joint_mobilenet(num_frames=max(1, groups // 2), modality=p_modality, 156 | input_channels=p_input_channels, causality_modeling=causality_modeling) 157 | 158 | main_net = joint_resnet_mobilenetv2(depth=depth, num_classes=num_classes, 159 | without_t_stride=without_t_stride, 160 | groups=groups, dropout=dropout, 161 | pooling_method=pooling_method, 162 | input_channels=m_input_channels, 163 | fusion_point=fusion_point, modality=m_modality, 164 | unimodality_pretrained=unimodality_pretrained, 165 | learnable_lf_weights=learnable_lf_weights) 166 | 167 | model = AdaMML(policy_net, main_net, num_frames=groups, 168 | num_segments=num_segments, modality=modality, rng_policy=rng_policy, 169 | rng_threshold=rng_threshold, num_classes=num_classes) 170 | 171 | return model 172 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TemporalPooling(nn.Module): 5 | 6 | def __init__(self, frames, kernel_size=3, stride=2, mode='avg'): 7 | """ 8 | 9 | Parameters 10 | ---------- 11 | frames (int): number of input frames 12 | kernel_size 13 | stride 14 | mode 15 | """ 16 | super().__init__() 17 | self.frames = frames 18 | pad_size = (kernel_size - 1) // stride 19 | if mode == 'avg': 20 | self.pool = nn.AvgPool3d(kernel_size=(kernel_size, 1, 1), stride=(stride, 1, 1), 21 | padding=(pad_size, 0, 0)) 22 | elif mode == 'max': 23 | self.pool = nn.MaxPool3d(kernel_size=(kernel_size, 1, 1), stride=(stride, 1, 1), 24 | padding=(pad_size, 0, 0)) 25 | else: 26 | raise ValueError("only support avg or max") 27 | 28 | def forward(self, x): 29 | nt, c, h, w = x.shape 30 | x = x.view((-1, self.frames) + x.size()[1:]).transpose(1, 2) 31 | x = self.pool(x) 32 | x = x.transpose(1, 2).contiguous().view(-1, c, h, w) 33 | return x 34 | -------------------------------------------------------------------------------- /models/joint_resnet_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.resnet import ResNet 7 | from models.sound_mobilenet_v2 import MobileNetV2 8 | 9 | __all__ = ['joint_resnet_mobilenetv2'] 10 | 11 | class JointResNetMobileNetV2(nn.Module): 12 | 13 | def __init__(self, depth, num_frames, modality, num_classes=1000, dropout=0.5, zero_init_residual=False, 14 | without_t_stride=False, pooling_method='max', input_channels=None, 15 | fusion_point='logits', learnable_lf_weights=False): 16 | super().__init__() 17 | 18 | self.depth = depth 19 | self.num_frames = num_frames 20 | self.without_t_stride = without_t_stride 21 | self.pooling_method = pooling_method 22 | self.fusion_point = fusion_point 23 | self.modality = modality 24 | self.learnable_lf_weights = learnable_lf_weights 25 | 26 | self.nets = nn.ModuleList() 27 | self.last_channels = [] 28 | for i, m in enumerate(modality): 29 | if m != 'sound': 30 | net = ResNet(depth, num_frames, num_classes, dropout, zero_init_residual, 31 | without_t_stride, pooling_method, input_channels[i]) 32 | if self.fusion_point != 'logits': 33 | del net.avgpool 34 | del net.dropout 35 | del net.fc 36 | 37 | if depth >= 50: 38 | self.last_channels.append(2048) 39 | else: 40 | self.last_channels.append(512) 41 | else: 42 | net = MobileNetV2(num_classes, dropout=dropout, input_channels=input_channels[i]) 43 | if self.fusion_point != 'logits': 44 | del net.classifier 45 | self.last_channels.append(net.last_channel) 46 | self.nets.append(net) 47 | 48 | self.lf_weights = None 49 | 50 | if self.fusion_point != 'logits': 51 | self.avgpool = nn.AdaptiveAvgPool2d(1) 52 | in_feature_c = sum(self.last_channels) 53 | out_feature_c = 2048 54 | self.joint = nn.Sequential( 55 | nn.Linear(in_feature_c, out_feature_c), nn.ReLU(True), 56 | nn.Linear(out_feature_c, out_feature_c), nn.ReLU(True) 57 | ) 58 | self.dropout = nn.Dropout(p=dropout) 59 | self.fc = nn.Linear(out_feature_c, num_classes) 60 | else: 61 | init_prob = 1.0 / len(self.modality) 62 | if learnable_lf_weights: 63 | self.lf_weights = nn.Parameter(torch.tensor([init_prob] * (len(self.modality) - 1))) 64 | self.register_parameter('lf_weights', self.lf_weights) 65 | 66 | def mean(self, modality='rgb'): 67 | return [0.485, 0.456, 0.406] if modality == 'rgb' or modality == 'rgbdiff' \ 68 | else [0.5] 69 | 70 | def std(self, modality='rgb'): 71 | return [0.229, 0.224, 0.225] if modality == 'rgb' or modality == 'rgbdiff' \ 72 | else [np.mean([0.229, 0.224, 0.225])] 73 | 74 | @property 75 | def network_name(self): 76 | name = 'joint_resnet-{}_mobilenet_v2-{}'.format(self.depth, self.fusion_point) 77 | if self.lf_weights is not None: 78 | name += "-llf" if self.learnable_lf_weights else '-llfc' 79 | if not self.without_t_stride: 80 | name += "-ts-{}".format(self.pooling_method) 81 | 82 | return name 83 | 84 | def forward(self, multi_modalities, decisions=None): 85 | # multi_modalities is a list 86 | bs, _, _, _ = multi_modalities[0].shape 87 | out = [] 88 | for i, x in enumerate(multi_modalities): 89 | tmp = self.nets[i].features(x) if self.fusion_point != 'logits' else self.nets[i].forward(x) 90 | tmp = self.avgpool(tmp) if self.fusion_point != 'logits' else tmp 91 | 92 | if decisions is not None: 93 | if self.fusion_point == 'logits': 94 | tmp = tmp * decisions[i].view((tmp.size(0), 1)) 95 | else: 96 | raise ValueError("only support logits mode") 97 | out.append(tmp) 98 | 99 | if self.fusion_point != 'logits': 100 | out = torch.cat(out, dim=1) 101 | out = out.view(out.size(0), -1) 102 | out = self.joint(out) 103 | out = self.dropout(out) 104 | out = self.fc(out) 105 | 106 | n_t, c = out.shape 107 | out = out.view(bs, -1, c) 108 | 109 | # average the prediction from all frames 110 | out = torch.mean(out, dim=1) 111 | else: 112 | out = torch.stack(out, dim=0) 113 | out.squeeze_(-1) 114 | out.squeeze_(-1) # MxNxC 115 | if self.lf_weights is not None: 116 | if self.lf_weights.dim() > 1: 117 | comple_weights = torch.ones((1, self.lf_weights.size(-1)), dtype=self.lf_weights.dtype, 118 | device=self.lf_weights.device) \ 119 | - torch.sum(self.lf_weights, dim=0) 120 | else: 121 | comple_weights = torch.ones(1, dtype=self.lf_weights.dtype, device=self.lf_weights.device) - torch.sum(self.lf_weights, dim=0) 122 | weights = torch.cat((self.lf_weights, comple_weights), dim=0) 123 | weights = weights.view(weights.size(0), 1, -1) 124 | out = out * weights 125 | out = torch.sum(out, dim=0) 126 | else: 127 | out = torch.mean(out, dim=0) 128 | return out 129 | 130 | 131 | def joint_resnet_mobilenetv2(depth, num_classes, without_t_stride, groups, dropout, pooling_method, 132 | input_channels, fusion_point, modality, unimodality_pretrained, 133 | learnable_lf_weights, **kwargs): 134 | 135 | model = JointResNetMobileNetV2(depth, num_frames=groups, num_classes=num_classes, 136 | without_t_stride=without_t_stride, dropout=dropout, 137 | pooling_method=pooling_method, input_channels=input_channels, 138 | fusion_point=fusion_point, modality=modality, 139 | learnable_lf_weights=learnable_lf_weights) 140 | 141 | if len(unimodality_pretrained) > 0: 142 | if len(unimodality_pretrained) != len(model.nets): 143 | raise ValueError("the number of pretrained models is incorrect.") 144 | for i, m in enumerate(modality): 145 | print("Loading unimodality pretrained model from: {}".format(unimodality_pretrained[i])) 146 | state_dict = torch.load(unimodality_pretrained[i], map_location='cpu')['state_dict'] 147 | new_state_dict = {key.replace("module.", ""): v for key, v in state_dict.items()} 148 | if fusion_point != 'logits': 149 | if m != 'sound': 150 | new_state_dict.pop('fc.weight', None) 151 | new_state_dict.pop('fc.bias', None) 152 | else: 153 | new_state_dict.pop('classifier.1.weight', None) 154 | new_state_dict.pop('classifier.1.bias', None) 155 | model.nets[i].load_state_dict(new_state_dict, strict=True) 156 | 157 | return model 158 | -------------------------------------------------------------------------------- /models/model_builder.py: -------------------------------------------------------------------------------- 1 | from . import (adamml, resnet, sound_mobilenet_v2) 2 | 3 | MODEL_TABLE = { 4 | 'adamml': adamml, 5 | 'resnet': resnet, 6 | 'sound_mobilenet_v2': sound_mobilenet_v2 7 | } 8 | 9 | 10 | def build_model(args, test_mode=False): 11 | """ 12 | Args: 13 | args: all options defined in opts.py and num_classes 14 | test_mode: 15 | Returns: 16 | network model 17 | architecture name 18 | """ 19 | model = MODEL_TABLE[args.backbone_net](**vars(args)) 20 | network_name = model.network_name if hasattr(model, 'network_name') else args.backbone_net 21 | 22 | if isinstance(args.modality, list): 23 | modality = '-'.join([x for x in args.modality]) 24 | else: 25 | modality = args.modality 26 | 27 | arch_name = "{dataset}-{modality}-{arch_name}".format( 28 | dataset=args.dataset, modality=modality, arch_name=network_name) 29 | arch_name += "-f{}".format(args.groups) 30 | if args.dense_sampling: 31 | arch_name += "-s{}".format(args.frames_per_group) 32 | 33 | 34 | # add setting info only in training 35 | if not test_mode: 36 | arch_name += "-{}{}-bs{}{}-e{}".format(args.lr_scheduler, "-syncbn" if args.sync_bn else "", 37 | args.batch_size, '-' + args.prefix if args.prefix else "", args.epochs) 38 | return model, arch_name 39 | -------------------------------------------------------------------------------- /models/policy_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.distributions 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | from models.common import TemporalPooling 11 | 12 | 13 | model_urls = { 14 | 'mobilenet_v2': 'https://raw.githubusercontent.com/d-li14/mobilenetv2.pytorch/master/pretrained/mobilenetv2_160x160-64dc7fa1.pth' 15 | } 16 | 17 | 18 | def _make_divisible(v, divisor, min_value=None): 19 | """ 20 | This function is taken from the original tf repo. 21 | It ensures that all layers have a channel number that is divisible by 8 22 | It can be seen here: 23 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 24 | :param v: 25 | :param divisor: 26 | :param min_value: 27 | :return: 28 | """ 29 | if min_value is None: 30 | min_value = divisor 31 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 32 | # Make sure that round down does not go down by more than 10%. 33 | if new_v < 0.9 * v: 34 | new_v += divisor 35 | return new_v 36 | 37 | 38 | def conv_3x3_bn(inp, oup, stride): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 41 | nn.BatchNorm2d(oup), 42 | nn.ReLU6(inplace=True) 43 | ) 44 | 45 | 46 | def conv_1x1_bn(inp, oup): 47 | return nn.Sequential( 48 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(oup), 50 | nn.ReLU6(inplace=True) 51 | ) 52 | 53 | 54 | class InvertedResidual(nn.Module): 55 | def __init__(self, inp, oup, stride, expand_ratio, num_frames=None): 56 | super(InvertedResidual, self).__init__() 57 | assert stride in [1, 2] 58 | 59 | self.temporal_pool = TemporalPooling(num_frames, mode='max') if num_frames else None 60 | hidden_dim = round(inp * expand_ratio) 61 | self.identity = stride == 1 and inp == oup 62 | 63 | if expand_ratio == 1: 64 | self.conv = nn.Sequential( 65 | # dw 66 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 67 | nn.BatchNorm2d(hidden_dim), 68 | nn.ReLU6(inplace=True), 69 | # pw-linear 70 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup), 72 | ) 73 | else: 74 | self.conv = nn.Sequential( 75 | # pw 76 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 77 | nn.BatchNorm2d(hidden_dim), 78 | nn.ReLU6(inplace=True), 79 | # dw 80 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 81 | nn.BatchNorm2d(hidden_dim), 82 | nn.ReLU6(inplace=True), 83 | # pw-linear 84 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 85 | nn.BatchNorm2d(oup), 86 | ) 87 | 88 | def forward(self, x): 89 | if self.temporal_pool: 90 | x = self.temporal_pool(x) 91 | 92 | if self.identity: 93 | return x + self.conv(x) 94 | else: 95 | return self.conv(x) 96 | 97 | 98 | class MobileNetV2(nn.Module): 99 | def __init__(self, num_classes=1000, num_frames=4, input_channels=3, width_mult=1.): 100 | super(MobileNetV2, self).__init__() 101 | # setting of inverted residual blocks 102 | self.cfgs = [ 103 | # t, c, n, s 104 | [1, 16, 1, 1], 105 | [6, 24, 2, 2], 106 | [6, 32, 3, 2], 107 | [6, 64, 4, 2], 108 | [6, 96, 3, 1], 109 | [6, 160, 3, 2], 110 | [6, 320, 1, 1], 111 | ] 112 | self.input_channels = input_channels 113 | self.num_frames = num_frames 114 | self.orig_num_frames = num_frames 115 | # building first layer 116 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 117 | layers = [conv_3x3_bn(input_channels, input_channel, 2)] 118 | # building inverted residual blocks 119 | block = InvertedResidual 120 | for t, c, n, s in self.cfgs: 121 | has_tp = True if c == 64 or c == 160 else False 122 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 123 | for i in range(n): 124 | num_frames = self.num_frames if i == 0 and has_tp and self.num_frames != 1 \ 125 | else None 126 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, 127 | num_frames=num_frames)) 128 | input_channel = output_channel 129 | if has_tp: 130 | self.num_frames //= 2 131 | self.features = nn.Sequential(*layers) 132 | # building last several layers 133 | self.last_channel = int(1280 * width_mult) 134 | output_channel = _make_divisible(self.last_channel, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 135 | self.conv = conv_1x1_bn(input_channel, output_channel) 136 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 137 | # self.dropout = nn.Dropout(p=dropout) 138 | self.classifier = nn.Linear(output_channel, num_classes) 139 | 140 | self._initialize_weights() 141 | 142 | def feature_extraction(self, x): 143 | bs, c_t, h, w = x.shape 144 | x = x.view(bs * self.orig_num_frames, c_t // self.orig_num_frames, h, w) 145 | x = self.features(x) 146 | x = self.conv(x) 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | return x 150 | 151 | def forward(self, x, with_feature=False): 152 | bs, c_t, h, w = x.shape 153 | fea = self.feature_extraction(x) 154 | # x = self.dropout(fea) 155 | x = self.classifier(fea) 156 | n_t, c = x.shape 157 | out = x.view(bs, -1, c) 158 | 159 | # average the prediction from all frames 160 | out = torch.mean(out, dim=1) 161 | if with_feature: 162 | return out, fea 163 | else: 164 | return out 165 | 166 | def _initialize_weights(self): 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 170 | m.weight.data.normal_(0, math.sqrt(2. / n)) 171 | if m.bias is not None: 172 | m.bias.data.zero_() 173 | elif isinstance(m, nn.BatchNorm2d): 174 | m.weight.data.fill_(1) 175 | m.bias.data.zero_() 176 | elif isinstance(m, nn.Linear): 177 | m.weight.data.normal_(0, 0.01) 178 | m.bias.data.zero_() 179 | 180 | def mean(self, modality='rgb'): 181 | return [0.485, 0.456, 0.406] if modality == 'rgb' or modality == 'rgbdiff'\ 182 | else [0.5] 183 | 184 | def std(self, modality='rgb'): 185 | return [0.229, 0.224, 0.225] if modality == 'rgb' or modality == 'rgbdiff'\ 186 | else [np.mean([0.229, 0.224, 0.225])] 187 | 188 | @property 189 | def network_name(self): 190 | name = 'mobilenet_v2' 191 | return name 192 | 193 | def load_imagenet_model(self): 194 | state_dict = model_zoo.load_url(model_urls['mobilenet_v2'], map_location='cpu') 195 | if self.input_channels != 3: # convert the RGB model to others, like flow 196 | value = state_dict['features.0.0.weight'] 197 | o_c, _, k_h, k_w = value.shape 198 | new_shape = (o_c, self.input_channels, k_h, k_w) 199 | state_dict['features.0.0.weight'] = value.mean(dim=1, keepdim=True).expand( 200 | new_shape).contiguous() 201 | state_dict.pop('classifier.weight', None) 202 | state_dict.pop('classifier.bias', None) 203 | self.load_state_dict(state_dict, strict=False) 204 | 205 | 206 | class JointMobileNetV2(nn.Module): 207 | 208 | def __init__(self, num_frames, modality, num_classes=1000, dropout=0.5, input_channels=None): 209 | super().__init__() 210 | 211 | self.num_frames = num_frames 212 | self.modality = modality 213 | 214 | self.nets = nn.ModuleList() 215 | self.last_channels = [] 216 | for i, m in enumerate(modality): 217 | net = MobileNetV2(num_classes, num_frames=1 if m == 'sound' else num_frames, 218 | input_channels=input_channels[i]) 219 | del net.classifier 220 | self.last_channels.append(net.last_channel) 221 | net.load_imagenet_model() 222 | self.nets.append(net) 223 | 224 | self.avgpool = nn.AdaptiveAvgPool2d(1) 225 | in_feature_c = sum(self.last_channels) 226 | out_feature_c = 2048 227 | self.last_channels = out_feature_c 228 | self.joint = nn.Sequential( 229 | nn.Linear(in_feature_c, out_feature_c), nn.ReLU(True), 230 | nn.Linear(out_feature_c, out_feature_c), nn.ReLU(True) 231 | ) 232 | self.dropout = nn.Dropout(p=dropout) 233 | self.fc = nn.Linear(out_feature_c, num_classes) 234 | 235 | def features(self, multi_modalities): 236 | # multi_modalities is a list 237 | bs, _, _, _ = multi_modalities[0].shape 238 | out = [] 239 | for i, x in enumerate(multi_modalities): 240 | tmp = self.nets[i].feature_extraction(x) 241 | out.append(tmp) 242 | 243 | out = torch.cat(out, dim=1) 244 | out = out.view(out.size(0), -1) 245 | out = self.joint(out) 246 | 247 | return out 248 | 249 | def forward(self, multi_modalities): 250 | bs, _, _, _ = multi_modalities[0].shape 251 | out = self.features(multi_modalities) 252 | out = self.dropout(out) 253 | out = self.fc(out) 254 | n_t, c = out.shape 255 | out = out.view(bs, -1, c) 256 | # average the prediction from all frames 257 | out = torch.mean(out, dim=1) 258 | return out 259 | 260 | 261 | class PolicyNet(nn.Module): 262 | 263 | def __init__(self, joint_net, modality, causality_modeling='lstm'): 264 | super().__init__() 265 | self.joint_net = joint_net 266 | if hasattr(self.joint_net, 'fc'): 267 | del self.joint_net.fc 268 | if hasattr(self.joint_net, 'dropout'): 269 | del self.joint_net.dropout 270 | self.modality = modality 271 | self.causality_modeling = causality_modeling 272 | self.num_modality = len(modality) 273 | self.temperature = 5.0 274 | feature_dim = self.joint_net.last_channels 275 | 276 | if causality_modeling is not None: 277 | embedded_dim = 256 278 | self.lstm = nn.LSTMCell(feature_dim + 2 * self.num_modality, embedded_dim) 279 | self.fcs = nn.ModuleList([nn.Linear(embedded_dim, 2) for _ in range(self.num_modality)]) 280 | else: 281 | self.fcs = nn.ModuleList([nn.Linear(feature_dim, 2) for _ in range(self.num_modality)]) 282 | 283 | def wrapper_gumbel_softmax(self, logits): 284 | """ 285 | :param logits: NxM, N is batch size, M is number of possible choices 286 | :return: Nx1: the selected index 287 | """ 288 | distributions = F.gumbel_softmax(logits, tau=self.temperature, hard=True) 289 | decisions = distributions[:, -1] 290 | return decisions 291 | 292 | def set_temperature(self, temperature): 293 | self.temperature = temperature 294 | 295 | def decay_temperature(self, decay_ratio=None): 296 | if decay_ratio: 297 | self.temperature *= decay_ratio 298 | print("Current temperature: {}".format(self.temperature), flush=True) 299 | 300 | def convert_index_to_decisions(self, decisions): 301 | """ 302 | 303 | :param decisions: Nx1, the index of selection 304 | :return: NxM, M is the number of modality, equals to the log2(max(decisions)) 305 | """ 306 | out = torch.zeros((decisions.size(0), self.num_modality), dtype=decisions.dtype, device=decisions.device) 307 | for m_i in range(self.num_modality): 308 | out[:, m_i] = decisions % 2 309 | decisions = torch.floor(decisions / 2) 310 | return out 311 | 312 | def forward(self, x): 313 | """ 314 | 315 | :param x: 316 | :return: all_logits shape is different when using single_fc. 317 | - single_fc: SxNx(2**M) 318 | - separate fc: MxSxNx2 319 | """ 320 | # x: M,SxNx(FC)xHxW 321 | num_segments = x[0].size(0) 322 | outs = [] 323 | for i in range(num_segments): 324 | tmp_x = [x[m_i][i, ...] for m_i in range(self.num_modality)] # M,Nx(FC)xHxW 325 | out = self.joint_net.features(tmp_x) # NxCout 326 | outs.append(out) 327 | outs = torch.stack(outs, dim=0) # SxNxCout 328 | 329 | # SxNxCout 330 | if self.causality_modeling is None: 331 | outs = outs.view((-1, outs.size(-1))) # (SN)xC 332 | logits = [] 333 | for m_i in range(self.num_modality): 334 | logits.append(self.fcs[m_i](outs)) # (SN)x2 335 | logits = torch.cat(logits, dim=0) # (MSN)x2 336 | decisions = self.wrapper_gumbel_softmax(logits) # (MSN)x1 337 | # (MSN)x1 338 | decisions = decisions.view((self.num_modality, num_segments, -1)).transpose(0, 1) 339 | all_logits = logits.view((self.num_modality, num_segments, -1, 2)).transpose(0, 1) 340 | # SxMxN 341 | elif self.causality_modeling == 'lstm': 342 | all_logits = [] 343 | decisions = [] 344 | h_xs, c_xs = None, None 345 | for i in range(num_segments): 346 | if i == 0: 347 | lstm_in = torch.cat((outs[i], 348 | torch.zeros((outs[i].shape[0], self.num_modality * 2), 349 | dtype=outs[i].dtype, device=outs[i].device) 350 | ), dim=-1) 351 | h_x, c_x = self.lstm(lstm_in) # h_x: Nxhidden, c_x: Nxhidden 352 | else: 353 | logits = logits.view((self.num_modality, -1, 2)).permute(1, 0, 2).contiguous().view(-1, 2 * self.num_modality) 354 | lstm_in = torch.cat((outs[i], logits), dim=-1) 355 | 356 | h_x, c_x = self.lstm(lstm_in, (h_x, c_x)) # h_x: Nxhidden, c_x: Nxhidden 357 | 358 | logits = [] 359 | for m_i in range(self.num_modality): 360 | tmp = self.fcs[m_i](h_x) # Nx2 361 | logits.append(tmp) 362 | logits = torch.cat(logits, dim=0) # MNx2 363 | all_logits.append(logits.view(self.num_modality, -1, 2)) 364 | selection = self.wrapper_gumbel_softmax(logits) # MNx1 365 | decisions.append(selection) 366 | decisions = torch.stack(decisions, dim=0).view(num_segments, self.num_modality, -1) 367 | all_logits = torch.stack(all_logits, dim=0) 368 | # SxMxN 369 | else: 370 | raise ValueError("unknown mode") 371 | 372 | # dim of decision: SxMxN 373 | return decisions, all_logits 374 | 375 | @property 376 | def network_name(self): 377 | name = 'j_mobilenet_v2{}'.format('-' + self.causality_modeling 378 | if self.causality_modeling else '') 379 | return name 380 | 381 | 382 | def p_joint_mobilenet(num_frames, modality, input_channels, causality_modeling): 383 | 384 | joint_net = JointMobileNetV2(num_frames=num_frames, modality=modality, input_channels=input_channels) 385 | model = PolicyNet(joint_net, modality, causality_modeling=causality_modeling) 386 | 387 | return model 388 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from models.common import TemporalPooling 7 | 8 | __all__ = ['ResNet'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def convert_rgb_model_to_others(state_dict, input_channels, ks=7): 20 | new_state_dict = {} 21 | for key, value in state_dict.items(): 22 | if "conv1.weight" in key: 23 | o_c, in_c, k_h, k_w = value.shape 24 | else: 25 | o_c, in_c, k_h, k_w = 0, 0, 0, 0 26 | if in_c == 3 and k_h == ks and k_w == ks: 27 | # average the weights and expand to all channels 28 | new_shape = (o_c, input_channels, k_h, k_w) 29 | new_value = value.mean(dim=1, keepdim=True).expand(new_shape).contiguous() 30 | else: 31 | new_value = value 32 | new_state_dict[key] = new_value 33 | return new_state_dict 34 | 35 | def conv3x3(in_planes, out_planes, stride=1): 36 | """3x3 convolution with padding""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 38 | padding=1, bias=False) 39 | 40 | 41 | def conv1x1(in_planes, out_planes, stride=1): 42 | """1x1 convolution""" 43 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(BasicBlock, self).__init__() 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None): 81 | super(Bottleneck, self).__init__() 82 | 83 | self.conv1 = conv1x1(inplanes, planes) 84 | self.bn1 = nn.BatchNorm2d(planes) 85 | self.conv2 = conv3x3(planes, planes, stride) 86 | self.bn2 = nn.BatchNorm2d(planes) 87 | self.conv3 = conv1x1(planes, planes * self.expansion) 88 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | 94 | def forward(self, x): 95 | identity = x 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | identity = self.downsample(x) 109 | 110 | out += identity 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self, depth, num_frames, num_classes=1000, dropout=0.5, zero_init_residual=False, 119 | without_t_stride=False, pooling_method='max', input_channels=3): 120 | super(ResNet, self).__init__() 121 | 122 | self.pooling_method = pooling_method.lower() 123 | block = BasicBlock if depth < 50 else Bottleneck 124 | layers = { 125 | 18: [2, 2, 2, 2], 126 | 34: [3, 4, 6, 3], 127 | 50: [3, 4, 6, 3], 128 | 101: [3, 4, 23, 3], 129 | 152: [3, 8, 36, 3]}[depth] 130 | 131 | self.depth = depth 132 | self.num_frames = num_frames 133 | self.orig_num_frames = num_frames 134 | self.num_classes = num_classes 135 | self.without_t_stride = without_t_stride 136 | 137 | self.inplanes = 64 138 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | 143 | self.layer1 = self._make_layer(block, 64, layers[0]) 144 | if not self.without_t_stride: 145 | self.pool1 = TemporalPooling(self.num_frames, 3, 2, self.pooling_method) 146 | self.num_frames = max(1, self.num_frames // 2) 147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 148 | if not self.without_t_stride: 149 | self.pool2 = TemporalPooling(self.num_frames, 3, 2, self.pooling_method) 150 | self.num_frames = max(1, self.num_frames // 2) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 152 | if not self.without_t_stride: 153 | self.pool3 = TemporalPooling(self.num_frames, 3, 2, self.pooling_method) 154 | self.num_frames = max(1, self.num_frames // 2) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | self.dropout = nn.Dropout(dropout) 158 | 159 | self.fc = nn.Linear(512 * block.expansion, num_classes) 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1): 162 | downsample = None 163 | if stride != 1 or self.inplanes != planes * block.expansion: 164 | downsample = nn.Sequential( 165 | conv1x1(self.inplanes, planes * block.expansion, stride), 166 | nn.BatchNorm2d(planes * block.expansion), 167 | ) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample)) 171 | self.inplanes = planes * block.expansion 172 | for _ in range(1, blocks): 173 | layers.append(block(self.inplanes, planes)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | 178 | def features(self, x): 179 | batch_size, c_t, h, w = x.shape 180 | x = x.view(batch_size * self.orig_num_frames, c_t // self.orig_num_frames, h, w) 181 | x = self.conv1(x) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | fp1 = self.maxpool(x) 185 | 186 | fp2 = self.layer1(fp1) 187 | fp2_d = self.pool1(fp2) if not self.without_t_stride else fp2 188 | fp3 = self.layer2(fp2_d) 189 | fp3_d = self.pool2(fp3) if not self.without_t_stride else fp3 190 | fp4 = self.layer3(fp3_d) 191 | fp4_d = self.pool3(fp4) if not self.without_t_stride else fp4 192 | fp5 = self.layer4(fp4_d) 193 | return fp5 194 | 195 | def forward(self, x): 196 | batch_size, c_t, h, w = x.shape 197 | if c_t != 1: # handle audio input 198 | x = x.view(batch_size * self.orig_num_frames, c_t // self.orig_num_frames, h, w) 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | x = self.relu(x) 202 | fp1 = self.maxpool(x) 203 | 204 | fp2 = self.layer1(fp1) 205 | fp2_d = self.pool1(fp2) if not self.without_t_stride else fp2 206 | fp3 = self.layer2(fp2_d) 207 | fp3_d = self.pool2(fp3) if not self.without_t_stride else fp3 208 | fp4 = self.layer3(fp3_d) 209 | fp4_d = self.pool3(fp4) if not self.without_t_stride else fp4 210 | fp5 = self.layer4(fp4_d) 211 | 212 | x = self.avgpool(fp5) 213 | x = x.view(x.size(0), -1) 214 | x = self.dropout(x) 215 | x = self.fc(x) 216 | 217 | n_t, c = x.shape 218 | out = x.view(batch_size, -1, c) 219 | 220 | # average the prediction from all frames 221 | out = torch.mean(out, dim=1) 222 | 223 | return out 224 | 225 | def mean(self, modality='rgb'): 226 | return [0.485, 0.456, 0.406] if modality == 'rgb' or modality == 'rgbdiff'\ 227 | else [0.5] 228 | 229 | def std(self, modality='rgb'): 230 | return [0.229, 0.224, 0.225] if modality == 'rgb' or modality == 'rgbdiff'\ 231 | else [np.mean([0.229, 0.224, 0.225])] 232 | 233 | @property 234 | def network_name(self): 235 | name = 'resnet-{}'.format(self.depth) 236 | if not self.without_t_stride: 237 | name += "-ts-{}".format(self.pooling_method) 238 | if self.fpn_dim > 0: 239 | name += "-fpn{}".format(self.fpn_dim) 240 | 241 | return name 242 | 243 | 244 | def resnet(depth, num_classes, without_t_stride, groups, dropout, pooling_method, 245 | input_channels, imagenet_pretrained=True, **kwargs): 246 | 247 | model = ResNet(depth, num_frames=groups, num_classes=num_classes, 248 | without_t_stride=without_t_stride, dropout=dropout, 249 | pooling_method=pooling_method, input_channels=input_channels) 250 | 251 | if imagenet_pretrained: 252 | state_dict = model_zoo.load_url(model_urls['resnet{}'.format(depth)], map_location='cpu') 253 | state_dict.pop('fc.weight', None) 254 | state_dict.pop('fc.bias', None) 255 | if input_channels != 3: # convert the RGB model to others, like flow 256 | state_dict = convert_rgb_model_to_others(state_dict, input_channels, 7) 257 | model.load_state_dict(state_dict, strict=False) 258 | 259 | return model 260 | -------------------------------------------------------------------------------- /models/sound_mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.utils.model_zoo as model_zoo 3 | import numpy as np 4 | 5 | __all__ = ['MobileNetV2', 'sound_mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 35 | padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 38 | nn.BatchNorm2d(out_planes), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | class InvertedResidual(nn.Module): 44 | def __init__(self, inp, oup, stride, expand_ratio): 45 | super(InvertedResidual, self).__init__() 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = int(round(inp * expand_ratio)) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | layers = [] 53 | if expand_ratio != 1: 54 | # pw 55 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 56 | layers.extend([ 57 | # dw 58 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 59 | # pw-linear 60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 61 | nn.BatchNorm2d(oup), 62 | ]) 63 | self.conv = nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | if self.use_res_connect: 67 | return x + self.conv(x) 68 | else: 69 | return self.conv(x) 70 | 71 | 72 | class MobileNetV2(nn.Module): 73 | def __init__(self, 74 | num_classes=1000, 75 | width_mult=1.0, 76 | inverted_residual_setting=None, 77 | round_nearest=8, 78 | block=None, 79 | input_channels=3, 80 | dropout=0.5): 81 | """ 82 | MobileNet V2 main class 83 | 84 | Args: 85 | num_classes (int): Number of classes 86 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 87 | inverted_residual_setting: Network structure 88 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 89 | Set to 1 to turn off rounding 90 | block: Module specifying inverted residual building block for mobilenet 91 | 92 | """ 93 | super(MobileNetV2, self).__init__() 94 | 95 | if block is None: 96 | block = InvertedResidual 97 | input_channel = 32 98 | last_channel = 1280 99 | 100 | if inverted_residual_setting is None: 101 | inverted_residual_setting = [ 102 | # t, c, n, s 103 | [1, 16, 1, 1], 104 | [6, 24, 2, 2], 105 | [6, 32, 3, 2], 106 | [6, 64, 4, 2], 107 | [6, 96, 3, 1], 108 | [6, 160, 3, 2], 109 | [6, 320, 1, 1], 110 | ] 111 | 112 | # only check the first element, assuming user knows t,c,n,s are required 113 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 114 | raise ValueError("inverted_residual_setting should be non-empty " 115 | "or a 4-element list, got {}".format(inverted_residual_setting)) 116 | 117 | # building first layer 118 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 119 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 120 | features = [ConvBNReLU(input_channels, input_channel, stride=2)] 121 | # building inverted residual blocks 122 | for t, c, n, s in inverted_residual_setting: 123 | output_channel = _make_divisible(c * width_mult, round_nearest) 124 | for i in range(n): 125 | stride = s if i == 0 else 1 126 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 127 | input_channel = output_channel 128 | # building last several layers 129 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 130 | # make it nn.Sequential 131 | self.features = nn.Sequential(*features) 132 | 133 | # building classifier 134 | self.classifier = nn.Sequential( 135 | nn.Dropout(dropout), 136 | nn.Linear(self.last_channel, num_classes), 137 | ) 138 | 139 | # weight initialization 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 143 | if m.bias is not None: 144 | nn.init.zeros_(m.bias) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | nn.init.ones_(m.weight) 147 | nn.init.zeros_(m.bias) 148 | elif isinstance(m, nn.Linear): 149 | nn.init.normal_(m.weight, 0, 0.01) 150 | nn.init.zeros_(m.bias) 151 | 152 | def _forward_impl(self, x): 153 | # This exists since TorchScript doesn't support inheritance, so the superclass method 154 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 155 | x = self.features(x) 156 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 157 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 158 | x = self.classifier(x) 159 | return x 160 | 161 | def forward(self, x): 162 | return self._forward_impl(x) 163 | 164 | def mean(self, modality='rgb'): 165 | return [0.485, 0.456, 0.406] if modality == 'rgb' or modality == 'rgbdiff'\ 166 | else [0.5] 167 | 168 | def std(self, modality='rgb'): 169 | return [0.229, 0.224, 0.225] if modality == 'rgb' or modality == 'rgbdiff'\ 170 | else [np.mean([0.229, 0.224, 0.225])] 171 | 172 | @property 173 | def network_name(self): 174 | name = 'sound_mobilenet_v2-{}'.format(self.depth) 175 | return name 176 | 177 | def sound_mobilenet_v2(num_classes, input_channels, dropout, imagenet_pretrained=True, **kwargs): 178 | """ 179 | Constructs a MobileNetV2 architecture from 180 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | progress (bool): If True, displays a progress bar of the download to stderr 185 | """ 186 | model = MobileNetV2(num_classes=num_classes, input_channels=input_channels, dropout=dropout) 187 | if imagenet_pretrained: 188 | state_dict = model_zoo.load_url(model_urls['mobilenet_v2']) 189 | if input_channels != 3: # convert the RGB model to others, like flow 190 | value = state_dict['features.0.0.weight'] 191 | o_c, _, k_h, k_w = value.shape 192 | new_shape = (o_c, input_channels, k_h, k_w) 193 | state_dict['features.0.0.weight'] = value.mean(dim=1, keepdim=True).expand(new_shape).contiguous() 194 | state_dict.pop('classifier.1.weight', None) 195 | state_dict.pop('classifier.1.bias', None) 196 | model.load_state_dict(state_dict, strict=False) 197 | 198 | return model 199 | 200 | 201 | if __name__ == "__main__": 202 | sound_mobilenet_v2(239, 1, 0.5, True) 203 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models.model_builder import MODEL_TABLE 3 | from utils.dataset_config import DATASET_CONFIG 4 | 5 | def arg_parser(): 6 | parser = argparse.ArgumentParser(description='PyTorch Action recognition Training') 7 | 8 | # model definition 9 | parser.add_argument('--backbone_net', default='s3d', type=str, help='backbone network', 10 | choices=list(MODEL_TABLE.keys())) 11 | parser.add_argument('-d', '--depth', default=18, type=int, metavar='N', 12 | help='depth of resnet (default: 18)', choices=[18, 34, 50, 101, 152]) 13 | parser.add_argument('--dropout', default=0.5, type=float, 14 | help='dropout ratio before the final layer') 15 | parser.add_argument('--groups', default=8, type=int, help='number of frames') 16 | parser.add_argument('--num_segments', default=1, type=int, help='number of consecutvie segments for adamml') 17 | parser.add_argument('--frames_per_group', default=1, type=int, 18 | help='[uniform sampling] number of frames per group; ' 19 | '[dense sampling]: sampling frequency') 20 | parser.add_argument('--without_t_stride', dest='without_t_stride', action='store_true', 21 | help='skip the temporal stride in the model') 22 | parser.add_argument('--pooling_method', default='max', 23 | choices=['avg', 'max'], help='method for temporal pooling method or ' 24 | 'which pool3d module') 25 | parser.add_argument('--fusion_point', default='logits', type=str, help='where to combine the features', 26 | choices=['fc2', 'logits']) 27 | parser.add_argument('--prefix', default='', type=str, help='model prefix') 28 | parser.add_argument('--learnable_lf_weights', action='store_true') 29 | parser.add_argument('--causality_modeling', default=None, type=str, 30 | help='causality modeling in policy net', choices=[None, 'lstm']) 31 | parser.add_argument('--cost_weights', default=None, type=float, nargs="+") 32 | parser.add_argument('--rng_policy', action='store_true', help='use rng as policy, baseline') 33 | parser.add_argument('--rng_threshold', type=float, default=0.5, help='rng threshold') 34 | parser.add_argument('--gammas', default=10.0, type=float) 35 | parser.add_argument('--penalty_type', default='blockdrop', type=str, choices=['mean', 'blockdrop']) 36 | 37 | # training setting 38 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 39 | parser.add_argument('--gpu_id', help='comma separated list of GPU(s) to use.', default=None) 40 | parser.add_argument('--disable_cudnn_benchmark', dest='cudnn_benchmark', action='store_false', 41 | help='Disable cudnn to search the best mode (avoid OOM)') 42 | parser.add_argument('-b', '--batch-size', default=72, type=int, 43 | metavar='N', help='mini-batch size (default: 72)') 44 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 45 | metavar='LR', help='initial learning rate') 46 | parser.add_argument('--p_lr', '--p_learning-rate', default=0.01, type=float, 47 | metavar='LR', help='initial learning rate for policy net') 48 | parser.add_argument('--lr_scheduler', default='cosine', type=str, 49 | help='learning rate scheduler', 50 | choices=['step', 'multisteps', 'cosine', 'plateau']) 51 | parser.add_argument('--lr_steps', default=[15, 30, 45], type=float, nargs="+", 52 | metavar='LRSteps', help='[step]: use a single value: the periodto decay ' 53 | 'learning rate by 10. ' 54 | '[multisteps] epochs to decay learning rate by 10') 55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 56 | parser.add_argument('--nesterov', action='store_true', 57 | help='enable nesterov momentum optimizer') 58 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 59 | metavar='W', help='weight decay (default: 1e-4)') 60 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 61 | help='number of total epochs to run') 62 | parser.add_argument('--warmup_epochs', default=5, type=int, metavar='N', 63 | help='number of total epochs for warmup') 64 | parser.add_argument('--finetune_epochs', default=10, type=int, metavar='N', 65 | help='number of total epochs for post finetune') 66 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | parser.add_argument('--auto_resume', action='store_true', help='if the log folder includes a checkpoint, automatically resume') 69 | parser.add_argument('--pretrained', dest='pretrained', type=str, metavar='PATH', 70 | help='use pre-trained model') 71 | parser.add_argument('--unimodality_pretrained', type=str, nargs="+", 72 | help='use pre-trained unimodality model', default=[]) 73 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 74 | help='manual epoch number (useful on restarts)') 75 | parser.add_argument('--clip_gradient', '--cg', default=None, type=float, 76 | help='clip the total norm of gradient before update parameter') 77 | parser.add_argument('--curr_stage', type=str, help='set stage for staging training', 78 | default='warmup', choices=['warmup', 'alternative_training', 'finetune']) 79 | # data-related 80 | parser.add_argument('-j', '--workers', default=18, type=int, metavar='N', 81 | help='number of data loading workers (default: 4)') 82 | parser.add_argument('--datadir', metavar='DIR', help='path to dataset file list', 83 | nargs="+", type=str) 84 | parser.add_argument('--dataset', default='activitynet', 85 | choices=list(DATASET_CONFIG.keys()), help='path to dataset file list') 86 | parser.add_argument('--threed_data', action='store_true', 87 | help='load data in the layout for 3D conv') 88 | parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size') 89 | parser.add_argument('--disable_scaleup', action='store_true', 90 | help='do not scale up and then crop a small region, directly crop the input_size') 91 | parser.add_argument('--random_sampling', action='store_true', 92 | help='perform determinstic sampling for data loader') 93 | parser.add_argument('--dense_sampling', action='store_true', 94 | help='perform dense sampling for data loader') 95 | parser.add_argument('--augmentor_ver', default='v2', type=str, choices=['v1', 'v2'], 96 | help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`') 97 | parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+", 98 | metavar='scale_range', help='scale range for augmentor v2') 99 | parser.add_argument('--modality', default=['rgb'], type=str, help='rgb or flow or rgbdiff', 100 | choices=['rgb', 'flow', 'rgbdiff', 'sound'], nargs="+") 101 | parser.add_argument('--mean', type=float, nargs="+", 102 | metavar='MEAN', help='mean, dimension should be 3 for RGB and RGBdiff, 1 for flow') 103 | parser.add_argument('--std', type=float, nargs="+", 104 | metavar='STD', help='std, dimension should be 3 for RGB and RGBdiff, 1 for flow') 105 | parser.add_argument('--skip_normalization', action='store_true', 106 | help='skip mean and std normalization, default use imagenet`s mean and std.') 107 | parser.add_argument('--fps', type=float, metavar='FPS', default=29.97, help='fps of the video') 108 | parser.add_argument('--audio_length', type=float, default=1.28, help='length of audio segment') 109 | parser.add_argument('--resampling_rate', type=float, default=24000, 110 | help='resampling rate of audio data') 111 | # logging 112 | parser.add_argument('--logdir', default='', type=str, help='log path') 113 | parser.add_argument('--print-freq', default=100, type=int, 114 | help='frequency to print the log during the training') 115 | parser.add_argument('--show_model', action='store_true', help='show model summary') 116 | 117 | # for testing and validation 118 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 119 | help='evaluate model on validation set') 120 | parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10]) 121 | parser.add_argument('--num_clips', default=1, type=int) 122 | parser.add_argument('--val_num_clips', default=10, type=int) 123 | parser.add_argument('--pred_files', type=str, nargs="+", 124 | help='scale range for augmentor v2') 125 | parser.add_argument('--pred_weights', type=float, nargs="+", 126 | help='scale range for augmentor v2') 127 | parser.add_argument('--after_softmax', action='store_true', help="perform softmax before ensumble") 128 | parser.add_argument('--lazy_eval', action='store_true', help="evaluate every 10 epochs and last 10 percentage of epochs") 129 | 130 | # for distributed learning, not supported yet 131 | parser.add_argument('--sync-bn', action='store_true', 132 | help='sync BN across GPUs') 133 | parser.add_argument('--world-size', default=1, type=int, 134 | help='number of nodes for distributed training') 135 | parser.add_argument('--rank', default=0, type=int, 136 | help='node rank for distributed training') 137 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str, 138 | help='url used to set up distributed training') 139 | parser.add_argument('--hostfile', default='', type=str, 140 | help='hostfile distributed learning') 141 | parser.add_argument('--dist-backend', default='nccl', type=str, 142 | help='distributed backend') 143 | parser.add_argument('--multiprocessing-distributed', action='store_true', 144 | help='Use multi-processing distributed training to launch ' 145 | 'N processes per node, which has N GPUs. This is the ' 146 | 'fastest way to use PyTorch for either single node or ' 147 | 'multi node data parallel training') 148 | 149 | return parser 150 | -------------------------------------------------------------------------------- /tools/extract_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | import glob 5 | from tqdm import tqdm 6 | 7 | 8 | def ffmpeg_extraction(input_video, output_sound, sample_rate): 9 | ffmpeg_command = ['ffmpeg', '-i', input_video, 10 | '-vn', '-acodec', 'pcm_s16le', 11 | '-loglevel', 'panic', 12 | '-ac', '1', '-ar', sample_rate, 13 | output_sound] 14 | 15 | subprocess.call(ffmpeg_command) 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('videos_dir', help='Input directory of videos with audio') 22 | parser.add_argument('output_dir', help='Output directory to store .wav files') 23 | parser.add_argument('--sample_rate', default='24000', help='Rate to resample audio') 24 | parser.add_argument('--ext', default=['.mp4'], nargs='+', help='The extension of videos') 25 | 26 | args = parser.parse_args() 27 | 28 | video_list = glob.glob(args.videos_dir + '/**/*.*', recursive=True) 29 | 30 | if not os.path.exists(args.output_dir): 31 | os.mkdir(args.output_dir) 32 | 33 | with tqdm(total=len(video_list)) as t_bar: 34 | for video in video_list: 35 | ffmpeg_extraction(video, 36 | os.path.join(args.output_dir, 37 | os.path.basename(video).split(".")[0] + ".wav"), 38 | args.sample_rate) 39 | t_bar.update() 40 | -------------------------------------------------------------------------------- /tools/extract_rgb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import skvideo.io 6 | import concurrent.futures 7 | import subprocess 8 | import glob 9 | from tqdm import tqdm 10 | 11 | 12 | def video_to_images(video, targetdir, short_side=256): 13 | filename = video 14 | output_foldername = os.path.join(targetdir, os.path.basename(video).split(".")[0]) 15 | if not os.path.exists(filename): 16 | print(f"{filename} is not existed.") 17 | return video, False 18 | else: 19 | try: 20 | video_meta = skvideo.io.ffprobe(filename) 21 | height = int(video_meta['video']['@height']) 22 | width = int(video_meta['video']['@width']) 23 | except Exception as e: 24 | print(f"Can not get video info: {filename}, error {e}") 25 | return video, False 26 | 27 | if width > height: 28 | scale = "scale=-1:{}".format(short_side) 29 | else: 30 | scale = "scale={}:-1".format(short_side) 31 | if not os.path.exists(output_foldername): 32 | os.makedirs(output_foldername) 33 | 34 | command = ['ffmpeg', 35 | '-i', '"%s"' % filename, 36 | '-vf', scale, 37 | '-threads', '1', 38 | '-loglevel', 'panic', 39 | '-q:v', '2', 40 | '{}/'.format(output_foldername) + '"%05d.jpg"'] 41 | command = ' '.join(command) 42 | try: 43 | subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) 44 | except Exception as e: 45 | print(f"fail to convert {filename}, error: {e}") 46 | return video, False 47 | 48 | return video, True 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('videos_dir', help='Input directory of videos with audio') 55 | parser.add_argument('output_dir', help='Output directory to store JPEG files') 56 | parser.add_argument('--num_workers', help='Number of workers', default=8, type=int) 57 | args = parser.parse_args() 58 | 59 | video_list = glob.glob(args.videos_dir + '/**/*.*', recursive=True) 60 | with concurrent.futures.ProcessPoolExecutor(max_workers=args.num_workers) as executor: 61 | futures = [executor.submit(video_to_images, video, args.output_dir, 256) 62 | for video in video_list] 63 | with tqdm(total=len(futures)) as t_bar: 64 | for future in concurrent.futures.as_completed(futures): 65 | video_id, success = future.result() 66 | if not success: 67 | print(f"Something wrong for {video_id}") 68 | t_bar.update() 69 | print("Completed") 70 | -------------------------------------------------------------------------------- /train_adamml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import numpy as np 5 | import sys 6 | import warnings 7 | import platform 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torch.multiprocessing as mp 18 | from torch.optim import lr_scheduler 19 | 20 | from models import build_model 21 | from utils.utils import (build_dataflow, get_augmentor, 22 | save_checkpoint, accuracy, 23 | train_adamml, validate_adamml) 24 | from utils.video_dataset import MultiVideoDataSet 25 | from utils.dataset_config import get_dataset_config 26 | from opts import arg_parser 27 | 28 | 29 | warnings.filterwarnings("ignore", category=UserWarning) 30 | 31 | 32 | def main(): 33 | global args 34 | parser = arg_parser() 35 | args = parser.parse_args() 36 | 37 | if args.gpu_id: 38 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 39 | 40 | if args.hostfile != '': 41 | curr_node_name = platform.node().split(".")[0] 42 | with open(args.hostfile) as f: 43 | nodes = [x.strip() for x in f.readlines() if x.strip() != ''] 44 | master_node = nodes[0].split(" ")[0] 45 | for idx, node in enumerate(nodes): 46 | if curr_node_name in node: 47 | args.rank = idx 48 | break 49 | args.world_size = len(nodes) 50 | args.dist_url = "tcp://{}:10598".format(master_node) 51 | 52 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 53 | ngpus_per_node = torch.cuda.device_count() 54 | if args.multiprocessing_distributed: 55 | # Since we have ngpus_per_node processes per node, the total world_size 56 | # needs to be adjusted accordingly 57 | args.world_size = ngpus_per_node * args.world_size 58 | # Use torch.multiprocessing.spawn to launch distributed processes: the 59 | # main_worker process function 60 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 61 | else: 62 | # Simply call main_worker function 63 | main_worker(args.gpu, ngpus_per_node, args) 64 | 65 | 66 | def main_worker(gpu, ngpus_per_node, args): 67 | cudnn.benchmark = args.cudnn_benchmark 68 | args.gpu = gpu 69 | 70 | num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(args.dataset) 71 | args.num_classes = num_classes 72 | 73 | if args.gpu is not None: 74 | print("Use GPU: {} for training".format(args.gpu)) 75 | 76 | if args.distributed: 77 | if args.dist_url == "env://" and args.rank == -1: 78 | args.rank = int(os.environ["RANK"]) 79 | if args.multiprocessing_distributed: 80 | # For multiprocessing distributed training, rank needs to be the 81 | # global rank among all the processes 82 | args.rank = args.rank * ngpus_per_node + gpu 83 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 84 | world_size=args.world_size, rank=args.rank) 85 | 86 | args.input_channels = [] 87 | for modality in args.modality: 88 | if modality == 'rgb': 89 | args.input_channels.append(3) 90 | elif modality == 'flow': 91 | args.input_channels.append(2 * 5) 92 | elif modality == 'rgbdiff': 93 | args.input_channels.append(3 * 5) 94 | elif modality == 'sound': 95 | args.input_channels.append(1) 96 | 97 | model, arch_name = build_model(args) 98 | 99 | mean = [model.mean(x) for x in args.modality] 100 | std = [model.std(x) for x in args.modality] 101 | model = model.cuda(args.gpu) 102 | model.eval() 103 | 104 | if args.rank == 0: 105 | torch.cuda.empty_cache() 106 | 107 | if args.show_model and args.rank == 0: 108 | print(model) 109 | return 0 110 | 111 | if args.distributed: 112 | # For multiprocessing distributed, DistributedDataParallel constructor 113 | # should always set the single device scope, otherwise, 114 | # DistributedDataParallel will use all available devices. 115 | if args.gpu is not None: 116 | torch.cuda.set_device(args.gpu) 117 | model.cuda(args.gpu) 118 | # When using a single GPU per process and per 119 | # DistributedDataParallel, we need to divide the batch size 120 | # ourselves based on the total number of GPUs we have 121 | # the batch size should be divided by number of nodes as well 122 | args.batch_size = int(args.batch_size / args.world_size) 123 | args.workers = int(args.workers / ngpus_per_node) 124 | 125 | if args.sync_bn: 126 | process_group = torch.distributed.new_group(list(range(args.world_size))) 127 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group) 128 | 129 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 130 | else: 131 | model.cuda() 132 | # DistributedDataParallel will divide and allocate batch_size to all 133 | # available GPUs if device_ids are not set 134 | model = torch.nn.parallel.DistributedDataParallel(model) 135 | elif args.gpu is not None: 136 | torch.cuda.set_device(args.gpu) 137 | model = model.cuda(args.gpu) 138 | else: 139 | # DataParallel will divide and allocate batch_size to all available GPUs 140 | # assign rank to 0 141 | model = torch.nn.DataParallel(model).cuda() 142 | args.rank = 0 143 | 144 | if args.pretrained is not None: 145 | if args.rank == 0: 146 | print("=> using pre-trained model '{}'".format(arch_name)) 147 | if args.gpu is None: 148 | checkpoint = torch.load(args.pretrained, map_location='cpu') 149 | else: 150 | checkpoint = torch.load(args.pretrained, map_location='cuda:{}'.format(args.gpu)) 151 | new_dict = checkpoint['state_dict'] 152 | model.module.policy_net.set_temperature(checkpoint['temperature']) 153 | if args.rank == 0: 154 | print(f"Temperature: {model.module.policy_net.temperature}", flush=True) 155 | 156 | model.load_state_dict(new_dict, strict=False) 157 | del checkpoint # dereference seems crucial 158 | torch.cuda.empty_cache() 159 | else: 160 | if args.rank == 0: 161 | print("=> creating model '{}'".format(arch_name)) 162 | 163 | # define loss function (criterion) and optimizer 164 | train_criterion = nn.CrossEntropyLoss().cuda(args.gpu) 165 | val_criterion = nn.CrossEntropyLoss().cuda(args.gpu) 166 | eval_criterion = accuracy 167 | 168 | # Data loading code, using rgbdiff as proxy if both rgbdiff and flow are in the arguments 169 | if 'rgbdiff' in args.modality and 'flow' in args.modality: 170 | major_modality = [x for x in args.modality if x != 'rgbdiff'] 171 | else: 172 | major_modality = args.modality 173 | 174 | # val_list = os.path.join(args.datadir, val_list_name) 175 | val_augmentors = [] 176 | for idx, modality in enumerate(args.modality): 177 | val_augmentor = get_augmentor(False, args.input_size, scale_range=args.scale_range, 178 | mean=mean[idx], std=std[idx], disable_scaleup=args.disable_scaleup, 179 | threed_data=args.threed_data, 180 | modality=args.modality[idx], 181 | version=args.augmentor_ver, num_clips=args.val_num_clips) 182 | val_augmentors.append(val_augmentor) 183 | video_data_cls = MultiVideoDataSet 184 | val_dataset = video_data_cls(args.datadir, val_list_name, args.groups, args.frames_per_group, 185 | num_clips=args.val_num_clips, 186 | num_classes=args.num_classes, 187 | modality=args.modality, image_tmpl=image_tmpl, 188 | dense_sampling=args.dense_sampling, 189 | transform=val_augmentors, is_train=False, test_mode=False, 190 | seperator=filename_seperator, filter_video=filter_video, 191 | fps=args.fps, audio_length=args.audio_length, 192 | resampling_rate=args.resampling_rate) 193 | 194 | val_loader = build_dataflow(val_dataset, is_train=False, batch_size=max(1, args.batch_size), 195 | workers=args.workers, 196 | is_distributed=args.distributed) 197 | 198 | log_folder = os.path.join(args.logdir, arch_name) 199 | if args.rank == 0: 200 | if not os.path.exists(log_folder): 201 | os.makedirs(log_folder) 202 | 203 | if args.evaluate: 204 | val_top1, val_top5, val_losses, val_speed, val_selection, mAP, all_selections, flops, output = validate_adamml(val_loader, model, val_criterion, 205 | args.val_num_clips, major_modality, gpu_id=args.gpu, return_output=True) 206 | 207 | if args.rank == 0: 208 | logfile = open(os.path.join(log_folder, 'evaluate_log.log'), 'a') 209 | all_selections = all_selections.cpu().numpy().astype(bool) 210 | np.savez(os.path.join(log_folder, f'all_selection.npz'), modality='_'.join(major_modality), selections=all_selections) 211 | selection_msg = "Selection: " 212 | for k, v in val_selection.items(): 213 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 214 | print(f'Val@{args.input_size}@{args.val_num_clips}: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 215 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True) 216 | print(args.pretrained, flush=True, file=logfile) 217 | print(f'Val@{args.input_size}@{args.val_num_clips}: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 218 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True, file=logfile) 219 | if args.pretrained is not None: 220 | postfix = os.path.basename(args.pretrained).split(".")[0] 221 | else: 222 | postfix = '' 223 | np.save(os.path.join(log_folder, f'val_{args.num_crops}crops_{args.val_num_clips}clips_{args.input_size}_details_{postfix}.npy'), output.data.cpu()) 224 | return 225 | 226 | train_augmentors = [] 227 | for idx, modality in enumerate(args.modality): 228 | train_augmentor = get_augmentor(True, args.input_size, scale_range=args.scale_range, 229 | mean=mean[idx], std=std[idx], 230 | disable_scaleup=args.disable_scaleup, 231 | threed_data=args.threed_data, modality=args.modality[idx], 232 | version=args.augmentor_ver, num_clips=args.num_segments) 233 | 234 | train_augmentors.append(train_augmentor) 235 | 236 | train_dataset = video_data_cls(args.datadir, train_list_name, args.groups, args.frames_per_group, 237 | num_clips=args.num_segments, 238 | modality=args.modality, image_tmpl=image_tmpl, 239 | num_classes=args.num_classes, 240 | dense_sampling=args.dense_sampling, 241 | transform=train_augmentors, is_train=True, test_mode=False, 242 | seperator=filename_seperator, filter_video=filter_video, 243 | fps=args.fps, audio_length=args.audio_length, 244 | resampling_rate=args.resampling_rate) 245 | train_loader = build_dataflow(train_dataset, is_train=True, batch_size=args.batch_size, 246 | workers=args.workers, is_distributed=args.distributed) 247 | 248 | ################## 249 | # use two separate optimizers 250 | policy_net_params = model.module.policy_net.parameters() 251 | p_optimizer = torch.optim.Adam(policy_net_params, args.p_lr, weight_decay=args.weight_decay) 252 | 253 | main_net_params = model.module.main_net.parameters() 254 | optimizer = torch.optim.SGD(main_net_params, args.lr, 255 | momentum=args.momentum, 256 | weight_decay=args.weight_decay, 257 | nesterov=args.nesterov) 258 | 259 | if args.lr_scheduler == 'step': 260 | p_scheduler = lr_scheduler.StepLR(p_optimizer, args.lr_steps[0], gamma=0.1) 261 | scheduler = lr_scheduler.StepLR(optimizer, args.lr_steps[0], gamma=0.1) 262 | elif args.lr_scheduler == 'multisteps': 263 | p_scheduler = lr_scheduler.MultiStepLR(p_optimizer, args.lr_steps, gamma=0.1) 264 | scheduler = lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1) 265 | elif args.lr_scheduler == 'cosine': 266 | p_scheduler = lr_scheduler.CosineAnnealingLR(p_optimizer, args.epochs, eta_min=0) 267 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 268 | elif args.lr_scheduler == 'plateau': 269 | p_scheduler = lr_scheduler.ReduceLROnPlateau(p_optimizer, 'min', verbose=True) 270 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True) 271 | 272 | best_top1 = 0.0 273 | curr_stage = args.curr_stage 274 | # optionally resume from a checkpoint 275 | if args.auto_resume: 276 | checkpoint_path = os.path.join(log_folder, 'checkpoint.pth.tar') 277 | if os.path.exists(checkpoint_path): 278 | args.resume = checkpoint_path 279 | print("Found the checkpoint in the log folder, will resume from there.") 280 | 281 | if args.resume: 282 | if args.rank == 0: 283 | logfile = open(os.path.join(log_folder, 'log.log'), 'a') 284 | if os.path.isfile(args.resume): 285 | if args.rank == 0: 286 | print("=> loading checkpoint '{}'".format(args.resume)) 287 | if args.gpu is None: 288 | checkpoint = torch.load(args.resume, map_location='cpu') 289 | else: 290 | checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu)) 291 | args.start_epoch = checkpoint['epoch'] 292 | best_top1 = checkpoint['best_top1'] 293 | curr_stage = checkpoint['stage'] 294 | if args.gpu is not None: 295 | if not isinstance(best_top1, float): 296 | best_top1 = best_top1.to(args.gpu) 297 | model.load_state_dict(checkpoint['state_dict']) 298 | optimizer.load_state_dict(checkpoint['optimizer']) 299 | p_optimizer.load_state_dict(checkpoint['p_optimizer']) 300 | try: 301 | p_scheduler.load_state_dict(checkpoint['p_scheduler']) 302 | scheduler.load_state_dict(checkpoint['scheduler']) 303 | except: 304 | pass 305 | model.module.policy_net.set_temperature(checkpoint['temperature']) 306 | if args.rank == 0: 307 | print("=> loaded checkpoint '{}' (epoch {})" 308 | .format(args.resume, checkpoint['epoch'])) 309 | del checkpoint # dereference seems crucial 310 | torch.cuda.empty_cache() 311 | else: 312 | raise ValueError("Checkpoint is not found: {}".format(args.resume)) 313 | else: 314 | if os.path.exists(os.path.join(log_folder, 'log.log')) and args.rank == 0: 315 | shutil.copyfile(os.path.join(log_folder, 'log.log'), os.path.join( 316 | log_folder, 'log.log.{}'.format(int(time.time())))) 317 | if args.rank == 0: 318 | logfile = open(os.path.join(log_folder, 'log.log'), 'w') 319 | 320 | if args.rank == 0: 321 | command = " ".join(sys.argv) 322 | # tensorboard_logger.configure(os.path.join(log_folder)) 323 | print(command, flush=True) 324 | print(args, flush=True) 325 | print(model, flush=True) 326 | print(command, file=logfile, flush=True) 327 | # print(model_summary, flush=True) 328 | print(args, file=logfile, flush=True) 329 | 330 | if args.resume == '' and args.rank == 0: 331 | print(model, file=logfile, flush=True) 332 | # print(model_summary, flush=True, file=logfile) 333 | 334 | ##### 335 | """ 336 | Stage 2: Warmup the main network 337 | Needed args: warmup epochs, lr, lr_scheduler, 338 | - Freeze the policy net 339 | """ 340 | if curr_stage == 'warmup': 341 | if args.warmup_epochs > 0: 342 | if args.rank == 0: 343 | print("Stage [Warming up]: Main network with {} epochs".format(args.warmup_epochs)) 344 | model.module.freeze_policy_net() 345 | model.module.unfreeze_main_net() 346 | for epoch in range(args.start_epoch, args.warmup_epochs): 347 | # train for one epoch 348 | train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps, \ 349 | train_selection = train_adamml(train_loader, model, train_criterion, optimizer, p_optimizer, 350 | epoch + 1, major_modality, display=args.print_freq, 351 | clip_gradient=args.clip_gradient, gpu_id=args.gpu, 352 | rank=args.rank, eval_criterion=eval_criterion, 353 | cost_weights=[0.0] * len(major_modality), gammas=args.gammas, penalty_type=args.penalty_type) 354 | if args.distributed: 355 | dist.barrier() 356 | if args.rank == 0: 357 | selection_msg = "Selection: " 358 | for k, v in train_selection.items(): 359 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 360 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 361 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 362 | '{}'.format(epoch + 1, args.warmup_epochs, train_losses, train_top1, 363 | train_top5, train_speed * 1000.0, 364 | speed_data_loader * 1000.0, 365 | selection_msg), file=logfile, flush=True) 366 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 367 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 368 | '{}'.format(epoch + 1, args.warmup_epochs, train_losses, train_top1, 369 | train_top5, train_speed * 1000.0, 370 | speed_data_loader * 1000.0, 371 | selection_msg), flush=True) 372 | # change to policy net: 373 | save_dict = {'epoch': epoch + 1, 374 | 'arch': arch_name, 375 | 'state_dict': model.state_dict(), 376 | 'best_top1': best_top1, 377 | 'p_optimizer': p_optimizer.state_dict(), 378 | 'optimizer': optimizer.state_dict(), 379 | 'p_scheduler': p_scheduler.state_dict(), 380 | 'scheduler': scheduler.state_dict(), 381 | 'temperature': model.module.policy_net.temperature, 382 | 'stage': 'warmup' 383 | } 384 | save_checkpoint(save_dict, is_best=False, filepath=log_folder, epoch=epoch+1, suffix='_warmup') 385 | # move to the next stage 386 | curr_stage = 'alternative_training' 387 | policy_net_params = model.module.policy_net.parameters() 388 | p_optimizer = torch.optim.Adam(policy_net_params, args.p_lr, 389 | weight_decay=args.weight_decay) 390 | main_net_params = model.module.main_net.parameters() 391 | optimizer = torch.optim.SGD(main_net_params, args.lr, 392 | momentum=args.momentum, 393 | weight_decay=args.weight_decay, 394 | nesterov=args.nesterov) 395 | args.start_epoch = 0 396 | """ 397 | Stage 3: alternative training 398 | Main net (k epochs) -> Policy net (k epochs) -> Main net (k epochs) ... 399 | k is a configurable args, default = 1 400 | - alternative_k: 401 | - number of alternatives is determined by total_epochs // alternative_k 402 | """ 403 | if curr_stage == 'alternative_training': 404 | if args.rank == 0: 405 | print("Stage [Alternative training]: {} epochs".format(args.epochs)) 406 | for epoch in range(args.start_epoch, args.epochs): 407 | if args.rank == 0: 408 | print("Stage [Alternative training]: Training Main net") 409 | # start with main net: 410 | model.module.freeze_policy_net() 411 | model.module.unfreeze_main_net() 412 | # train for one epoch 413 | train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps, train_selection = \ 414 | train_adamml(train_loader, model, train_criterion, optimizer, p_optimizer, 415 | epoch + 1, major_modality, 416 | display=args.print_freq, 417 | clip_gradient=args.clip_gradient, gpu_id=args.gpu, rank=args.rank, 418 | eval_criterion=eval_criterion, cost_weights=[0.0] * len(major_modality), 419 | gammas=args.gammas, penalty_type=args.penalty_type) 420 | if args.distributed: 421 | dist.barrier() 422 | 423 | if args.rank == 0: 424 | selection_msg = "Selection: " 425 | for k, v in train_selection.items(): 426 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 427 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 428 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 429 | '{}'.format(epoch + 1, args.epochs, train_losses, train_top1, 430 | train_top5, train_speed * 1000.0, 431 | speed_data_loader * 1000.0, 432 | selection_msg), file=logfile, flush=True) 433 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 434 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 435 | '{}'.format(epoch + 1, args.epochs, train_losses, train_top1, 436 | train_top5, train_speed * 1000.0, 437 | speed_data_loader * 1000.0, 438 | selection_msg), flush=True) 439 | # change to policy net: 440 | if args.rank == 0: 441 | print("Stage [Alternative training]: Training Policy net") 442 | model.module.unfreeze_policy_net() 443 | model.module.freeze_main_net() 444 | # train for one epoch 445 | train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps, train_selection = \ 446 | train_adamml(train_loader, model, train_criterion, optimizer, p_optimizer, 447 | epoch + 1, major_modality, 448 | display=args.print_freq, 449 | clip_gradient=args.clip_gradient, gpu_id=args.gpu, rank=args.rank, 450 | eval_criterion=eval_criterion, cost_weights=args.cost_weights, 451 | gammas=args.gammas, penalty_type=args.penalty_type) 452 | if args.distributed: 453 | dist.barrier() 454 | 455 | # evaluate on validation set 456 | val_top1, val_top5, val_losses, val_speed, val_selection, mAP, all_selections, flops = validate_adamml( 457 | val_loader, model, val_criterion, args.val_num_clips, major_modality, gpu_id=args.gpu, eval_criterion=eval_criterion) 458 | 459 | # update current learning rate 460 | if args.lr_scheduler == 'plateau': 461 | p_scheduler.step(val_losses) 462 | scheduler.step(val_losses) 463 | else: 464 | p_scheduler.step(epoch + 1) 465 | scheduler.step(epoch + 1) 466 | 467 | if args.distributed: 468 | dist.barrier() 469 | 470 | # only logging at rank 0 471 | if args.rank == 0: 472 | selection_msg = "Selection: " 473 | for k, v in train_selection.items(): 474 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 475 | 476 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 477 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 478 | '{}'.format(epoch + 1, args.epochs, train_losses, train_top1, 479 | train_top5, train_speed * 1000.0, 480 | speed_data_loader * 1000.0, 481 | selection_msg), file=logfile, flush=True) 482 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 483 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 484 | '{}'.format(epoch + 1, args.epochs, train_losses, train_top1, 485 | train_top5, train_speed * 1000.0, 486 | speed_data_loader * 1000.0, 487 | selection_msg), flush=True) 488 | 489 | all_selections = all_selections.cpu().numpy().astype(bool) 490 | np.savez(os.path.join(log_folder, f'all_selection_main_{epoch + 1}.npz'), modality='_'.format(major_modality), selections=all_selections) 491 | selection_msg = "Selection: " 492 | for k, v in val_selection.items(): 493 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 494 | print(f'Val: [{epoch + 1:03d}/{args.epochs:03d}]: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 495 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True) 496 | print(f'Val: [{epoch + 1:03d}/{args.epochs:03d}]: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 497 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True, file=logfile) 498 | 499 | # remember best prec@1 and save checkpoint 500 | is_best = val_top1 > best_top1 501 | best_top1 = max(val_top1, best_top1) 502 | 503 | save_dict = {'epoch': epoch + 1, 504 | 'arch': arch_name, 505 | 'state_dict': model.state_dict(), 506 | 'best_top1': best_top1, 507 | 'p_optimizer': p_optimizer.state_dict(), 508 | 'optimizer': optimizer.state_dict(), 509 | 'p_scheduler': p_scheduler.state_dict(), 510 | 'scheduler': scheduler.state_dict(), 511 | 'temperature': model.module.policy_net.temperature, 512 | 'stage': 'alternative_training' 513 | } 514 | 515 | save_checkpoint(save_dict, is_best, filepath=log_folder, epoch=epoch+1, suffix='_main') 516 | model.module.decay_temperature() 517 | 518 | # move to next stage 519 | curr_stage = 'finetune' 520 | policy_net_params = model.module.policy_net.parameters() 521 | p_optimizer = torch.optim.Adam(policy_net_params, args.p_lr, 522 | weight_decay=args.weight_decay) 523 | main_net_params = model.module.main_net.parameters() 524 | optimizer = torch.optim.SGD(main_net_params, args.lr, 525 | momentum=args.momentum, 526 | weight_decay=args.weight_decay, 527 | nesterov=args.nesterov) 528 | args.start_epoch = 0 529 | """ 530 | Stage 4: fixed policy net and then finetune main network for few epochs 531 | finetune epochs 532 | """ 533 | if curr_stage == 'finetune': 534 | if args.rank == 0: 535 | print("Stage [Post finetuning]: Finetune the main network {} epochs".format(args.finetune_epochs)) 536 | if args.finetune_epochs > 0: 537 | # finetune on top of the best model when it moves from the previous stage 538 | if args.start_epoch == 0: 539 | try: 540 | best_model_path = os.path.join(log_folder, 'model_best.pth.tar') 541 | if args.gpu is None: 542 | checkpoint = torch.load(best_model_path, map_location='cpu') 543 | else: 544 | checkpoint = torch.load(best_model_path, map_location='cuda:{}'.format(args.gpu)) 545 | new_dict = checkpoint['state_dict'] 546 | model.module.policy_net.set_temperature(checkpoint['temperature']) 547 | model.load_state_dict(new_dict, strict=True) 548 | del checkpoint # dereference seems crucial 549 | torch.cuda.empty_cache() 550 | except Exception as e: 551 | print("Can not find the best model at {}. Use the last checkpoint.".format(log_folder)) 552 | model.module.freeze_policy_net() 553 | model.module.unfreeze_main_net() 554 | for epoch in range(args.start_epoch, args.finetune_epochs): 555 | # NOTE: disable cost weight here 556 | train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps, \ 557 | train_selection = train_adamml(train_loader, model, train_criterion, optimizer, 558 | p_optimizer, 559 | epoch + 1, major_modality, display=args.print_freq, 560 | clip_gradient=args.clip_gradient, gpu_id=args.gpu, 561 | rank=args.rank, 562 | eval_criterion=eval_criterion, 563 | cost_weights=[0.0] * len(major_modality), gammas=args.gammas, 564 | penalty_type=args.penalty_type) 565 | if args.distributed: 566 | dist.barrier() 567 | # evaluate on validation set 568 | val_top1, val_top5, val_losses, val_speed, val_selection, mAP, all_selections, flops = validate_adamml( 569 | val_loader, model, val_criterion, args.val_num_clips, major_modality, gpu_id=args.gpu, eval_criterion=eval_criterion) 570 | 571 | # update current learning rate 572 | if args.lr_scheduler == 'plateau': 573 | p_scheduler.step(val_losses) 574 | scheduler.step(val_losses) 575 | else: 576 | p_scheduler.step(epoch + 1) 577 | scheduler.step(epoch + 1) 578 | 579 | if args.distributed: 580 | dist.barrier() 581 | 582 | # only logging at rank 0 583 | if args.rank == 0: 584 | selection_msg = "Selection: " 585 | for k, v in train_selection.items(): 586 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 587 | 588 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 589 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 590 | '{}'.format(epoch + 1, args.finetune_epochs, train_losses, train_top1, 591 | train_top5, train_speed * 1000.0, 592 | speed_data_loader * 1000.0, 593 | selection_msg), file=logfile, flush=True) 594 | print('Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\t' 595 | 'Speed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch\t' 596 | '{}'.format(epoch + 1, args.finetune_epochs, train_losses, train_top1, 597 | train_top5, train_speed * 1000.0, 598 | speed_data_loader * 1000.0, 599 | selection_msg), flush=True) 600 | all_selections = all_selections.cpu().numpy().astype(bool) 601 | np.savez(os.path.join(log_folder, f'all_selection_finetune_{epoch+1}.npz'), modality='_'.format(major_modality), selections=all_selections) 602 | selection_msg = "Selection: " 603 | for k, v in val_selection.items(): 604 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 605 | print(f'Val: [{epoch + 1:03d}/{args.finetune_epochs:03d}]: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 606 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True) 607 | print(f'Val: [{epoch + 1:03d}/{args.finetune_epochs:03d}]: \tLoss: {val_losses:4.4f}\tTop@1: {val_top1:.4f}' 608 | f'\tTop@5: {val_top5:.4f}\tmAP: {mAP:.4f}\tSpeed: {val_speed:.2f} ms/batch\tflops: {flops:.2f}\t{selection_msg}', flush=True, file=logfile) 609 | 610 | # remember best prec@1 and save checkpoint 611 | is_best = val_top1 > best_top1 612 | best_top1 = max(val_top1, best_top1) 613 | 614 | save_dict = {'epoch': epoch + 1, 615 | 'arch': arch_name, 616 | 'state_dict': model.state_dict(), 617 | 'best_top1': best_top1, 618 | 'p_optimizer': p_optimizer.state_dict(), 619 | 'optimizer': optimizer.state_dict(), 620 | 'p_scheduler': p_scheduler.state_dict(), 621 | 'scheduler': scheduler.state_dict(), 622 | 'temperature': model.module.policy_net.temperature, 623 | 'stage': 'finetune' 624 | } 625 | 626 | save_checkpoint(save_dict, is_best, filepath=log_folder, epoch=epoch+1, suffix='_finetune') 627 | 628 | if args.rank == 0: 629 | logfile.close() 630 | 631 | 632 | if __name__ == '__main__': 633 | main() 634 | -------------------------------------------------------------------------------- /train_unimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import numpy as np 5 | import sys 6 | import warnings 7 | import platform 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torch.multiprocessing as mp 18 | from torch.optim import lr_scheduler 19 | import torchsummary 20 | 21 | from models import build_model 22 | from utils.utils import (train, validate, build_dataflow, get_augmentor, 23 | save_checkpoint, extract_total_flops_params, 24 | accuracy, actnet_acc) 25 | from utils.video_dataset import VideoDataSet 26 | from utils.dataset_config import get_dataset_config 27 | from opts import arg_parser 28 | 29 | 30 | warnings.filterwarnings("ignore", category=UserWarning) 31 | 32 | 33 | def main(): 34 | global args 35 | parser = arg_parser() 36 | args = parser.parse_args() 37 | 38 | if args.hostfile != '': 39 | curr_node_name = platform.node().split(".")[0] 40 | with open(args.hostfile) as f: 41 | nodes = [x.strip() for x in f.readlines() if x.strip() != ''] 42 | master_node = nodes[0].split(" ")[0] 43 | for idx, node in enumerate(nodes): 44 | if curr_node_name in node: 45 | args.rank = idx 46 | break 47 | args.world_size = len(nodes) 48 | args.dist_url = "tcp://{}:10598".format(master_node) 49 | 50 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 51 | ngpus_per_node = torch.cuda.device_count() 52 | if args.multiprocessing_distributed: 53 | # Since we have ngpus_per_node processes per node, the total world_size 54 | # needs to be adjusted accordingly 55 | args.world_size = ngpus_per_node * args.world_size 56 | # Use torch.multiprocessing.spawn to launch distributed processes: the 57 | # main_worker process function 58 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 59 | else: 60 | # Simply call main_worker function 61 | main_worker(args.gpu, ngpus_per_node, args) 62 | 63 | 64 | def main_worker(gpu, ngpus_per_node, args): 65 | cudnn.benchmark = args.cudnn_benchmark 66 | args.gpu = gpu 67 | 68 | args.datadir = args.datadir[0] 69 | args.modality = args.modality[0] 70 | 71 | num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(args.dataset) 72 | args.num_classes = num_classes 73 | 74 | if args.gpu is not None: 75 | print("Use GPU: {} for training".format(args.gpu)) 76 | 77 | if args.distributed: 78 | if args.dist_url == "env://" and args.rank == -1: 79 | args.rank = int(os.environ["RANK"]) 80 | if args.multiprocessing_distributed: 81 | # For multiprocessing distributed training, rank needs to be the 82 | # global rank among all the processes 83 | args.rank = args.rank * ngpus_per_node + gpu 84 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 85 | world_size=args.world_size, rank=args.rank) 86 | 87 | if args.modality == 'rgb': 88 | args.input_channels = 3 89 | elif args.modality == 'flow': 90 | args.input_channels = 2 * 5 91 | elif args.modality == 'rgbdiff': 92 | args.input_channels = 3 * 5 93 | elif args.modality == 'sound': 94 | args.input_channels = 1 95 | 96 | model, arch_name = build_model(args) 97 | mean = model.mean(args.modality) 98 | std = model.std(args.modality) 99 | 100 | # overwrite mean and std if they are presented in command 101 | if args.mean is not None: 102 | if args.modality == 'rgb' or args.modality == 'rgbdiff': 103 | if len(args.mean) != 3: 104 | raise ValueError("When training with rgb, dim of mean must be three.") 105 | elif args.modality == 'flow': 106 | if len(args.mean) != 1: 107 | raise ValueError("When training with flow, dim of mean must be three.") 108 | mean = args.mean 109 | 110 | if args.std is not None: 111 | if args.modality == 'rgb' or args.modality == 'rgbdiff': 112 | if len(args.std) != 3: 113 | raise ValueError("When training with rgb, dim of std must be three.") 114 | elif args.modality == 'flow': 115 | if len(args.std) != 1: 116 | raise ValueError("When training with flow, dim of std must be three.") 117 | std = args.std 118 | 119 | model = model.cuda(args.gpu) 120 | model.eval() 121 | 122 | if args.modality == 'sound': 123 | num_frames = 1 124 | else: 125 | num_frames = args.groups 126 | 127 | if args.threed_data: 128 | dummy_data = (args.input_channels, num_frames, args.input_size, args.input_size) 129 | else: 130 | dummy_data = (args.input_channels * num_frames, args.input_size, args.input_size) 131 | 132 | if args.rank == 0: 133 | model_summary = torchsummary.summary(model, input_size=dummy_data) 134 | torch.cuda.empty_cache() 135 | 136 | if args.show_model and args.rank == 0: 137 | print(model) 138 | print(model_summary) 139 | return 0 140 | 141 | if args.distributed: 142 | # For multiprocessing distributed, DistributedDataParallel constructor 143 | # should always set the single device scope, otherwise, 144 | # DistributedDataParallel will use all available devices. 145 | if args.gpu is not None: 146 | torch.cuda.set_device(args.gpu) 147 | model.cuda(args.gpu) 148 | # When using a single GPU per process and per 149 | # DistributedDataParallel, we need to divide the batch size 150 | # ourselves based on the total number of GPUs we have 151 | # the batch size should be divided by number of nodes as well 152 | args.batch_size = int(args.batch_size / args.world_size) 153 | args.workers = int(args.workers / ngpus_per_node) 154 | 155 | if args.sync_bn: 156 | process_group = torch.distributed.new_group(list(range(args.world_size))) 157 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group) 158 | 159 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 160 | else: 161 | model.cuda() 162 | # DistributedDataParallel will divide and allocate batch_size to all 163 | # available GPUs if device_ids are not set 164 | model = torch.nn.parallel.DistributedDataParallel(model) 165 | elif args.gpu is not None: 166 | torch.cuda.set_device(args.gpu) 167 | model = model.cuda(args.gpu) 168 | else: 169 | # DataParallel will divide and allocate batch_size to all available GPUs 170 | # assign rank to 0 171 | model = torch.nn.DataParallel(model).cuda() 172 | args.rank = 0 173 | 174 | if args.pretrained is not None: 175 | if args.rank == 0: 176 | print("=> using pre-trained model '{}'".format(arch_name)) 177 | if args.gpu is None: 178 | checkpoint = torch.load(args.pretrained, map_location='cpu') 179 | else: 180 | checkpoint = torch.load(args.pretrained, map_location='cuda:{}'.format(args.gpu)) 181 | if args.transfer: 182 | new_dict = {} 183 | for k, v in checkpoint['state_dict'].items(): 184 | if k.startswith("module.fc") or k.startswith('module._fc'): 185 | continue 186 | new_dict[k] = v 187 | if args.input_channels != 3: 188 | from models.inflate_from_2d_model import convert_rgb_model_to_others 189 | new_dict = convert_rgb_model_to_others(new_dict, args.input_channels, ks=7) 190 | else: 191 | new_dict = checkpoint['state_dict'] 192 | msg = model.load_state_dict(new_dict, strict=False) 193 | if args.rank == 0: 194 | print(msg, flush=True) 195 | del checkpoint # dereference seems crucial 196 | torch.cuda.empty_cache() 197 | else: 198 | if args.rank == 0: 199 | print("=> creating model '{}'".format(arch_name)) 200 | 201 | # define loss function (criterion) and optimizer 202 | train_criterion = nn.CrossEntropyLoss().cuda(args.gpu) 203 | val_criterion = nn.CrossEntropyLoss().cuda(args.gpu) 204 | eval_criterion = accuracy 205 | 206 | # Data loading code 207 | val_augmentor = get_augmentor(False, args.input_size, scale_range=args.scale_range, mean=mean, std=std, disable_scaleup=args.disable_scaleup, 208 | threed_data=args.threed_data, modality=args.modality, 209 | version=args.augmentor_ver) 210 | 211 | video_data_cls = VideoDataSet 212 | val_dataset = video_data_cls(args.datadir, val_list_name, args.groups, args.frames_per_group, num_clips=args.num_clips, 213 | modality=args.modality, image_tmpl=image_tmpl, dense_sampling=args.dense_sampling, 214 | num_classes=args.num_classes, 215 | transform=val_augmentor, is_train=False, test_mode=False, 216 | seperator=filename_seperator, filter_video=filter_video, 217 | fps=args.fps, audio_length=args.audio_length, 218 | resampling_rate=args.resampling_rate) 219 | 220 | val_loader = build_dataflow(val_dataset, is_train=False, batch_size=args.batch_size, workers=args.workers, 221 | is_distributed=args.distributed) 222 | 223 | log_folder = os.path.join(args.logdir, arch_name) 224 | if args.rank == 0: 225 | if not os.path.exists(log_folder): 226 | os.makedirs(log_folder) 227 | 228 | if args.evaluate: 229 | val_top1, val_top5, val_losses, val_speed = validate(val_loader, model, val_criterion, gpu_id=args.gpu) 230 | if args.rank == 0: 231 | logfile = open(os.path.join(log_folder, 'evaluate_log.log'), 'a') 232 | flops, params = extract_total_flops_params(model_summary) 233 | print( 234 | 'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'.format( 235 | args.input_size, val_losses, val_top1, val_top5, val_speed * 1000.0, flops, params), flush=True) 236 | print(args.pretrained, flush=True, file=logfile) 237 | print( 238 | 'Val@{}: \tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tFlops: {}\tParams: {}'.format( 239 | args.input_size, val_losses, val_top1, val_top5, val_speed * 1000.0, flops, params), flush=True, 240 | file=logfile) 241 | return 242 | 243 | train_augmentor = get_augmentor(True, args.input_size, scale_range=args.scale_range, mean=mean, std=std, 244 | disable_scaleup=args.disable_scaleup, threed_data=args.threed_data, 245 | modality=args.modality, version=args.augmentor_ver) 246 | 247 | train_dataset = video_data_cls(args.datadir, train_list_name, args.groups, args.frames_per_group, num_clips=args.num_clips, 248 | modality=args.modality, image_tmpl=image_tmpl, dense_sampling=args.dense_sampling, 249 | num_classes=args.num_classes, 250 | transform=train_augmentor, is_train=True, test_mode=False, 251 | seperator=filename_seperator, filter_video=filter_video, 252 | fps=args.fps, audio_length=args.audio_length, 253 | resampling_rate=args.resampling_rate) 254 | 255 | train_loader = build_dataflow(train_dataset, is_train=True, batch_size=args.batch_size, 256 | workers=args.workers, is_distributed=args.distributed) 257 | 258 | sgd_polices = model.parameters() 259 | optimizer = torch.optim.SGD(sgd_polices, args.lr, 260 | momentum=args.momentum, 261 | weight_decay=args.weight_decay, 262 | nesterov=args.nesterov) 263 | 264 | if args.lr_scheduler == 'step': 265 | scheduler = lr_scheduler.StepLR(optimizer, args.lr_steps[0], gamma=0.1) 266 | elif args.lr_scheduler == 'multisteps': 267 | scheduler = lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=0.1) 268 | elif args.lr_scheduler == 'cosine': 269 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 270 | elif args.lr_scheduler == 'plateau': 271 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True) 272 | 273 | best_top1 = 0.0 274 | if args.auto_resume: 275 | checkpoint_path = os.path.join(log_folder, 'checkpoint.pth.tar') 276 | if os.path.exists(checkpoint_path): 277 | args.resume = checkpoint_path 278 | print("Found the checkpoint in the log folder, will resume from there.") 279 | 280 | # optionally resume from a checkpoint 281 | if args.resume: 282 | if args.rank == 0: 283 | logfile = open(os.path.join(log_folder, 'log.log'), 'a') 284 | if os.path.isfile(args.resume): 285 | if args.rank == 0: 286 | print("=> loading checkpoint '{}'".format(args.resume)) 287 | if args.gpu is None: 288 | checkpoint = torch.load(args.resume, map_location='cpu') 289 | else: 290 | checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu)) 291 | args.start_epoch = checkpoint['epoch'] 292 | best_top1 = checkpoint['best_top1'] 293 | if not isinstance(best_top1, float): 294 | if args.gpu is not None: 295 | best_top1 = best_top1.to(args.gpu) 296 | else: 297 | best_top1 = best_top1.cuda() 298 | model.load_state_dict(checkpoint['state_dict']) 299 | optimizer.load_state_dict(checkpoint['optimizer']) 300 | try: 301 | scheduler.load_state_dict(checkpoint['scheduler']) 302 | except: 303 | pass 304 | if args.rank == 0: 305 | print("=> loaded checkpoint '{}' (epoch {})" 306 | .format(args.resume, checkpoint['epoch'])) 307 | del checkpoint # dereference seems crucial 308 | torch.cuda.empty_cache() 309 | else: 310 | raise ValueError("Checkpoint is not found: {}".format(args.resume)) 311 | else: 312 | if os.path.exists(os.path.join(log_folder, 'log.log')) and args.rank == 0: 313 | shutil.copyfile(os.path.join(log_folder, 'log.log'), os.path.join( 314 | log_folder, 'log.log.{}'.format(int(time.time())))) 315 | if args.rank == 0: 316 | logfile = open(os.path.join(log_folder, 'log.log'), 'w') 317 | 318 | if args.rank == 0: 319 | command = " ".join(sys.argv) 320 | print(command, flush=True) 321 | print(args, flush=True) 322 | print(model, flush=True) 323 | print(command, file=logfile, flush=True) 324 | print(model_summary, flush=True) 325 | print(args, file=logfile, flush=True) 326 | 327 | if args.resume == '' and args.rank == 0: 328 | print(model, file=logfile, flush=True) 329 | print(model_summary, flush=True, file=logfile) 330 | 331 | for epoch in range(args.start_epoch, args.epochs): 332 | # train for one epoch 333 | train_top1, train_top5, train_losses, train_speed, speed_data_loader, train_steps = \ 334 | train(train_loader, model, train_criterion, optimizer, epoch + 1, 335 | display=args.print_freq, clip_gradient=args.clip_gradient, 336 | gpu_id=args.gpu, rank=args.rank, 337 | eval_criterion=eval_criterion) 338 | if args.distributed: 339 | dist.barrier() 340 | 341 | eval_this_epoch = True 342 | if args.lazy_eval: 343 | if (epoch + 1) % 10 == 0 or (epoch + 1) >= args.epochs * 0.9: 344 | eval_this_epoch = True 345 | else: 346 | eval_this_epoch = False 347 | 348 | if eval_this_epoch: 349 | # evaluate on validation set 350 | val_top1, val_top5, val_losses, val_speed = validate( 351 | val_loader, model, val_criterion, gpu_id=args.gpu, eval_criterion=eval_criterion) 352 | else: 353 | val_top1, val_top5, val_losses, val_speed = 0.0, 0.0, 0.0, 0.0 354 | 355 | # update current learning rate 356 | if args.lr_scheduler == 'plateau': 357 | scheduler.step(val_losses) 358 | else: 359 | scheduler.step(epoch+1) 360 | 361 | if args.distributed: 362 | dist.barrier() 363 | 364 | # only logging at rank 0 365 | if args.rank == 0: 366 | print( 367 | 'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'.format( 368 | epoch + 1, args.epochs, train_losses, train_top1, train_top5, train_speed * 1000.0, 369 | speed_data_loader * 1000.0), file=logfile, flush=True) 370 | print( 371 | 'Train: [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch\tData loading: {:.2f} ms/batch'.format( 372 | epoch + 1, args.epochs, train_losses, train_top1, train_top5, train_speed * 1000.0, 373 | speed_data_loader * 1000.0), flush=True) 374 | if eval_this_epoch: 375 | print('Val : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'.format( 376 | epoch + 1, args.epochs, val_losses, val_top1, val_top5, val_speed * 1000.0), file=logfile, flush=True) 377 | print('Val : [{:03d}/{:03d}]\tLoss: {:4.4f}\tTop@1: {:.4f}\tTop@5: {:.4f}\tSpeed: {:.2f} ms/batch'.format( 378 | epoch + 1, args.epochs, val_losses, val_top1, val_top5, val_speed * 1000.0), flush=True) 379 | 380 | # remember best prec@1 and save checkpoint 381 | is_best = val_top1 > best_top1 382 | best_top1 = max(val_top1, best_top1) 383 | 384 | 385 | save_dict = {'epoch': epoch + 1, 386 | 'arch': arch_name, 387 | 'state_dict': model.state_dict(), 388 | 'best_top1': best_top1, 389 | 'optimizer': optimizer.state_dict(), 390 | 'scheduler': scheduler.state_dict() 391 | } 392 | 393 | save_checkpoint(save_dict, is_best, filepath=log_folder) 394 | try: 395 | # get_lr get all lrs for every layer of current epoch, assume the lr for all layers are identical 396 | lr = scheduler.optimizer.param_groups[0]['lr'] 397 | except Exception as e: 398 | lr = None 399 | 400 | if args.distributed: 401 | dist.barrier() 402 | 403 | if args.rank == 0: 404 | logfile.close() 405 | 406 | 407 | if __name__ == '__main__': 408 | main() 409 | -------------------------------------------------------------------------------- /utils/dataset_config.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | For each dataset, the following fields are required: 4 | - num_classes: number of classes 5 | - train_list_name: the filename of train list 6 | - val_list_name: the filename of val list 7 | - filename_separator: the separator used in train/val/test list 8 | - image_tmpl: the template of images in the video folder 9 | - filter_video: the threshold to remove videos whose frame number is less than this value 10 | - label_file: a file contains mapping between label index to class name 11 | 12 | Those are optional: 13 | - test_list_name: the filename of test list 14 | - label_file: name of classes, used to map the prediction from a model to real label name 15 | 16 | """ 17 | 18 | 19 | DATASET_CONFIG = { 20 | 'kinetics-sounds': { 21 | 'num_classes': 31, 22 | 'train_list_name': 'train.txt', 23 | 'val_list_name': 'val.txt', 24 | 'filename_seperator': ";", 25 | 'image_tmpl': '{:05d}.jpg', 26 | 'filter_video': 0, 27 | 'label_file': 'categories.txt' 28 | } 29 | } 30 | 31 | 32 | def get_dataset_config(dataset): 33 | ret = DATASET_CONFIG[dataset] 34 | num_classes = ret['num_classes'] 35 | train_list_name = ret['train_list_name'] 36 | val_list_name = ret['val_list_name'] 37 | test_list_name = ret.get('test_list_name', None) 38 | if test_list_name is not None: 39 | test_list_name = test_list_name 40 | filename_seperator = ret['filename_seperator'] 41 | image_tmpl = ret['image_tmpl'] 42 | filter_video = ret.get('filter_video', 0) 43 | label_file = ret.get('label_file', None) 44 | 45 | return num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, \ 46 | image_tmpl, filter_video, label_file 47 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import time 4 | import multiprocessing 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn.utils import clip_grad_norm_ 9 | import torch.distributed as dist 10 | import torch.nn.parallel 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | import torchvision.transforms as transforms 15 | from tqdm import tqdm 16 | 17 | from .video_transforms import (GroupRandomHorizontalFlip, 18 | GroupMultiScaleCrop, GroupScale, GroupCenterCrop, GroupRandomCrop, 19 | GroupNormalize, Stack, ToTorchFormatTensor, GroupRandomScale) 20 | 21 | from torch.utils.data import DataLoader 22 | 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def accuracy(output, target, topk=(1, 5)): 43 | """Computes the precision@k for the specified values of k""" 44 | with torch.no_grad(): 45 | maxk = max(topk) 46 | batch_size = target.size(0) 47 | 48 | _, pred = output.topk(maxk, 1, True, True) 49 | pred = pred.t() 50 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 51 | 52 | res = [] 53 | for k in topk: 54 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 55 | res.append(correct_k.mul_(100.0 / batch_size)) 56 | return res 57 | 58 | def actnet_acc(logits, test_y, topk=None, have_softmaxed=False): 59 | from torchnet import meter 60 | 61 | """ 62 | 63 | :param logits: (NxK) 64 | :param test_y: (Nx1) 65 | :param topk (tuple(int)): 66 | :return: 67 | - list[float]: topk acc 68 | - float: mAP 69 | """ 70 | num_classes = logits.shape[1] 71 | topk = [1, min(5, num_classes)] if topk is None else topk 72 | single_label = True if len(test_y.shape) == 1 else False 73 | probs = F.softmax(logits, dim=1) if not have_softmaxed else logits 74 | if single_label: 75 | acc_meter = meter.ClassErrorMeter(topk=topk, accuracy=True) 76 | acc_meter.add(logits, test_y) 77 | acc = acc_meter.value() 78 | gt = torch.zeros_like(logits) 79 | gt[torch.LongTensor(range(gt.size(0))), test_y.view(-1)] = 1 80 | else: 81 | gt = test_y 82 | acc = [0] * len(topk) 83 | map_meter = meter.mAPMeter() 84 | map_meter.add(probs, gt) 85 | ap = map_meter.value() * 100.0 86 | return acc, ap.item() 87 | 88 | 89 | def save_checkpoint(state, is_best, filepath='', epoch=None, suffix=''): 90 | curr_checkpoint_path = os.path.join(filepath, 'checkpoint.pth.tar') 91 | torch.save(state, curr_checkpoint_path) 92 | 93 | if epoch: 94 | shutil.copyfile(curr_checkpoint_path, os.path.join(filepath, 'checkpoint{}_{:02d}.pth.tar'.format(suffix, epoch))) 95 | if is_best: 96 | shutil.copyfile(curr_checkpoint_path, os.path.join(filepath, 'model_best.pth.tar')) 97 | 98 | def extract_total_flops_params(summary): 99 | for line in summary.split("\n"): 100 | line = line.strip() 101 | if line == "": 102 | continue 103 | if "Total flops" in line: 104 | total_flops = line.split(":")[-1].strip() 105 | elif "Total params" in line: 106 | total_params = line.split(":")[-1].strip() 107 | 108 | return total_flops, total_params 109 | 110 | def get_augmentor(is_train, image_size, mean=None, 111 | std=None, disable_scaleup=False, 112 | threed_data=False, version='v1', scale_range=None, 113 | modality='rgb', num_clips=1, num_crops=1): 114 | 115 | mean = [0.485, 0.456, 0.406] if mean is None else mean 116 | std = [0.229, 0.224, 0.225] if std is None else std 117 | scale_range = [256, 320] if scale_range is None else scale_range 118 | 119 | if modality == 'sound': 120 | augments = [ 121 | Stack(threed_data=threed_data), 122 | ToTorchFormatTensor(div=False, num_clips_crops=num_clips * num_crops) 123 | ] 124 | else: 125 | augments = [] 126 | if is_train: 127 | if version == 'v1': 128 | augments += [ 129 | GroupMultiScaleCrop(image_size, [1, .875, .75, .66]) 130 | ] 131 | elif version == 'v2': 132 | augments += [ 133 | GroupRandomScale(scale_range), 134 | GroupRandomCrop(image_size), 135 | ] 136 | augments += [GroupRandomHorizontalFlip(is_flow=(modality == 'flow'))] 137 | else: 138 | scaled_size = image_size if disable_scaleup else int(image_size / 0.875 + 0.5) 139 | augments += [ 140 | GroupScale(scaled_size), 141 | GroupCenterCrop(image_size) 142 | ] 143 | augments += [ 144 | Stack(threed_data=threed_data), 145 | ToTorchFormatTensor(num_clips_crops=num_clips * num_crops), 146 | GroupNormalize(mean=mean, std=std, threed_data=threed_data) 147 | ] 148 | 149 | augmentor = transforms.Compose(augments) 150 | return augmentor 151 | 152 | 153 | def build_dataflow(dataset, is_train, batch_size, workers=36, is_distributed=False): 154 | workers = min(workers, multiprocessing.cpu_count()) 155 | shuffle = False 156 | 157 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None 158 | if is_train: 159 | shuffle = sampler is None 160 | 161 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 162 | num_workers=workers, pin_memory=True, sampler=sampler) 163 | return data_loader 164 | 165 | 166 | def compute_policy_loss(penalty_type, selection, cost_weights, gammas, cls_logits, cls_targets): 167 | num_modality = selection.shape[-1] 168 | policy_loss = torch.tensor(0.0, dtype=selection.dtype, device=selection.device) 169 | if penalty_type == 'mean': 170 | for w, pl in zip(cost_weights, selection.chunk(chunks=num_modality, dim=-1)): 171 | policy_loss = policy_loss + w * torch.mean(pl) 172 | 173 | elif penalty_type == 'blockdrop': 174 | top1_pred = torch.argmax(cls_logits.detach(), dim=-1) 175 | correctness = (top1_pred == cls_targets).type_as(cls_logits) 176 | 177 | selection = torch.mean(selection, dim=1) # compute the selection per video per modality 178 | selection = selection * selection # square it 179 | for w, pl in zip(cost_weights, selection.chunk(chunks=num_modality, dim=-1)): 180 | # pl: Nx1 181 | loss = w * torch.mean(correctness * pl) 182 | policy_loss = policy_loss + loss 183 | policy_loss = policy_loss + torch.mean((torch.ones_like(correctness) - correctness) * gammas) 184 | return policy_loss 185 | 186 | 187 | def train(data_loader, model, criterion, optimizer, epoch, display=100, 188 | steps_per_epoch=99999999999, num_classes=None, 189 | clip_gradient=None, gpu_id=None, rank=0, eval_criterion=accuracy, **kwargs): 190 | batch_time = AverageMeter() 191 | data_time = AverageMeter() 192 | losses = AverageMeter() 193 | top1 = AverageMeter() 194 | top5 = AverageMeter() 195 | 196 | # set different random see every epoch 197 | if dist.is_initialized(): 198 | data_loader.sampler.set_epoch(epoch) 199 | 200 | # switch to train mode 201 | model.train() 202 | end = time.time() 203 | num_batch = 0 204 | if gpu_id is None or gpu_id == 0: 205 | disable_status_bar = False 206 | else: 207 | disable_status_bar = True 208 | 209 | with tqdm(total=len(data_loader), disable=disable_status_bar) as t_bar: 210 | for i, (images, target) in enumerate(data_loader): 211 | # measure data loading time 212 | data_time.update(time.time() - end) 213 | # compute output 214 | if gpu_id is not None: 215 | if isinstance(images, list): 216 | images = [x.cuda(gpu_id, non_blocking=True) for x in images] 217 | else: 218 | images = images.cuda(gpu_id, non_blocking=True) 219 | output = model(images) 220 | target = target.cuda(gpu_id, non_blocking=True) 221 | # target = target.cuda(non_blocking=True) 222 | loss = criterion(output, target) 223 | # measure accuracy and record loss 224 | prec1, prec5 = eval_criterion(output, target) 225 | prec1 = prec1.to(device=loss.device) 226 | prec5 = prec5.to(device=loss.device) 227 | 228 | if dist.is_initialized(): 229 | world_size = dist.get_world_size() 230 | dist.all_reduce(prec1) 231 | dist.all_reduce(prec5) 232 | prec1 /= world_size 233 | prec5 /= world_size 234 | 235 | losses.update(loss.item(), target.size(0)) 236 | top1.update(prec1[0], target.size(0)) 237 | top5.update(prec5[0], target.size(0)) 238 | # compute gradient and do SGD step 239 | loss.backward() 240 | 241 | if clip_gradient is not None: 242 | _ = clip_grad_norm_(model.parameters(), clip_gradient) 243 | 244 | optimizer.step() 245 | optimizer.zero_grad() 246 | 247 | # measure elapsed time 248 | batch_time.update(time.time() - end) 249 | end = time.time() 250 | if i % display == 0 and rank == 0: 251 | print('Epoch: [{0}][{1}/{2}]\t' 252 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 253 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 255 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 256 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 257 | epoch, i, len(data_loader), batch_time=batch_time, 258 | data_time=data_time, loss=losses, top1=top1, top5=top5), flush=True) 259 | num_batch += 1 260 | t_bar.update(1) 261 | 262 | if i > steps_per_epoch: 263 | break 264 | torch.cuda.empty_cache() 265 | return top1.avg, top5.avg, losses.avg, batch_time.avg, data_time.avg, num_batch 266 | 267 | 268 | def validate(data_loader, model, criterion, gpu_id=None, eval_criterion=accuracy): 269 | batch_time = AverageMeter() 270 | losses = AverageMeter() 271 | top1 = AverageMeter() 272 | top5 = AverageMeter() 273 | 274 | # switch to evaluate mode 275 | model.eval() 276 | if gpu_id is None or gpu_id == 0: 277 | disable_status_bar = False 278 | else: 279 | disable_status_bar = True 280 | 281 | with torch.no_grad(), tqdm(total=len(data_loader), disable=disable_status_bar) as t_bar: 282 | end = time.time() 283 | for i, (images, target) in enumerate(data_loader): 284 | 285 | if gpu_id is not None: 286 | if isinstance(images, list): 287 | images = [x.cuda(gpu_id, non_blocking=True) for x in images] 288 | else: 289 | images = images.cuda(gpu_id, non_blocking=True) 290 | 291 | target = target.cuda(gpu_id, non_blocking=True) 292 | 293 | # compute output 294 | output = model(images) 295 | loss = criterion(output, target) 296 | 297 | # measure accuracy and record loss 298 | prec1, prec5 = eval_criterion(output, target) 299 | prec1 = prec1.to(device=loss.device) 300 | prec5 = prec5.to(device=loss.device) 301 | 302 | if dist.is_initialized(): 303 | world_size = dist.get_world_size() 304 | dist.all_reduce(prec1) 305 | dist.all_reduce(prec5) 306 | prec1 /= world_size 307 | prec5 /= world_size 308 | losses.update(loss.item(), target.size(0)) 309 | top1.update(prec1[0], target.size(0)) 310 | top5.update(prec5[0], target.size(0)) 311 | 312 | # measure elapsed time 313 | batch_time.update(time.time() - end) 314 | end = time.time() 315 | t_bar.update(1) 316 | torch.cuda.empty_cache() 317 | return top1.avg, top5.avg, losses.avg, batch_time.avg 318 | 319 | def train_adamml(data_loader, model, criterion, optimizer, p_optimizer, epoch, modality, display=100, 320 | steps_per_epoch=99999999999, clip_gradient=None, gpu_id=None, 321 | rank=0, eval_criterion=accuracy, cost_weights=None, 322 | gammas=None, penalty_type='blockdrop'): 323 | batch_time = AverageMeter() 324 | data_time = AverageMeter() 325 | losses = AverageMeter() 326 | top1 = AverageMeter() 327 | top5 = AverageMeter() 328 | selection_meter = { m:AverageMeter() for m in modality } 329 | 330 | # set different random see every epoch 331 | if dist.is_initialized(): 332 | data_loader.sampler.set_epoch(epoch) 333 | 334 | # switch to train mode 335 | model.train() 336 | model.zero_grad() 337 | end = time.time() 338 | num_batch = 0 339 | cost_weights = [0.0] * len(modality) if cost_weights is None else cost_weights 340 | cost_weights = torch.tensor(cost_weights).cuda() 341 | 342 | gammas = torch.tensor(gammas).cuda() 343 | if gpu_id is None or gpu_id == 0: 344 | disable_status_bar = False 345 | else: 346 | disable_status_bar = True 347 | 348 | with tqdm(total=len(data_loader), disable=disable_status_bar) as t_bar: 349 | for i, (images, target) in enumerate(data_loader): 350 | # measure data loading time 351 | data_time.update(time.time() - end) 352 | # compute output 353 | if gpu_id is not None: 354 | if isinstance(images, list): 355 | images = [x.cuda(gpu_id, non_blocking=True) for x in images] 356 | else: 357 | images = images.cuda(gpu_id, non_blocking=True) 358 | 359 | output, selection = model(images) 360 | # dim of selection: NxSxM 361 | target = target.cuda(gpu_id, non_blocking=True) 362 | policy_loss = compute_policy_loss(penalty_type, selection, cost_weights, gammas, output, target) 363 | selection_ratio = selection.detach().mean(0).mean(0) 364 | cls_loss = criterion(output, target) 365 | # measure accuracy and record loss 366 | prec1, prec5 = eval_criterion(output, target) 367 | prec1 = prec1.to(device=target.device) 368 | prec5 = prec5.to(device=target.device) 369 | if dist.is_initialized(): 370 | world_size = dist.get_world_size() 371 | dist.all_reduce(prec1) 372 | dist.all_reduce(prec5) 373 | prec1 /= world_size 374 | prec5 /= world_size 375 | 376 | dist.all_reduce(selection_ratio) 377 | selection_ratio /= world_size 378 | 379 | # classification is always considered but selection loss only used in training policy 380 | loss = cls_loss 381 | if model.module.update_policy_net: 382 | loss = loss + policy_loss 383 | 384 | losses.update(loss.item(), target.size(0)) 385 | top1.update(prec1[0], target.size(0)) 386 | top5.update(prec5[0], target.size(0)) 387 | for ii, m in enumerate(modality): 388 | selection_meter[m].update(selection_ratio[ii].item()) 389 | # compute gradient and do SGD step 390 | loss.backward() 391 | 392 | if clip_gradient is not None: 393 | _ = clip_grad_norm_(model.parameters(), clip_gradient) 394 | 395 | if model.module.update_policy_net: 396 | p_optimizer.step() 397 | p_optimizer.zero_grad() 398 | if model.module.update_main_net: 399 | optimizer.step() 400 | optimizer.zero_grad() 401 | 402 | # measure elapsed time 403 | batch_time.update(time.time() - end) 404 | end = time.time() 405 | if i % display == 0 and rank == 0: 406 | selection_msg = "Selection: " 407 | for k, v in selection_meter.items(): 408 | selection_msg += "{}:{:.2f} ".format(k, v.avg * 100) 409 | print('Epoch: [{0}][{1}/{2}]\t' 410 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 411 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 412 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 413 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 414 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 415 | '{select}'.format( 416 | epoch, i, len(data_loader), batch_time=batch_time, 417 | data_time=data_time, loss=losses, top1=top1, top5=top5, select=selection_msg), flush=True) 418 | num_batch += 1 419 | t_bar.update(1) 420 | 421 | if i > steps_per_epoch: 422 | break 423 | torch.cuda.empty_cache() 424 | return top1.avg, top5.avg, losses.avg, batch_time.avg, data_time.avg, num_batch, selection_meter 425 | 426 | 427 | def validate_adamml(data_loader, model, criterion, num_segments, modality, gpu_id=None, 428 | eval_criterion=accuracy, return_output=False): 429 | batch_time = AverageMeter() 430 | losses = AverageMeter() 431 | top1 = AverageMeter() 432 | top5 = AverageMeter() 433 | selection_meter = { m:AverageMeter() for m in modality } 434 | 435 | # switch to evaluate mode 436 | model.eval() 437 | if gpu_id is None or gpu_id == 0: 438 | disable_status_bar = False 439 | else: 440 | disable_status_bar = True 441 | 442 | outputs = None 443 | labels = None 444 | all_selections = None 445 | 446 | with torch.no_grad(), tqdm(total=len(data_loader), disable=disable_status_bar) as t_bar: 447 | end = time.time() 448 | for i, (images, target) in enumerate(data_loader): 449 | 450 | if gpu_id is not None: 451 | if isinstance(images, list): 452 | images = [x.cuda(gpu_id, non_blocking=True) for x in images] 453 | else: 454 | images = images.cuda(gpu_id, non_blocking=True) 455 | target = target.cuda(gpu_id, non_blocking=True) 456 | 457 | # compute output 458 | output, selection = model(images, num_segments) 459 | # dim of selection: NxSxM 460 | 461 | loss = criterion(output, target) 462 | 463 | selection_ratio = selection.detach().mean(0).mean(0) 464 | # measure accuracy and record loss 465 | prec1, prec5 = eval_criterion(output, target) 466 | prec1 = prec1.to(device=loss.device) 467 | prec5 = prec5.to(device=loss.device) 468 | 469 | if dist.is_initialized(): 470 | world_size = dist.get_world_size() 471 | dist.all_reduce(prec1) 472 | dist.all_reduce(prec5) 473 | prec1 /= world_size 474 | prec5 /= world_size 475 | dist.all_reduce(selection_ratio) 476 | selection_ratio /= world_size 477 | 478 | losses.update(loss.item(), target.size(0)) 479 | top1.update(prec1[0], target.size(0)) 480 | top5.update(prec5[0], target.size(0)) 481 | for ii, m in enumerate(modality): 482 | selection_meter[m].update(selection_ratio[ii].item()) 483 | if outputs is None: 484 | outputs = concat_all_gather(output) if dist.is_initialized() else output 485 | labels = concat_all_gather(target) if dist.is_initialized() else target 486 | all_selections = concat_all_gather(selection) if dist.is_initialized() else selection 487 | else: 488 | outputs = torch.cat((outputs, concat_all_gather(output)), dim=0) if dist.is_initialized() else torch.cat((outputs, output), dim=0) 489 | labels = torch.cat((labels, concat_all_gather(target)), dim=0) if dist.is_initialized() else torch.cat((labels, target), dim=0) 490 | all_selections = torch.cat((all_selections, concat_all_gather(selection)), dim=0) if dist.is_initialized() else torch.cat((all_selections, selection), dim=0) 491 | 492 | # measure elapsed time 493 | batch_time.update(time.time() - end) 494 | end = time.time() 495 | t_bar.update(1) 496 | 497 | acc, mAP = actnet_acc(outputs, labels) 498 | top1, top5 = acc 499 | 500 | torch.cuda.empty_cache() 501 | 502 | flops = flops_computation(modality, selection_meter, num_segments) 503 | 504 | if return_output: 505 | return top1, top5, losses.avg, batch_time.avg, selection_meter, mAP, all_selections, flops, outputs 506 | else: 507 | return top1, top5, losses.avg, batch_time.avg, selection_meter, mAP, all_selections, flops 508 | 509 | 510 | def flops_computation(modality, ratios, num_segments, net='resnet'): 511 | 512 | main_flops = { 513 | 'rgb': 14135984128, 514 | 'flow': 16338911232, 515 | 'sound': 381739008, 516 | } 517 | 518 | policy_flops = { 519 | 'rgb': 375446400, 520 | 'sound': 381739008, 521 | 'rgbdiff': 909283200, 522 | 'lstm': 2359296 523 | } 524 | 525 | total_flops = 0 526 | 527 | for m in modality: 528 | if m == 'sound' or m == 'rgb': 529 | total_flops += (main_flops[m] * num_segments * ratios[m].avg) + (policy_flops[m] * num_segments) 530 | else: 531 | total_flops += (main_flops['flow'] * num_segments * ratios['flow'].avg) + (policy_flops['rgbdiff'] * num_segments) 532 | total_flops += policy_flops['lstm'] * num_segments 533 | total_flops /= 1e9 534 | 535 | return total_flops 536 | 537 | 538 | # utils 539 | @torch.no_grad() 540 | def concat_all_gather(tensor): 541 | """ 542 | Performs all_gather operation on the provided tensors. 543 | *** Warning ***: torch.distributed.all_gather has no gradient. 544 | """ 545 | tensors_gather = [torch.ones_like(tensor) 546 | for _ in range(torch.distributed.get_world_size())] 547 | torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False) 548 | 549 | output = torch.cat(tensors_gather, dim=0) 550 | return output 551 | -------------------------------------------------------------------------------- /utils/video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import torch.utils.data as data 6 | 7 | def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False, start_frame_idx=0, end_frame_idx=None): 8 | """ 9 | 10 | Args: 11 | video_frames (int): total frame number of a video 12 | sampling_rate (int): sampling rate for clip, pick one every k frames 13 | frames_per_clip (int): number of frames of a clip 14 | fixed_offset (bool): used with sample offset to decide the offset value deterministically. 15 | 16 | Returns: 17 | list[int]: frame indices (started from zero) 18 | """ 19 | new_sampling_rate = sampling_rate 20 | highest_idx = video_frames - new_sampling_rate * frames_per_clip if end_frame_idx is None else end_frame_idx 21 | if highest_idx <= 0: 22 | random_offset = 0 23 | else: 24 | if fixed_offset: 25 | random_offset = (video_frames - new_sampling_rate * frames_per_clip) // 2 26 | else: 27 | random_offset = int(np.random.randint(start_frame_idx, highest_idx, 1)) 28 | frame_idx = [int(random_offset + i * sampling_rate) % video_frames for i in range(frames_per_clip)] 29 | return frame_idx 30 | 31 | 32 | def compute_img_diff(image_1, image_2, bound=255.0): 33 | image_diff = np.asarray(image_1, dtype=np.float) - np.asarray(image_2, dtype=np.float) 34 | image_diff += bound 35 | image_diff *= (255.0 / float(2 * bound)) 36 | image_diff = image_diff.astype(np.uint8) 37 | image_diff = Image.fromarray(image_diff) 38 | return image_diff 39 | 40 | 41 | def load_image(root_path, directory, image_tmpl, idx, modality): 42 | """ 43 | 44 | :param root_path: 45 | :param directory: 46 | :param image_tmpl: 47 | :param idx: if it is a list, load a batch of images 48 | :param modality: 49 | :return: 50 | """ 51 | def _safe_load_image(img_path): 52 | img = None 53 | num_try = 0 54 | while num_try < 10: 55 | try: 56 | img_tmp = Image.open(img_path) 57 | img = img_tmp.copy() 58 | img_tmp.close() 59 | break 60 | except Exception as e: 61 | print('[Will try load again] error loading image: {}, ' 62 | 'error: {}'.format(img_path, str(e))) 63 | num_try += 1 64 | if img is None: 65 | raise ValueError('[Fail 10 times] error loading image: {}'.format(img_path)) 66 | return img 67 | 68 | if not isinstance(idx, list): 69 | idx = [idx] 70 | out = [] 71 | if modality == 'rgb': 72 | for i in idx: 73 | image_path_file = os.path.join(root_path, directory, image_tmpl.format(i)) 74 | out.append(_safe_load_image(image_path_file)) 75 | elif modality == 'rgbdiff': 76 | tmp = {} 77 | new_idx = np.unique(np.concatenate((np.asarray(idx), np.asarray(idx) + 1))) 78 | for i in new_idx: 79 | image_path_file = os.path.join(root_path, directory, image_tmpl.format(i)) 80 | tmp[i] = _safe_load_image(image_path_file) 81 | for k in idx: 82 | img_ = compute_img_diff(tmp[k+1], tmp[k]) 83 | out.append(img_) 84 | del tmp 85 | elif modality == 'flow': 86 | for i in idx: 87 | flow_x_name = os.path.join(root_path, directory, "x_" + image_tmpl.format(i)) 88 | flow_y_name = os.path.join(root_path, directory, "y_" + image_tmpl.format(i)) 89 | out.extend([_safe_load_image(flow_x_name), _safe_load_image(flow_y_name)]) 90 | 91 | return out 92 | 93 | 94 | def load_sound(data_dir, record, idx, fps, audio_length, resampling_rate, 95 | window_size=10, step_size=5, eps=1e-6): 96 | import librosa 97 | """idx must be the center frame of a clip""" 98 | centre_sec = (record.start_frame + idx) / fps 99 | left_sec = centre_sec - (audio_length / 2.0) 100 | right_sec = centre_sec + (audio_length / 2.0) 101 | audio_fname = os.path.join(data_dir, record.path) 102 | if not os.path.exists(audio_fname): 103 | return [Image.fromarray(np.zeros((256, 256 * int(audio_length / 1.28))))] 104 | samples, sr = librosa.core.load(audio_fname, sr=None, mono=True) 105 | duration = samples.shape[0] / float(resampling_rate) 106 | 107 | left_sample = int(round(left_sec * resampling_rate)) 108 | right_sample = int(round(right_sec * resampling_rate)) 109 | 110 | required_samples = int(round(resampling_rate * audio_length)) 111 | 112 | if left_sec < 0: 113 | samples = samples[:required_samples] 114 | elif right_sec > duration: 115 | samples = samples[-required_samples:] 116 | else: 117 | samples = samples[left_sample:right_sample] 118 | 119 | # if the samples is not long enough, repeat the waveform 120 | if len(samples) < required_samples: 121 | multiplies = required_samples / len(samples) 122 | samples = np.tile(samples, int(multiplies + 0.5) + 1) 123 | samples = samples[:required_samples] 124 | 125 | # log sepcgram 126 | nperseg = int(round(window_size * resampling_rate / 1e3)) 127 | noverlap = int(round(step_size * resampling_rate / 1e3)) 128 | spec = librosa.stft(samples, n_fft=511, window='hann', hop_length=noverlap, 129 | win_length=nperseg, pad_mode='constant') 130 | spec = np.log(np.real(spec * np.conj(spec)) + eps) 131 | img = Image.fromarray(spec) 132 | return [img] 133 | 134 | 135 | def sample_train_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling, num_clips=1): 136 | 137 | max_frame_idx = max(1, video_length - num_consecutive_frames + 1) 138 | if dense_sampling: 139 | frame_idx = np.zeros((num_clips, num_frames), dtype=int) 140 | if num_clips == 1: # backward compatibility 141 | frame_idx[0] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False)) 142 | else: 143 | max_start_frame_idx = max_frame_idx - sample_freq * num_frames 144 | frames_per_segment = max_start_frame_idx // num_clips 145 | for i in range(num_clips): 146 | if frames_per_segment <= 0: 147 | frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False)) 148 | else: 149 | frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False, i * frames_per_segment, (i+1) * frames_per_segment)) 150 | 151 | frame_idx = frame_idx.flatten() 152 | else: # uniform sampling 153 | total_frames = num_frames * sample_freq 154 | ave_frames_per_group = max_frame_idx // num_frames 155 | if ave_frames_per_group >= sample_freq: 156 | # randomly sample f images per segement 157 | frame_idx = np.arange(0, num_frames) * ave_frames_per_group 158 | frame_idx = np.repeat(frame_idx, repeats=sample_freq) 159 | offsets = np.random.choice(ave_frames_per_group, sample_freq, replace=False) 160 | offsets = np.tile(offsets, num_frames) 161 | frame_idx = frame_idx + offsets 162 | elif max_frame_idx < total_frames: 163 | # need to sample the same images 164 | frame_idx = np.random.choice(max_frame_idx, total_frames) 165 | else: 166 | # sample cross all images 167 | frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False) 168 | frame_idx = np.sort(frame_idx) 169 | frame_idx = frame_idx + 1 170 | return frame_idx 171 | 172 | 173 | def sample_val_test_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling, 174 | fixed_offset, num_clips): 175 | max_frame_idx = max(1, video_length - num_consecutive_frames + 1) 176 | if dense_sampling: 177 | if fixed_offset: 178 | sample_pos = max(1, 1 + max_frame_idx - sample_freq * num_frames) 179 | t_stride = sample_freq 180 | start_list = np.linspace(0, sample_pos - 1, num=num_clips, dtype=int) 181 | frame_idx = [] 182 | for start_idx in start_list.tolist(): 183 | frame_idx += [(idx * t_stride + start_idx) % max_frame_idx for idx in 184 | range(num_frames)] 185 | else: 186 | frame_idx = [] 187 | for i in range(num_clips): 188 | frame_idx.extend(random_clip(max_frame_idx, sample_freq, num_frames)) 189 | frame_idx = np.asarray(frame_idx) + 1 190 | else: # uniform sampling 191 | if fixed_offset: 192 | frame_idices = [] 193 | sample_offsets = list(range(-num_clips // 2 + 1, num_clips // 2 + 1)) 194 | for sample_offset in sample_offsets: 195 | if max_frame_idx > num_frames: 196 | tick = max_frame_idx / float(num_frames) 197 | curr_sample_offset = sample_offset 198 | if curr_sample_offset >= tick / 2.0: 199 | curr_sample_offset = tick / 2.0 - 1e-4 200 | elif curr_sample_offset < -tick / 2.0: 201 | curr_sample_offset = -tick / 2.0 202 | frame_idx = np.array([int(tick / 2.0 + curr_sample_offset + tick * x) for x in 203 | range(num_frames)]) 204 | else: 205 | np.random.seed(sample_offset - (-num_clips // 2 + 1)) 206 | frame_idx = np.random.choice(max_frame_idx, num_frames) 207 | frame_idx = np.sort(frame_idx) 208 | frame_idices.extend(frame_idx.tolist()) 209 | else: 210 | frame_idices = [] 211 | for i in range(num_clips): 212 | total_frames = num_frames * sample_freq 213 | ave_frames_per_group = max_frame_idx // num_frames 214 | if ave_frames_per_group >= sample_freq: 215 | # randomly sample f images per segment 216 | frame_idx = np.arange(0, num_frames) * ave_frames_per_group 217 | frame_idx = np.repeat(frame_idx, repeats=sample_freq) 218 | offsets = np.random.choice(ave_frames_per_group, sample_freq, 219 | replace=False) 220 | offsets = np.tile(offsets, num_frames) 221 | frame_idx = frame_idx + offsets 222 | elif max_frame_idx < total_frames: 223 | # need to sample the same images 224 | np.random.seed(i) 225 | frame_idx = np.random.choice(max_frame_idx, total_frames) 226 | else: 227 | # sample cross all images 228 | np.random.seed(i) 229 | frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False) 230 | frame_idx = np.sort(frame_idx) 231 | frame_idices.extend(frame_idx.tolist()) 232 | frame_idx = np.asarray(frame_idices) + 1 233 | return frame_idx 234 | 235 | 236 | class VideoRecord(object): 237 | def __init__(self, path, start_frame, end_frame, label, reverse=False): 238 | self.path = path 239 | self.video_id = os.path.basename(path) 240 | self.start_frame = start_frame 241 | self.end_frame = end_frame 242 | self.label = label 243 | self.reverse = reverse 244 | 245 | @property 246 | def num_frames(self): 247 | return self.end_frame - self.start_frame + 1 248 | 249 | def __str__(self): 250 | return self.path 251 | 252 | 253 | class VideoDataSet(data.Dataset): 254 | 255 | def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1, 256 | modality='rgb', dense_sampling=False, fixed_offset=True, 257 | image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ', 258 | filter_video=0, num_classes=None, 259 | fps=29.97, audio_length=1.28, resampling_rate=24000): 260 | """ 261 | 262 | Arguments have different meaning when dense_sampling is True: 263 | - num_groups ==> number of frames 264 | - frames_per_group ==> sample every K frame 265 | - sample_offset ==> number of clips used in validation or test mode 266 | 267 | Args: 268 | root_path (str): the file path to the root of video folder 269 | list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id 270 | num_groups (int): number of frames per data sample 271 | frames_per_group (int): number of frames within one group 272 | sample_offset (int): used in validation/test, the offset when sampling frames from a group 273 | modality (str): rgb or flow 274 | dense_sampling (bool): dense sampling in I3D 275 | fixed_offset (bool): used for generating the same videos used in TSM 276 | image_tmpl (str): template of image ids 277 | transform: the transformer for preprocessing 278 | is_train (bool): shuffle the video but keep the causality 279 | test_mode (bool): testing mode, no label 280 | fps (float): frame rate per second, used to localize sound when frame idx is selected. 281 | audio_length (float): the time window to extract audio feature. 282 | resampling_rate (int): used to resampling audio extracted from wav 283 | """ 284 | if modality not in ['flow', 'rgb', 'rgbdiff', 'sound']: 285 | raise ValueError("modality should be 'flow' or 'rgb' or 'rgbdiff' or 'sound'.") 286 | 287 | self.root_path = root_path 288 | self.list_file = os.path.join(root_path, list_file) 289 | self.num_groups = num_groups 290 | self.num_frames = num_groups 291 | self.frames_per_group = frames_per_group 292 | self.sample_freq = frames_per_group 293 | self.num_clips = num_clips 294 | self.sample_offset = sample_offset 295 | self.fixed_offset = fixed_offset 296 | self.dense_sampling = dense_sampling 297 | self.modality = modality.lower() 298 | self.image_tmpl = image_tmpl 299 | self.transform = transform 300 | self.is_train = is_train 301 | self.test_mode = test_mode 302 | self.separator = seperator 303 | self.filter_video = filter_video 304 | self.fps = fps 305 | self.audio_length = audio_length 306 | self.resampling_rate = resampling_rate 307 | self.video_length = (self.num_frames * self.sample_freq) / self.fps 308 | 309 | 310 | if self.modality in ['flow', 'rgbdiff']: 311 | self.num_consecutive_frames = 5 312 | else: 313 | self.num_consecutive_frames = 1 314 | 315 | self.video_list, self.multi_label = self._parse_list() 316 | self.num_classes = num_classes 317 | 318 | def _parse_list(self): 319 | # usually it is [video_id, num_frames, class_idx] 320 | # or [video_id, start_frame, end_frame, list of class_idx] 321 | tmp = [] 322 | original_video_numbers = 0 323 | for x in open(self.list_file): 324 | elements = x.strip().split(self.separator) 325 | start_frame = int(elements[1]) 326 | end_frame = int(elements[2]) 327 | total_frame = end_frame - start_frame + 1 328 | original_video_numbers += 1 329 | if self.test_mode: 330 | tmp.append(elements) 331 | else: 332 | if total_frame >= self.filter_video: 333 | tmp.append(elements) 334 | 335 | num = len(tmp) 336 | print("The number of videos is {} (with more than {} frames) " 337 | "(original: {})".format(num, self.filter_video, original_video_numbers), flush=True) 338 | assert (num > 0) 339 | multi_label = np.mean(np.asarray([len(x) for x in tmp])) > 4.0 340 | file_list = [] 341 | for item in tmp: 342 | if self.test_mode: 343 | file_list.append([item[0], int(item[1]), int(item[2]), -1]) 344 | else: 345 | labels = [] 346 | for i in range(3, len(item)): 347 | labels.append(float(item[i])) 348 | if not multi_label: 349 | labels = labels[0] if len(labels) == 1 else labels 350 | file_list.append([item[0], int(item[1]), int(item[2]), labels]) 351 | 352 | video_list = [VideoRecord(item[0], item[1], item[2], item[3]) for item in file_list] 353 | # flow model has one frame less 354 | if self.modality in ['rgbdiff']: 355 | for i in range(len(video_list)): 356 | video_list[i].end_frame -= 1 357 | 358 | return video_list, multi_label 359 | 360 | def remove_data(self, idx): 361 | original_video_num = len(self.video_list) 362 | self.video_list = [v for i, v in enumerate(self.video_list) if i not in idx] 363 | print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), len(self.video_list))) 364 | 365 | def _sample_indices(self, record): 366 | return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames, 367 | self.sample_freq, self.dense_sampling, self.num_clips) 368 | 369 | def _get_val_indices(self, record): 370 | return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames, 371 | self.sample_freq, self.dense_sampling, self.fixed_offset, 372 | self.num_clips) 373 | 374 | def __getitem__(self, index): 375 | """ 376 | Returns: 377 | torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group. 378 | torch.FloatTensor: the label 379 | """ 380 | record = self.video_list[index] 381 | # check this is a legit video folder 382 | indices = self._sample_indices(record) if self.is_train else self._get_val_indices(record) 383 | images = self.get_data(record, indices) 384 | images = self.transform(images) 385 | label = self.get_label(record) 386 | 387 | # re-order data to targeted format. 388 | return images, label 389 | 390 | def get_data(self, record, indices): 391 | images = [] 392 | num_clips = self.num_clips 393 | if self.modality == 'sound': 394 | new_indices = [indices[i * self.num_frames: (i + 1) * self.num_frames] 395 | for i in range(num_clips)] 396 | for curr_indiecs in new_indices: 397 | center_idx = (curr_indiecs[self.num_frames // 2 - 1] + curr_indiecs[self.num_frames // 2]) // 2 \ 398 | if self.num_frames % 2 == 0 else curr_indiecs[self.num_frames // 2] 399 | center_idx = min(record.num_frames, center_idx) 400 | seg_imgs = load_sound(self.root_path, record, center_idx, 401 | self.fps, self.audio_length, self.resampling_rate) 402 | images.extend(seg_imgs) 403 | else: 404 | images = [] 405 | for seg_ind in indices: 406 | new_seg_ind = [min(seg_ind + record.start_frame - 1 + i, record.num_frames) 407 | for i in range(self.num_consecutive_frames)] 408 | seg_imgs = load_image(self.root_path, record.path, self.image_tmpl, 409 | new_seg_ind, self.modality) 410 | images.extend(seg_imgs) 411 | return images 412 | 413 | def get_label(self, record): 414 | if self.test_mode: 415 | # in test mode, return the video id as label 416 | label = record.video_id 417 | else: 418 | if not self.multi_label: 419 | label = int(record.label) 420 | else: 421 | # create a binary vector. 422 | label = torch.zeros(self.num_classes, dtype=torch.float) 423 | for x in record.label: 424 | label[int(x)] = 1.0 425 | return label 426 | 427 | def __len__(self): 428 | return len(self.video_list) 429 | 430 | 431 | class MultiVideoDataSet(data.Dataset): 432 | 433 | def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1, 434 | modality='rgb', dense_sampling=False, fixed_offset=True, 435 | image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ', 436 | filter_video=0, num_classes=None, 437 | fps=29.97, audio_length=1.28, resampling_rate=24000): 438 | """ 439 | # root_path, modality and transform become list, each for one modality 440 | 441 | Argments have different meaning when dense_sampling is True: 442 | - num_groups ==> number of frames 443 | - frames_per_group ==> sample every K frame 444 | - sample_offset ==> number of clips used in validation or test mode 445 | 446 | Args: 447 | root_path (str): the file path to the root of video folder 448 | list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id 449 | num_groups (int): number of frames per data sample 450 | frames_per_group (int): number of frames within one group 451 | sample_offset (int): used in validation/test, the offset when sampling frames from a group 452 | modality (str): rgb or flow 453 | dense_sampling (bool): dense sampling in I3D 454 | fixed_offset (bool): used for generating the same videos used in TSM 455 | image_tmpl (str): template of image ids 456 | transform: the transformer for preprocessing 457 | is_train (bool): shuffle the video but keep the causality 458 | test_mode (bool): testing mode, no label 459 | """ 460 | 461 | video_datasets = [] 462 | for i in range(len(modality)): 463 | tmp = VideoDataSet(root_path[i], os.path.join(root_path[i], list_file), 464 | num_groups, frames_per_group, sample_offset, 465 | num_clips, modality[i], dense_sampling, fixed_offset, 466 | image_tmpl, transform[i], is_train, test_mode, seperator, 467 | filter_video, num_classes, fps, audio_length, resampling_rate) 468 | video_datasets.append(tmp) 469 | 470 | self.video_datasets = video_datasets 471 | self.is_train = is_train 472 | self.test_mode = test_mode 473 | self.num_frames = num_groups 474 | self.sample_freq = frames_per_group 475 | self.dense_sampling = dense_sampling 476 | self.num_clips = num_clips 477 | self.fixed_offset = fixed_offset 478 | self.modality = modality 479 | self.num_classes = num_classes 480 | 481 | self.video_list = video_datasets[0].video_list 482 | self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets]) 483 | 484 | def _sample_indices(self, record): 485 | return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames, 486 | self.sample_freq, self.dense_sampling, self.num_clips) 487 | 488 | def _get_val_indices(self, record): 489 | return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames, 490 | self.sample_freq, self.dense_sampling, self.fixed_offset, 491 | self.num_clips) 492 | 493 | def remove_data(self, idx): 494 | for i in range(len(self.video_datasets)): 495 | self.video_datasets[i].remove_data(idx) 496 | self.video_list = self.video_datasets[0].video_list 497 | 498 | def __getitem__(self, index): 499 | """ 500 | Returns: 501 | torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group. 502 | torch.FloatTensor: the label 503 | """ 504 | 505 | record = self.video_list[index] 506 | if self.is_train: 507 | indices = self._sample_indices(record) 508 | else: 509 | indices = self._get_val_indices(record) 510 | 511 | multi_modalities = [] 512 | for modality, video_dataset in zip(self.modality, self.video_datasets): 513 | record = video_dataset.video_list[index] 514 | images = video_dataset.get_data(record, indices) 515 | images = video_dataset.transform(images) 516 | label = video_dataset.get_label(record) 517 | multi_modalities.append((images, label)) 518 | 519 | return [x for x, y in multi_modalities], multi_modalities[0][1] 520 | 521 | def __len__(self): 522 | return len(self.video_list) 523 | 524 | -------------------------------------------------------------------------------- /utils/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | class GroupRandomCrop(object): 10 | def __init__(self, size): 11 | if isinstance(size, numbers.Number): 12 | self.size = (int(size), int(size)) 13 | else: 14 | self.size = size 15 | 16 | def __call__(self, img_group): 17 | 18 | w, h = img_group[0].size 19 | th, tw = self.size 20 | 21 | out_images = list() 22 | 23 | x1 = random.randint(0, w - tw) 24 | y1 = random.randint(0, h - th) 25 | 26 | for img in img_group: 27 | assert(img.size[0] == w and img.size[1] == h) 28 | if w == tw and h == th: 29 | out_images.append(img) 30 | else: 31 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 32 | 33 | return out_images 34 | 35 | 36 | class GroupCenterCrop(object): 37 | def __init__(self, size): 38 | self.worker = torchvision.transforms.CenterCrop(size) 39 | 40 | def __call__(self, img_group): 41 | return [self.worker(img) for img in img_group] 42 | 43 | 44 | class GroupRandomHorizontalFlip(object): 45 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 46 | """ 47 | def __init__(self, is_flow=False): 48 | self.is_flow = is_flow 49 | 50 | def __call__(self, img_group, is_flow=False): 51 | v = random.random() 52 | if v < 0.5: 53 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 54 | if self.is_flow: 55 | for i in range(0, len(ret), 2): 56 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 57 | return ret 58 | else: 59 | return img_group 60 | 61 | 62 | class GroupNormalize(object): 63 | def __init__(self, mean, std, threed_data=False): 64 | self.threed_data = threed_data 65 | if self.threed_data: 66 | # convert to the proper format 67 | self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1) 68 | self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1) 69 | else: 70 | self.mean = mean 71 | self.std = std 72 | 73 | def __call__(self, tensor): 74 | 75 | if self.threed_data: 76 | tensor.sub_(self.mean).div_(self.std) 77 | else: 78 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 79 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 80 | 81 | for t, m, s in zip(tensor, rep_mean, rep_std): 82 | t.sub_(m).div_(s) 83 | 84 | return tensor 85 | 86 | 87 | class GroupScale(object): 88 | """ Rescales the input PIL.Image to the given 'size'. 89 | 'size' will be the size of the smaller edge. 90 | For example, if height > width, then image will be 91 | rescaled to (size * height / width, size) 92 | size: size of the smaller edge 93 | interpolation: Default: PIL.Image.BILINEAR 94 | """ 95 | 96 | def __init__(self, size, interpolation=Image.BILINEAR): 97 | self.worker = torchvision.transforms.Resize(size, interpolation) 98 | 99 | def __call__(self, img_group): 100 | return [self.worker(img) for img in img_group] 101 | 102 | class GroupRandomScale(object): 103 | """ Rescales the input PIL.Image to the given 'size'. 104 | 'size' will be the size of the smaller edge. 105 | For example, if height > width, then image will be 106 | rescaled to (size * height / width, size) 107 | size: size of the smaller edge 108 | interpolation: Default: PIL.Image.BILINEAR 109 | 110 | Randomly select the smaller edge from the range of 'size'. 111 | """ 112 | def __init__(self, size, interpolation=Image.BILINEAR): 113 | self.size = size 114 | self.interpolation = interpolation 115 | 116 | def __call__(self, img_group): 117 | selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int) 118 | scale = GroupScale(selected_size, interpolation=self.interpolation) 119 | return scale(img_group) 120 | 121 | class GroupOverSample(object): 122 | def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False): 123 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 124 | 125 | if scale_size is not None: 126 | self.scale_worker = GroupScale(scale_size) 127 | else: 128 | self.scale_worker = None 129 | 130 | if num_crops not in [1, 3, 5, 10]: 131 | raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops)) 132 | self.num_crops = num_crops 133 | 134 | self.flip = flip 135 | 136 | def __call__(self, img_group): 137 | 138 | if self.scale_worker is not None: 139 | img_group = self.scale_worker(img_group) 140 | 141 | image_w, image_h = img_group[0].size 142 | crop_w, crop_h = self.crop_size 143 | 144 | if self.num_crops == 3: 145 | w_step = (image_w - crop_w) // 4 146 | h_step = (image_h - crop_h) // 4 147 | offsets = list() 148 | if image_w < image_h: 149 | offsets.append((2 * w_step, 0 * h_step)) # top 150 | offsets.append((2 * w_step, 4 * h_step)) # bottom 151 | offsets.append((2 * w_step, 2 * h_step)) # center 152 | else: 153 | offsets.append((0 * w_step, 2 * h_step)) # left 154 | offsets.append((4 * w_step, 2 * h_step)) # right 155 | offsets.append((2 * w_step, 2 * h_step)) # center 156 | 157 | else: 158 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 159 | 160 | oversample_group = list() 161 | for o_w, o_h in offsets: 162 | normal_group = list() 163 | flip_group = list() 164 | for i, img in enumerate(img_group): 165 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 166 | normal_group.append(crop) 167 | if self.flip: 168 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 169 | 170 | if img.mode == 'L' and i % 2 == 0: 171 | flip_group.append(ImageOps.invert(flip_crop)) 172 | else: 173 | flip_group.append(flip_crop) 174 | 175 | oversample_group.extend(normal_group) 176 | if self.flip: 177 | oversample_group.extend(flip_group) 178 | return oversample_group 179 | 180 | 181 | class GroupMultiScaleCrop(object): 182 | 183 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 184 | self.scales = scales if scales is not None else [1, 875, .75, .66] 185 | self.max_distort = max_distort 186 | self.fix_crop = fix_crop 187 | self.more_fix_crop = more_fix_crop 188 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 189 | self.interpolation = Image.BILINEAR 190 | 191 | def __call__(self, img_group): 192 | 193 | im_size = img_group[0].size 194 | 195 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 196 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 197 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 198 | for img in crop_img_group] 199 | return ret_img_group 200 | 201 | def _sample_crop_size(self, im_size): 202 | image_w, image_h = im_size[0], im_size[1] 203 | 204 | # find a crop size 205 | base_size = min(image_w, image_h) 206 | crop_sizes = [int(base_size * x) for x in self.scales] 207 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 208 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 209 | 210 | pairs = [] 211 | for i, h in enumerate(crop_h): 212 | for j, w in enumerate(crop_w): 213 | if abs(i - j) <= self.max_distort: 214 | pairs.append((w, h)) 215 | 216 | crop_pair = random.choice(pairs) 217 | if not self.fix_crop: 218 | w_offset = random.randint(0, image_w - crop_pair[0]) 219 | h_offset = random.randint(0, image_h - crop_pair[1]) 220 | else: 221 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 222 | 223 | return crop_pair[0], crop_pair[1], w_offset, h_offset 224 | 225 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 226 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 227 | return random.choice(offsets) 228 | 229 | @staticmethod 230 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 231 | w_step = (image_w - crop_w) // 4 232 | h_step = (image_h - crop_h) // 4 233 | 234 | ret = list() 235 | ret.append((0, 0)) # upper left 236 | ret.append((4 * w_step, 0)) # upper right 237 | ret.append((0, 4 * h_step)) # lower left 238 | ret.append((4 * w_step, 4 * h_step)) # lower right 239 | ret.append((2 * w_step, 2 * h_step)) # center 240 | 241 | if more_fix_crop: 242 | ret.append((0, 2 * h_step)) # center left 243 | ret.append((4 * w_step, 2 * h_step)) # center right 244 | ret.append((2 * w_step, 4 * h_step)) # lower center 245 | ret.append((2 * w_step, 0 * h_step)) # upper center 246 | 247 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 248 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 249 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 250 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 251 | 252 | return ret 253 | 254 | 255 | class GroupRandomSizedCrop(object): 256 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 257 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 258 | This is popularly used to train the Inception networks 259 | size: size of the smaller edge 260 | interpolation: Default: PIL.Image.BILINEAR 261 | """ 262 | def __init__(self, size, interpolation=Image.BILINEAR): 263 | self.size = size 264 | self.interpolation = interpolation 265 | 266 | def __call__(self, img_group): 267 | for attempt in range(10): 268 | area = img_group[0].size[0] * img_group[0].size[1] 269 | target_area = random.uniform(0.08, 1.0) * area 270 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 271 | 272 | w = int(round(math.sqrt(target_area * aspect_ratio))) 273 | h = int(round(math.sqrt(target_area / aspect_ratio))) 274 | 275 | if random.random() < 0.5: 276 | w, h = h, w 277 | 278 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 279 | x1 = random.randint(0, img_group[0].size[0] - w) 280 | y1 = random.randint(0, img_group[0].size[1] - h) 281 | found = True 282 | break 283 | else: 284 | found = False 285 | x1 = 0 286 | y1 = 0 287 | 288 | if found: 289 | out_group = list() 290 | for img in img_group: 291 | img = img.crop((x1, y1, x1 + w, y1 + h)) 292 | assert(img.size == (w, h)) 293 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 294 | return out_group 295 | else: 296 | # Fallback 297 | scale = GroupScale(self.size, interpolation=self.interpolation) 298 | crop = GroupRandomCrop(self.size) 299 | return crop(scale(img_group)) 300 | 301 | 302 | class Stack(object): 303 | 304 | def __init__(self, roll=False, threed_data=False): 305 | self.roll = roll 306 | self.threed_data = threed_data 307 | 308 | def __call__(self, img_group): 309 | if img_group[0].mode == 'L' or img_group[0].mode == 'F': 310 | return np.concatenate([np.expand_dims(np.array(x), 2) for x in img_group], axis=2) 311 | elif img_group[0].mode == 'RGB': 312 | if self.threed_data: 313 | return np.stack(img_group, axis=0) 314 | else: 315 | if self.roll: 316 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 317 | else: 318 | return np.concatenate(img_group, axis=2) 319 | 320 | 321 | class ToTorchFormatTensor(object): 322 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 323 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 324 | def __init__(self, div=True, num_clips_crops=1): 325 | self.div = div 326 | self.num_clips_crops = num_clips_crops 327 | 328 | def __call__(self, pic): 329 | if isinstance(pic, np.ndarray): 330 | # handle numpy array 331 | if len(pic.shape) == 4: 332 | # ((NF)xCxHxW) --> (Cx(NF)xHxW) 333 | img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() 334 | else: # data is HW(FC) 335 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 336 | else: 337 | # handle PIL Image 338 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 339 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 340 | # put it from HWC to CHW format 341 | # yikes, this transpose takes 80% of the loading time/CPU 342 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 343 | return img.float().div(255) if self.div else img.float() 344 | 345 | 346 | class IdentityTransform(object): 347 | 348 | def __call__(self, data): 349 | return data 350 | 351 | 352 | if __name__ == "__main__": 353 | trans = torchvision.transforms.Compose([ 354 | GroupScale(256), 355 | GroupRandomCrop(224), 356 | Stack(), 357 | ToTorchFormatTensor(), 358 | GroupNormalize( 359 | mean=[.485, .456, .406], 360 | std=[.229, .224, .225] 361 | )] 362 | ) 363 | 364 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 365 | 366 | color_group = [im] * 3 367 | rst = trans(color_group) 368 | 369 | gray_group = [im.convert('L')] * 9 370 | gray_rst = trans(gray_group) 371 | 372 | trans2 = torchvision.transforms.Compose([ 373 | GroupRandomSizedCrop(256), 374 | Stack(), 375 | ToTorchFormatTensor(), 376 | GroupNormalize( 377 | mean=[.485, .456, .406], 378 | std=[.229, .224, .225]) 379 | ]) 380 | print(trans2(color_group)) 381 | --------------------------------------------------------------------------------