├── .gitignore ├── Checkpoints └── readme.md ├── LICENSE ├── README.md ├── config ├── IsoGD.yml ├── Jester.yml ├── NTU.yml ├── NetworkConfig.yml ├── NvGesture.yml ├── THU.yml ├── __init__.py └── config.py ├── data └── data_preprose.py ├── demo ├── decouple_recouple.jpg ├── pipline.jpg └── readme.md ├── lib ├── __init__.py ├── datasets │ ├── IsoGD.py │ ├── Jester.py │ ├── NTU.py │ ├── NvGesture.py │ ├── THU_READ.py │ ├── __init__.py │ ├── base.py │ ├── build.py │ └── distributed_sampler.py └── model │ ├── DSN.py │ ├── DSN_Fusion.py │ ├── DTN.py │ ├── FRP.py │ ├── __init__.py │ ├── build.py │ ├── fusion_Net.py │ ├── trans_module.py │ └── utils.py ├── run.sh ├── tools ├── fusion.py ├── readme.md └── train.py └── utils ├── __init__.py ├── build.py ├── evaluate_metric.py ├── print_function.py ├── utils.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | .dumbo.json 4 | Checkpoints/ 5 | out/ 6 | demo/ 7 | *.tar 8 | core* 9 | bk/ 10 | *.ipynb 11 | *.ipynb* -------------------------------------------------------------------------------- /Checkpoints/readme.md: -------------------------------------------------------------------------------- 1 | This folder is necessary because it is used to save all training logs and models. 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DamoCV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR2022](https://openaccess.thecvf.com/content/CVPR2022/html/Zhou_Decoupling_and_Recoupling_Spatiotemporal_Representation_for_RGB-D-Based_Motion_Recognition_CVPR_2022_paper.html) Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/decoupling-and-recoupling-spatiotemporal/hand-gesture-recognition-on-nvgesture-1)](https://paperswithcode.com/sota/hand-gesture-recognition-on-nvgesture-1?p=decoupling-and-recoupling-spatiotemporal) 4 | 5 | This repo is the official implementation of "Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition" as well as the follow-ups. It currently includes code and models for the following tasks: 6 | > **RGB-D-based Action Recognition**: Included in this repo. 7 | 8 | > **RGB-D-based Gesture Recognition**: Included in this repo. 9 | 10 | >**Dynamic motion attention capture based on native video frames**: Included in this repo. See FRP module in this paper. 11 | 12 | ## Updates 13 | ***27/07/2023*** 14 | 1. Updated the link of the journal expansion version [UMDR-Net](https://github.com/zhoubenjia/MotionRGBD-PAMI)(TPAMI'23) of this conference paper. 15 | 2. Updated the link of its improved version [MFST](https://arxiv.org/pdf/2308.12006.pdf)(MM'23). 16 | 17 | ***27/10/2022*** 18 | 1. Update the code of NTU data preprocessing. 19 | 2. Fixed a bug in the DTN. 20 | 21 | ***18/10/2022*** 22 | 1. Update the code of NvGesture training. 23 | 24 | ## 1. Requirements 25 | This is a PyTorch implementation of our paper. 26 | 27 | torch>=1.7.0; torchvision>=0.8.0; Visdom(optional) 28 | 29 | data prepare: Database with the following folder structure: 30 | 31 | ``` 32 | │NTURGBD/ 33 | ├──dataset_splits/ 34 | │ ├── @CS 35 | │ │ ├── train.txt 36 | video name total frames label 37 | │ │ │ ├──S001C001P001R001A001_rgb 103 0 38 | │ │ │ ├──S001C001P001R001A004_rgb 99 3 39 | │ │ │ ├──...... 40 | │ │ ├── valid.txt 41 | │ ├── @CV 42 | │ │ ├── train.txt 43 | │ │ ├── valid.txt 44 | ├──Images/ 45 | │ │ ├── S001C002P001R001A002_rgb 46 | │ │ │ ├──000000.jpg 47 | │ │ │ ├──000001.jpg 48 | │ │ │ ├──...... 49 | ├──nturgb+d_depth_masked/ 50 | │ │ ├── S001C002P001R001A002 51 | │ │ │ ├──MDepth-00000000.png 52 | │ │ │ ├──MDepth-00000001.png 53 | │ │ │ ├──...... 54 | ``` 55 | It is important to note that due to the RGB video resolution in the NTU dataset is relatively high, so we are not directly to resize the image from the original resolution to 320x240, but first crop the object-centered ROI area (640x480), and then resize it to 320x240 for training and testing. 56 | 57 | ## 2. Methodology 58 |

59 | 60 | 61 |

62 | We propose to decouple and recouple spatiotemporal representation for RGB-D-based motion recognition. The Figure in the first line illustrates the proposed multi-modal spatiotemporal representation learning framework. The RGB-D-based motion recognition can be described as spatiotemporal information decoupling modeling, compact representation recoupling learning, and cross-modal representation interactive learning. 63 | The Figure in the second line shows the process of decoupling and recoupling saptiotemporal representation of a unimodal data. 64 | 65 | ## 3. Train and Evaluate 66 | All of our models are pre-trained on the [20BN Jester V1 dataset](https://www.kaggle.com/toxicmender/20bn-jester) and the pretrained model can be download [here](https://drive.google.com/drive/folders/1eBXED3uXlzBZzix7TvtDlJrZ3SlDCSF6?usp=sharing). Before cross-modal representation interactive learning, we first separately perform unimodal representation learning on RGB and depth data modalities. 67 | ### Unimodal Training 68 | Take training an RGB model with 8 GPUs on the NTU-RGBD dataset as an example, 69 | some basic configuration: 70 | ```bash 71 | common: 72 | dataset: NTU 73 | batch_size: 6 74 | test_batch_size: 6 75 | num_workers: 6 76 | learning_rate: 0.01 77 | learning_rate_min: 0.00001 78 | momentum: 0.9 79 | weight_decay: 0.0003 80 | init_epochs: 0 81 | epochs: 100 82 | optim: SGD 83 | scheduler: 84 | name: cosin # Represent decayed learning rate with the cosine schedule 85 | warm_up_epochs: 3 86 | loss: 87 | name: CE # cross entropy loss function 88 | labelsmooth: True 89 | MultiLoss: True # Enable multi-loss training strategy. 90 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] # The loss weight coefficient assigned for each sub-branch. 91 | distill: 1. # The loss weight coefficient assigned for distillation task. 92 | 93 | model: 94 | Network: I3DWTrans # I3DWTrans represent unimodal training, set FusionNet for multi-modal fusion training. 95 | sample_duration: 64 # Sampled frames in a video. 96 | sample_size: 224 # The image is croped into 224x224. 97 | grad_clip: 5. 98 | SYNC_BN: 1 # Utilize SyncBatchNorm. 99 | w: 10 # Sliding window size. 100 | temper: 0.5 # Distillation temperature setting. 101 | recoupling: True # Enable recoupling strategy during training. 102 | knn_attention: 0.7 # Hyperparameter used in k-NN attention: selecting Top-70% tokens. 103 | sharpness: True # Enable sharpness for each sub-branch's output. 104 | temp: [ 0.04, 0.07 ] # Temperature parameter follows a cosine schedule from 0.04 to 0.07 during the training. 105 | frp: True # Enable FRP module. 106 | SEHeads: 1 # Number of heads used in RCM module. 107 | N: 6 # Number of Transformer blochs configured for each sub-branch. 108 | 109 | dataset: 110 | type: M # M: RGB modality, K: Depth modality. 111 | flip: 0.5 # Horizontal flip. 112 | rotated: 0.5 # Horizontal rotation 113 | angle: (-10, 10) # Rotation angle 114 | Blur: False # Enable random blur operation for each video frame. 115 | resize: (320, 240) # The input is spatially resized to 320x240 for NTU dataset. 116 | crop_size: 224 117 | low_frames: 16 # Number of frames sampled for small Transformer. 118 | media_frames: 32 # Number of frames sampled for medium Transformer. 119 | high_frames: 48 # Number of frames sampled for large Transformer. 120 | ``` 121 | 122 | ```bash 123 | bash run.sh tools/train.py config/NTU.yml 0,1,2,3,4,5,6,7 8 124 | ``` 125 | or 126 | ```bash 127 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 train.py --config config/NTU.yml --nprocs 8 128 | ``` 129 | 130 | ### Cross-modal Representation Interactive Learning 131 | Take training a fusion model with 8 GPUs on the NTU-RGBD dataset as an example. 132 | ```bash 133 | bash run.sh tools/fusion.py config/NTU.yml 0,1,2,3,4,5,6,7 8 134 | ``` 135 | or 136 | ```bash 137 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 tools/fusion.py --config config/NTU.yml --nprocs 8 138 | ``` 139 | 140 | ### Evaluation 141 | ```bash 142 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=1234 tools/train.py --config config/NTU.yml --nprocs 1 --eval_only --resume /path/to/model_best.pth.tar 143 | ``` 144 | 145 | ## 4. Models Download 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 |
DatasetModalityAccuracyDownload
NvGestureRGB89.58Google Drive
NvGestureDepth90.62Google Drive
NvGestureRGB-D91.70Google Drive
THU-READRGB81.25Google Drive
THU-READDepth77.92Google Drive
THU-READRGB-D87.04Google Drive
NTU-RGBD(CS)RGB90.3Google Drive
NTU-RGBD(CS)Depth92.7Google Drive
NTU-RGBD(CS)RGB-D94.2Google Drive
NTU-RGBD(CV)RGB95.4Google Drive
NTU-RGBD(CV)Depth96.2Google Drive
NTU-RGBD(CV)RGB-D97.3Google Drive
IsoGDRGB60.87Google Drive
IsoGDDepth60.17Google Drive
IsoGDRGB-D66.79Google Drive
248 | 249 | # Citation 250 | ``` 251 | @InProceedings{Zhou_2022_CVPR, 252 | author = {Zhou, Benjia and Wang, Pichao and Wan, Jun and Liang, Yanyan and Wang, Fan and Zhang, Du and Lei, Zhen and Li, Hao and Jin, Rong}, 253 | title = {Decoupling and Recoupling Spatiotemporal Representation for RGB-D-Based Motion Recognition}, 254 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 255 | month = {June}, 256 | year = {2022}, 257 | pages = {20154-20163} 258 | } 259 | ``` 260 | # LICENSE 261 | The code is released under the MIT license. 262 | # Copyright 263 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 264 | -------------------------------------------------------------------------------- /config/IsoGD.yml: -------------------------------------------------------------------------------- 1 | common: 2 | data: /path/to/IsoGD/Dataset 3 | splits: /path/to/IsoGD/Dataset/dataset_splits 4 | 5 | #-------basic Hyparameter---------- 6 | visdom: 7 | enable: False 8 | visname: IsoGD 9 | 10 | dataset: IsoGD #Database name e.g., NTU, THUREAD ... 11 | batch_size: 6 12 | test_batch_size: 6 13 | num_workers: 6 14 | learning_rate: 0.01 15 | learning_rate_min: 0.00001 16 | momentum: 0.9 17 | weight_decay: 0.0003 18 | init_epochs: 0 19 | epochs: 300 20 | report_freq: 10 21 | optim: SGD 22 | dist: True 23 | vis_feature: True # Feature Visualization? 24 | DEBUG: False 25 | 26 | scheduler: 27 | name: ReduceLR 28 | patience: 4 29 | warm_up_epochs: 3 30 | loss: 31 | name: CE 32 | labelsmooth: True 33 | MultiLoss: True 34 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 35 | distill: 1. 36 | resume_scheduler: 0 37 | model: 38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet 39 | pretrained: '' 40 | # resume: '' 41 | resumelr: 0.0001 42 | sample_duration: 64 43 | sample_size: 224 44 | grad_clip: 5. 45 | SYNC_BN: 1 46 | w: 10 47 | temper: 0.5 48 | recoupling: True 49 | knn_attention: 0.7 50 | sharpness: True 51 | temp: [ 0.04, 0.07 ] 52 | frp: True 53 | SEHeads: 1 54 | N: 6 # Number of Transformer Blocks 55 | #-------Used for fusion network---------- 56 | rgb_checkpoint: '' 57 | depth_checkpoint: '' 58 | dataset: 59 | type: M # M: rgb, K: depth 60 | flip: 0.0 61 | rotated: 0.5 62 | angle: (-10, 10) # Rotation angle 63 | Blur: False 64 | resize: (256, 256) 65 | crop_size: 224 66 | low_frames: 16 67 | media_frames: 32 68 | high_frames: 48 69 | -------------------------------------------------------------------------------- /config/Jester.yml: -------------------------------------------------------------------------------- 1 | common: 2 | data: /media/ssd1/bjzhou/ssd2/bjzhou/Jester/20bn-jester-v1 3 | splits: /media/ssd1/bjzhou/ssd2/bjzhou/Jester/dataset_splits 4 | 5 | #-------basic Hyparameter---------- 6 | visname: Jester 7 | dataset: Jester 8 | batch_size: 2 9 | test_batch_size: 2 10 | num_workers: 4 11 | learning_rate: 0.01 12 | learning_rate_min: 0.00001 13 | momentum: 0.9 14 | weight_decay: 0.0003 15 | init_epochs: 0 16 | epochs: 100 17 | report_freq: 10 18 | metric_freq: 10 19 | show_cluster_result: 100 20 | optim: SGD 21 | dist: True 22 | vis_feature: True # Visualization? 23 | 24 | scheduler: 25 | name: cosin 26 | patience: 4 27 | warm_up_epochs: 0 28 | loss: 29 | name: CE 30 | labelsmooth: True 31 | mse_weight: 10.0 32 | MultiLoss: True 33 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 34 | distill_lamdb: 1. 35 | 36 | model: 37 | Network: I3DWTrans 38 | pretrained: '' 39 | # resume: '' 40 | resumelr: False 41 | sample_duration: 64 42 | sample_size: 224 43 | grad_clip: 5. 44 | SYNC_BN: 1 45 | w: 10 46 | temper: 0.4 47 | recoupling: True 48 | knn_attention: 0.8 49 | sharpness: True 50 | temp: [ 0.04, 0.07 ] 51 | frp: True 52 | SEHeads: 1 53 | N: 6 54 | 55 | dataset: 56 | type: M 57 | flip: 0.0 58 | rotated: 0.5 59 | angle: (-10, 10) 60 | Blur: False 61 | resize: (256, 256) 62 | crop_size: 224 63 | low_frames: 16 64 | media_frames: 32 65 | high_frames: 48 -------------------------------------------------------------------------------- /config/NTU.yml: -------------------------------------------------------------------------------- 1 | # ''' 2 | # Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | # ''' 4 | 5 | common: 6 | data: /mnt/workspace//Dataset/NTU-RGBD/ 7 | splits: /mnt/workspace//Dataset/NTU-RGBD/dataset_splits/@CS 8 | 9 | #-------basic Hyparameter---------- 10 | visdom: 11 | enable: False 12 | visname: NTU 13 | dataset: NTU #Database name e.g., NTU, THU ... 14 | batch_size: 4 15 | test_batch_size: 6 16 | num_workers: 10 17 | learning_rate: 0.005 18 | learning_rate_min: 0.00001 19 | momentum: 0.9 20 | weight_decay: 0.0003 21 | init_epochs: 0 22 | epochs: 100 23 | report_freq: 10 24 | optim: SGD 25 | dist: True 26 | vis_feature: True # Visualization? 27 | 28 | scheduler: 29 | name: cosin 30 | patience: 4 31 | warm_up_epochs: 3 32 | loss: 33 | name: CE 34 | labelsmooth: True 35 | MultiLoss: True 36 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 37 | distill: 1. 38 | 39 | model: 40 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet 41 | pretrained: '' #./Checkpoints/I3DWTrans-EXP-20211204-211826//model_best.pth.tar' #'../MultiScale/Checkpoints/I3DWTrans-EXP-20211024-171224/model_best.pth.tar' #'../MultiScale/Checkpoints/I3DWTrans-EXP-20211019-124405/model_best.pth.tar' 42 | # resume: '' #'./Checkpoints/I3DWTrans-NTU-M-20211214-195641/model_best-DTN.pth.tar' 43 | resumelr: '' 44 | sample_duration: 64 45 | sample_size: 224 46 | grad_clip: 5. 47 | SYNC_BN: 1 48 | w: 10 49 | temper: 0.5 50 | recoupling: True 51 | knn_attention: 0.7 52 | sharpness: True 53 | temp: [ 0.04, 0.07 ] 54 | frp: True 55 | SEHeads: 1 56 | N: 6 # Number of Transformer Blocks 57 | 58 | #-------Used for fusion network---------- 59 | rgb_checkpoint: './Checkpoints/I3DWTrans-EXP-20211204-211826//model_best.pth.tar' 60 | depth_checkpoint: './Checkpoints/I3DWTrans-EXP-20211204-214434//model_best.pth.tar' 61 | 62 | dataset: 63 | type: M # M: rgb, K: depth 64 | flip: 0.5 65 | rotated: 0.5 66 | angle: (-10, 10) # Rotation angle 67 | Blur: False 68 | resize: (320, 240) 69 | crop_size: 224 70 | low_frames: 16 71 | media_frames: 32 72 | high_frames: 48 73 | 74 | # I3DWTrans-EXP-20211204-211826 M 90.25 75 | # I3DWTrans-EXP-20211204-214434 K 92.81 76 | -------------------------------------------------------------------------------- /config/NetworkConfig.yml: -------------------------------------------------------------------------------- 1 | common: 2 | data: /path/to/dataset/NTU-RGBD 3 | splits: /path/to/dataset/dataset/NTU-RGBD/dataset_splits/@CS # include: train.txt and test.txt 4 | 5 | #-------basic Hyparameter---------- 6 | visdom: 7 | enable: True 8 | visname: NTU 9 | dataset: NTU #Database name e.g., NTU, THUREAD, NvGesture and IsoGD ... 10 | batch_size: 6 11 | test_batch_size: 6 12 | num_workers: 6 13 | learning_rate: 0.01 14 | learning_rate_min: 0.00001 15 | momentum: 0.9 16 | weight_decay: 0.0003 17 | init_epochs: 0 18 | epochs: 100 # if training on IsoGD dataset, set 300 is better. 19 | report_freq: 100 20 | optim: SGD 21 | dist: True 22 | vis_feature: True # Visualization? 23 | 24 | scheduler: 25 | name: cosin 26 | patience: 4 27 | warm_up_epochs: 3 28 | loss: 29 | name: CE 30 | labelsmooth: True 31 | MultiLoss: True 32 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 33 | distill: 1. 34 | 35 | model: 36 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet 37 | pretrained: '' # all of experiments are pre-trained on 20BN Jester V1 dataset except for NTU-RGBD. 38 | resume: '' 39 | resumelr: False 40 | sample_duration: 64 41 | sample_size: 224 42 | grad_clip: 5. 43 | SYNC_BN: 1 44 | w: 10 45 | temper: 0.5 # 0.5 for THUREAD and NTU-RGBD; 0.4 for NvGesture and IsoGD 46 | recoupling: True 47 | knn_attention: 0.7 48 | sharpness: True 49 | temp: [ 0.04, 0.07 ] 50 | frp: True 51 | SEHeads: 1 52 | N: 6 # Number of Transformer Blocks 53 | 54 | #-------Used for fusion network---------- 55 | rgb_checkpoint: '' 56 | depth_checkpoint: '' 57 | 58 | dataset: 59 | type: M # M: rgb, K: depth 60 | flip: 0.5 # set 0.0 for NvGesture and IsoGD 61 | rotated: 0.5 # THUREAD: 0.8, others: 0.5 62 | angle: (-10, 10) # Rotation angle. THUREAD: (-45, 45), others: (-10, 10) 63 | Blur: False 64 | resize: (320, 240) #NTU and THUREAD: (320, 240), others:(256, 256) 65 | crop_size: 224 # THUREAD: 200, others: 224 66 | low_frames: 16 67 | media_frames: 32 68 | high_frames: 48 69 | 70 | -------------------------------------------------------------------------------- /config/NvGesture.yml: -------------------------------------------------------------------------------- 1 | 2 | common: 3 | data: /mnt/workspace/Dataset/NvGesture/ 4 | splits: /mnt/workspace/Dataset/NvGesture/dataset_splits/ 5 | 6 | #-------basic Hyparameter---------- 7 | visdom: 8 | enable: False 9 | visname: NvGesture 10 | 11 | dataset: NvGesture #Database name e.g., NTU, THUREAD ... 12 | batch_size: 4 13 | test_batch_size: 2 14 | num_workers: 10 15 | learning_rate: 0.01 16 | learning_rate_min: 0.00001 17 | momentum: 0.9 18 | weight_decay: 0.0003 19 | init_epochs: 0 20 | epochs: 100 21 | report_freq: 10 22 | optim: SGD 23 | dist: True 24 | vis_feature: True # Visualization? 25 | DEBUG: False 26 | 27 | scheduler: 28 | name: cosin 29 | patience: 4 30 | warm_up_epochs: 3 31 | loss: 32 | name: CE 33 | labelsmooth: True 34 | MultiLoss: True 35 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 36 | distill: 1. 37 | model: 38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet 39 | pretrained: /mnt/workspace/Code/CVPR/Checkpoints/I3DWTrans-NvGesture-K-20221017-113943/model_best-Nv-K.pth.tar-v1 40 | resumelr: False 41 | sample_duration: 64 42 | sample_size: 224 43 | grad_clip: 5. 44 | SYNC_BN: 1 45 | w: 4 # 4 is best for Nv 46 | temper: 0.4 47 | recoupling: True 48 | knn_attention: 0.7 49 | sharpness: True 50 | temp: [ 0.04, 0.07 ] 51 | frp: True 52 | SEHeads: 1 53 | N: 6 # Number of Transformer Blocks 54 | #-------Used for fusion network---------- 55 | rgb_checkpoint: '/mnt/workspace/Code/CVPR/Checkpoints/model_best-Nv-K.pth.tar' 56 | depth_checkpoint: '/mnt/workspace/Code/CVPR/Checkpoints/model_best-Nv-M.pth.tar' 57 | 58 | dataset: 59 | type: M 60 | flip: 0.0 61 | rotated: 0.5 62 | angle: (-10, 10) 63 | Blur: False 64 | resize: (256, 256) 65 | crop_size: 224 66 | low_frames: 16 67 | media_frames: 32 68 | high_frames: 48 69 | 70 | # I3DWTrans-NvGesture-K-20221017-113943 90.83 71 | # I3DWTrans-NvGesture-M-20221018-123442 88.75 -------------------------------------------------------------------------------- /config/THU.yml: -------------------------------------------------------------------------------- 1 | common: 2 | data: /mnt/workspace//Dataset/THU-READ/frames 3 | splits: /mnt/workspace//Dataset/THU-READ/dataset_splits/@2 4 | 5 | 6 | #-------basic Hyparameter---------- 7 | visdom: 8 | enable: False 9 | visname: THU 10 | dataset: THUREAD 11 | batch_size: 6 12 | test_batch_size: 6 13 | num_workers: 6 14 | learning_rate: 0.01 15 | learning_rate_min: 0.00001 16 | momentum: 0.9 17 | weight_decay: 0.0003 18 | init_epochs: 0 19 | epochs: 100 20 | report_freq: 10 21 | optim: SGD 22 | dist: True 23 | vis_feature: True # Visualization? 24 | DEBUG: False 25 | 26 | scheduler: 27 | name: cosin 28 | patience: 4 29 | warm_up_epochs: 3 # 10 may work 30 | loss: 31 | name: CE 32 | labelsmooth: True 33 | MultiLoss: True 34 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] 35 | distill: 1. 36 | 37 | model: 38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet 39 | pretrained: '/mnt/workspace/Code/CVPR/Checkpoints/I3DWTrans-NvGesture-K-20221017-113943/model_best-Nv-K.pth.tar-v1' #'./Checkpoints/I3DWTrans-THUREAD-M-20211211-194730/model_best.pth.tar' 40 | # resume: ./Checkpoints/FusionNet-THUREAD-M-20211213-195422/model_best.pth.tar 41 | resumelr: False 42 | sample_duration: 64 43 | sample_size: 224 44 | grad_clip: 5. 45 | SYNC_BN: 1 46 | w: 10 47 | temper: 0.5 48 | recoupling: True 49 | knn_attention: 0.7 50 | sharpness: True 51 | temp: [ 0.04, 0.07 ] 52 | frp: True 53 | SEHeads: 1 54 | N: 6 # Number of Transformer Blocks 55 | 56 | rgb_checkpoint: './Checkpoints/I3DWTrans-THUREAD-M-20211211-194730/model_best.pth.tar' 57 | depth_checkpoint: './Checkpoints/I3DWTrans-THUREAD-K-20211211-124150/model_best.pth.tar' 58 | 59 | dataset: 60 | type: M 61 | flip: 0.5 62 | rotated: 0.8 63 | angle: (-45, 45) 64 | Blur: False 65 | resize: (320, 240) 66 | crop_size: 200 67 | low_frames: 16 68 | media_frames: 32 69 | high_frames: 48 70 | 71 | # I3DWTrans-THUREAD-K-20211212-122941 K 83.33 72 | # I3DWTrans-THUREAD-M-20211211-194730 M 82.08 73 | 74 | # Local + Global + multi loss I3DWTrans_SAtt-EXP-20210823-115604 55.0 75 | # THU-READPre I3DWTrans-EXP-20210825-000754 74.58 76 | # READPreonlyGlobal I3DWTrans-EXP-20210825-085754 72.08% 77 | # THU-READPreSingleLoss I3DWTrans-EXP-20210825-154036 72.50 78 | # READPreNew I3DWTrans-EXP-20210827-110508 75.00 79 | # THU-READPreNewTopKSche I3DWTrans-EXP-20210828-215011 75.83 80 | # THU-READPreNewTopKSche depth I3DWTrans-EXP-20210829-125851 80.0 81 | # THU-READPreNewTopKSchecross I3DWTrans-EXP-20210829-161637 77.08 82 | # THU-READPreCrossSche I3DWTrans-EXP-20210830-002827 74.58 83 | # THU-READPreCrossTopKSchegradclip 75.42 84 | # @2 THU-READDatt I3DWTrans-EXP-20210831-214258 77.92 85 | # THU-READDattSize(320, 320) I3DWTrans-EXP-20210901-014119 75.00 --> no work 86 | # THU-READDatt depth I3DWTrans-EXP-20210901-014025 79.58 87 | # THU-READDatt@4 I3DWTrans-EXP-20210901-092801 76.25 88 | # THU-READDatt@1Blur I3DWTrans-EXP-20210909-105432 74.17 89 | 90 | # THU-READDatt@1New M I3DWTrans-EXP-20210911-103936 80.00 91 | # THU-READDatt@1New K I3DWTrans-EXP-20210911-104131 75.42 92 | # Fusion add @1 84.17 | strategy 86.67 93 | # THU-READDatt@2New M I3DWTrans-EXP-20210910-224232 81.25 94 | # THU-READDatt@2New K I3DWTrans-EXP-20210910-223415 82.08 95 | # Fusion add @2 86.25 | strategy 90.41 96 | # THU-READDatt@3New M I3DWTrans-EXP-20210911-191641 77.50 97 | # THU-READDatt@3New K I3DWTrans-EXP-20210911-191748 77.92 98 | # Fusion add @3 | strategy 82.02 99 | # THU-READDatt@4New M I3DWTrans-EXP-20210912-005825 82.92 100 | # THU-READDatt@4New K I3DWTrans-EXP-20210912-101129 76.25 101 | # Fusion add @4 | strategy 102 | 103 | # @1 M 104 | # @1 K 105 | # fusion 85.42 106 | 107 | # @2 M 84.17 108 | # @2 K 82.91 109 | # fusion 88.75 110 | 111 | # ./Checkpoints/FusionNet-EXP-20211008-232502/model_best.pth.tar 112 | # @3 M 81.94 113 | # @3 K 84.86 114 | # fusion 85.69 -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | from .config import Config -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import yaml 6 | # from easydict import EasyDict as edict 7 | def Config(args): 8 | print() 9 | print('='*80) 10 | with open(args.config) as f: 11 | config = yaml.load(f, Loader=yaml.FullLoader) 12 | for dic in config: 13 | for k, v in config[dic].items(): 14 | setattr(args, k, v) 15 | print(k, ':\t', v) 16 | print('='*80) 17 | print() 18 | return args -------------------------------------------------------------------------------- /data/data_preprose.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import cv2 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import os, glob, re 10 | import argparse 11 | import csv 12 | import random 13 | from tqdm import tqdm 14 | from multiprocessing import Process 15 | import shutil 16 | from multiprocessing import Pool, cpu_count 17 | 18 | def resize_pos(center,src_size,tar_size): 19 | x, y = center 20 | w1=src_size[1] 21 | h1=src_size[0] 22 | w=tar_size[1] 23 | h=tar_size[0] 24 | 25 | y1 = int((h / h1) * y) 26 | x1 = int((w / w1) * x) 27 | return (x1, y1) 28 | 29 | ''' 30 | For NTU-RGBD 31 | ''' 32 | def video2image(v_p): 33 | m_path='nturgb+d_depth_masked/' 34 | img_path = os.path.join('Images', v_p[:-4].split('/')[-1]) 35 | if not os.path.exists(img_path): 36 | os.makedirs(img_path) 37 | cap = cv2.VideoCapture(v_p) 38 | suc, frame = cap.read() 39 | frame_count = 1 40 | while suc: 41 | # frame [1920, 1080] 42 | mask_path = os.path.join(m_path, v_p[:-8].split('/')[-1], 'MDepth-%08d.png'%frame_count) 43 | mask = cv2.imread(mask_path) 44 | mask = mask*255 45 | w, h, c = mask.shape 46 | h2, w2, _ = frame.shape 47 | ori = frame 48 | frame = cv2.resize(frame, (h, w)) 49 | h1, w1, _ = frame.shape 50 | 51 | # image = cv2.add(frame, mask) 52 | 53 | # find contour 54 | mask = cv2.erode(mask, np.ones((3, 3),np.uint8)) 55 | mask = cv2.dilate(mask ,np.ones((10, 10),np.uint8)) 56 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 57 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 58 | 59 | # Find Max Maxtri 60 | Idx = [] 61 | for i in range(len(contours)): 62 | Area = cv2.contourArea(contours[i]) 63 | if Area > 500: 64 | Idx.append(i) 65 | # max_idx = np.argmax(area) 66 | 67 | centers = [] 68 | for i in Idx: 69 | rect = cv2.minAreaRect(contours[i]) 70 | center, (h, w), degree = rect 71 | centers.append(center) 72 | 73 | finall_center = np.int0(np.array(centers)) 74 | c_x = min(finall_center[:, 0]) 75 | c_y = min(finall_center[:, 1]) 76 | 77 | center = (c_x, c_y) 78 | # finall_center = finall_center.sum(0)/len(finall_center) 79 | 80 | # rect = cv2.minAreaRect(contours[max_idx]) 81 | # center, (h, w), degree = rect 82 | # center = tuple(np.int0(finall_center)) 83 | center_new = resize_pos(center, (h1, w1), (h2, w2)) 84 | 85 | #----------------------------------- 86 | # Image Crop 87 | #----------------------------------- 88 | # ori = cv2.circle(ori, center_new, 2, (0, 0, 255), 2) 89 | crop_y, crop_x = h2//2, w2//2 90 | # print(crop_x, crop_y) 91 | left = center_new[0] - crop_x//2 if center_new[0] - crop_x//2 > 0 else 0 92 | top = center_new[1] - crop_y//2 if center_new[1] - crop_y//2 > 0 else 0 93 | # ori = cv2.circle(ori, (left, top), 2, (0, 0, 255), 2) 94 | # cv2.imwrite('demo/ori.png', ori) 95 | crop_w = left + crop_x if left + crop_x < w2 else w2 96 | crop_h = top + crop_y if top + crop_y < h2 else h2 97 | rect = (left, top, crop_w, crop_h) 98 | image = Image.fromarray(cv2.cvtColor(ori, cv2.COLOR_BGR2RGB)) 99 | image = image.crop(rect) 100 | image.save('{}/{:0>6d}.jpg'.format(img_path, frame_count)) 101 | 102 | # box = cv2.boxPoints(rect) 103 | # box = np.int0(box) 104 | # drawImage = frame.copy() 105 | # drawImage = cv2.drawContours(drawImage, [box], 0, (255, 0, 0), -1) # draw one contour 106 | # cv2.imwrite('demo/drawImage.png', drawImage) 107 | # frame = cv2.circle(frame, center, 2, (0, 255, 255), 2) 108 | # cv2.imwrite('demo/Image.png', frame) 109 | # cv2.imwrite('demo/mask.png', mask) 110 | # ori = cv2.circle(ori, center_new, 2, (0, 0, 255), 2) 111 | # cv2.imwrite('demo/ORI.png', ori) 112 | # cv2.imwrite('demo/maskImage.png', image) 113 | 114 | # cv2.imwrite('{}/{:0>6d}.jpg'.format(img_path, frame_count), frame) 115 | frame_count += 1 116 | suc, frame = cap.read() 117 | cap.release() 118 | 119 | ''' 120 | For IsoGD, Nv... 121 | ''' 122 | # def video2image(v_p): 123 | # img_path = v_p[:-4].replace('UCF-101', 'UCF-101-images') 124 | # if not os.path.exists(img_path): 125 | # os.makedirs(img_path) 126 | # cap = cv2.VideoCapture(v_p) 127 | # suc, frame = cap.read() 128 | # frame_count = 0 129 | # while suc: 130 | # h, w, c = frame.shape 131 | # cv2.imwrite('{}/{:0>6d}.jpg'.format(img_path, frame_count), frame) 132 | # frame_count += 1 133 | # suc, frame = cap.read() 134 | # cap.release() 135 | 136 | def GeneratLabel(sample): 137 | path = sample[:-4].split('/')[-1] 138 | cap = cv2.VideoCapture(sample) 139 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 140 | label = int(sample.split('A')[-1][:3])-1 141 | txt = ' '.join(map(str, [path, frame_count, label, '\n'])) 142 | if args.proto == '@CV': 143 | if 'C001' in sample: 144 | with open(args.validTXT, 'a') as vf: 145 | vf.writelines(txt) 146 | else: 147 | with open(args.trainTXT, 'a') as tf: 148 | tf.writelines(txt) 149 | elif args.proto == '@CS': 150 | pattern = re.findall(r'P\d+', sample) 151 | if int(pattern[0][1:]) in [1, 2, 4, 5, 8, 9, 13, 14, 15,16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38]: 152 | with open(args.trainTXT, 'a') as tf: 153 | tf.writelines(txt) 154 | else: 155 | with open(args.validTXT, 'a') as vf: 156 | vf.writelines(txt) 157 | 158 | def ResizeImage(img_path): 159 | save_path = img_path.replace('Images', 'ImagesResize') 160 | if not os.path.exists(save_path): 161 | os.makedirs(save_path) 162 | for img in os.listdir(img_path): 163 | im_path = os.path.join(img_path, img) 164 | image = cv2.imread(im_path) 165 | image = cv2.resize(image, (320, 240)) 166 | cv2.imwrite(os.path.join(save_path, img), image) 167 | 168 | data_root = '/mnt/workspace/Dataset/NTU-RGBD' 169 | Image_paths = glob.glob(os.path.join(data_root, 'nturgb+d_rgb/*.avi')) 170 | print('Total Images: {}'.format(len(Image_paths))) 171 | mask_paths = os.listdir(os.path.join(data_root, 'nturgb+d_depth_masked/')) 172 | print('Total Masks: {}'.format(len(mask_paths))) 173 | 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--proto', default='@CS') 176 | args = parser.parse_args() 177 | 178 | 179 | #--------------------------------------------- 180 | # Generate label .txt 181 | #--------------------------------------------- 182 | trainTXT = os.path.join(data_root, 'dataset_splits', args.proto, 'train.txt') 183 | validTXT = os.path.join(data_root, 'dataset_splits', args.proto, 'valid.txt') 184 | args.trainTXT = trainTXT 185 | args.validTXT = validTXT 186 | if os.path.isfile(args.trainTXT): 187 | os.system('rm {}'.format(args.trainTXT)) 188 | if os.path.isfile(args.validTXT): 189 | os.system('rm {}'.format(args.validTXT)) 190 | 191 | with Pool(20) as pool: 192 | for a in tqdm(pool.imap_unordered(GeneratLabel, Image_paths), total=len(Image_paths), desc='Processes'): 193 | if a is not None: 194 | pass 195 | print('Write file list done'.center(80, '*')) 196 | 197 | #--------------------------------------------- 198 | # video --> Images 199 | #--------------------------------------------- 200 | print(len(Image_paths)) 201 | with Pool(20) as pool: 202 | for a in tqdm(pool.imap_unordered(video2image, Image_paths), total=len(Image_paths), desc='Processes'): 203 | if a is not None: 204 | pass 205 | print('Write Image done'.center(80, '*')) 206 | 207 | #--------------------------------------------- 208 | # Images size to (320, 240) 209 | #--------------------------------------------- 210 | trainTXT = '/mnt/workspace/Dataset/NTU-RGBD/dataset_splits/@CS/train.txt' 211 | validTXT = '/mnt/workspace/Dataset/NTU-RGBD/dataset_splits/@CS/valid.txt' 212 | Image_paths = ['./Images/'+ l.split()[0] for l in open(validTXT, 'r').readlines()] 213 | with Pool(40) as pool: 214 | for a in tqdm(pool.imap_unordered(ResizeImage, Image_paths), total=len(Image_paths), desc='Processes'): 215 | if a is not None: 216 | pass 217 | print('Write Image done'.center(80, '*')) 218 | 219 | # data_root = '/mnt/workspace/Dataset/UCF-101/' 220 | # label_dict = dict([(lambda x: (x[1], int(x[0])-1))(l.strip().split(' ')) for l in open(data_root + 'dataset_splits/lableind.txt').readlines()]) 221 | # print(label_dict) 222 | 223 | # def split_func(file_list): 224 | # class_list = [] 225 | # fl = open(file_list).readlines() 226 | # for d in tqdm(fl): 227 | # path = d.strip().split()[0][:-4] 228 | # label = label_dict[path.split('/')[0]] 229 | # frame_num = len(os.listdir(os.path.join(data_root, 'UCF-101-images', path))) 230 | # class_list.append([path, str(frame_num), str(label), '\n']) 231 | # return class_list 232 | 233 | # def save_list(file_list, file_name): 234 | # with open(file_name, 'w') as f: 235 | # class_list = split_func(file_list) 236 | # for l in class_list: 237 | # f.writelines(' '.join(l)) 238 | 239 | # prot = '@3' 240 | # data_train_split = data_root + f'dataset_splits/{prot}/trainlist.txt' 241 | # data_test_split = data_root + f'dataset_splits/{prot}/testlist.txt' 242 | 243 | # train_file_name = data_root + f'dataset_splits/{prot}/train.txt' 244 | # test_file_name = data_root + f'dataset_splits/{prot}/valid.txt' 245 | # save_list(data_train_split, train_file_name) 246 | # save_list(data_test_split, test_file_name) 247 | 248 | 249 | -------------------------------------------------------------------------------- /demo/decouple_recouple.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/MotionRGBD/b13673a10e3f259ddef4911a2a91b6eedaf104a1/demo/decouple_recouple.jpg -------------------------------------------------------------------------------- /demo/pipline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/MotionRGBD/b13673a10e3f259ddef4911a2a91b6eedaf104a1/demo/pipline.jpg -------------------------------------------------------------------------------- /demo/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | from .datasets import * 5 | from .model import * -------------------------------------------------------------------------------- /lib/datasets/IsoGD.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .base import Datasets 7 | from torchvision import transforms, set_image_backend 8 | import random, os 9 | from PIL import Image 10 | import numpy as np 11 | 12 | class IsoGDData(Datasets): 13 | def __init__(self, args, ground_truth, modality, phase='train'): 14 | super(IsoGDData, self).__init__(args, ground_truth, modality, phase) 15 | def __getitem__(self, index): 16 | """ 17 | Args: 18 | index (int): Index 19 | Returns: 20 | tuple: (image, target) where target is class_index of the target class. 21 | """ 22 | sl = self.get_sl(self.inputs[index][1]) 23 | self.data_path = os.path.join(self.dataset_root, self.typ, self.inputs[index][0]) 24 | if self.typ == 'depth': 25 | self.data_path = self.data_path.replace('M_', 'K_') 26 | 27 | if self.args.Network == 'FusionNet': 28 | assert self.typ == 'rgb' 29 | self.data_path1 = self.data_path.replace('rgb', 'depth') 30 | self.data_path1 = self.data_path1.replace('M', 'K') 31 | 32 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 33 | self.clip1, skgmaparr1 = self.image_propose(self.data_path1, sl) 34 | return (self.clip.permute(0, 3, 1, 2), skgmaparr), (self.clip1.permute(0, 3, 1, 2), skgmaparr1), self.inputs[index][2], self.inputs[index][0] 35 | 36 | else: 37 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 38 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.inputs[index][0] 39 | 40 | def __len__(self): 41 | return len(self.inputs) 42 | -------------------------------------------------------------------------------- /lib/datasets/Jester.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .base import Datasets 7 | from torchvision import transforms, set_image_backend 8 | import random, os 9 | from PIL import Image 10 | import numpy as np 11 | import logging 12 | np.random.seed(123) 13 | 14 | class JesterData(Datasets): 15 | def __init__(self, args, ground_truth, modality, phase='train'): 16 | super(JesterData, self).__init__(args, ground_truth, modality, phase) 17 | 18 | def LoadKeypoints(self): 19 | if self.phase == 'train': 20 | kpt_file = os.path.join(self.dataset_root, self.args.splits, 'train_kp.data') 21 | else: 22 | kpt_file = os.path.join(self.dataset_root, self.args.splits, 'valid_kp.data') 23 | with open(kpt_file, 'r') as f: 24 | kpt_data = [(lambda arr: (os.path.join(self.dataset_root, self.typ, self.phase, arr[0]), list(map(lambda x: int(float(x)), arr[1:]))))(l[:-1].split()) for l in f.readlines()] 25 | kpt_data = dict(kpt_data) 26 | 27 | for k, v in kpt_data.items(): 28 | pose = v[:18*2] 29 | r_hand = v[18*2: 18*2+21*2] 30 | l_hand = v[18*2+21*2: 18*2+21*2+21*2] 31 | kpt_data[k] = {'people': [{'pose_keypoints_2d': pose, 'hand_right_keypoints_2d': r_hand, 'hand_left_keypoints_2d': l_hand}]} 32 | 33 | logging.info('Load Keypoints files Done, Total: {}'.format(len(kpt_data))) 34 | return kpt_data 35 | def get_path(self, imgs_path, a): 36 | return os.path.join(imgs_path, "%05d.jpg" % int(a + 1)) 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | Returns: 42 | tuple: (image, target) where target is class_index of the target class. 43 | """ 44 | sl = self.get_sl(self.inputs[index][1]) 45 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0]) 46 | # self.clip = self.image_propose(self.data_path, sl) 47 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 48 | 49 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path 50 | 51 | def __len__(self): 52 | return len(self.inputs) 53 | -------------------------------------------------------------------------------- /lib/datasets/NTU.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .base import Datasets 7 | from torchvision import transforms, set_image_backend 8 | import random, os 9 | from PIL import Image 10 | import numpy as np 11 | 12 | class NTUData(Datasets): 13 | def __init__(self, args, ground_truth, modality, phase='train'): 14 | super(NTUData, self).__init__(args, ground_truth, modality, phase) 15 | 16 | def __getitem__(self, index): 17 | """ 18 | Args: 19 | index (int): Index 20 | Returns: 21 | tuple: (image, target) where target is class_index of the target class. 22 | """ 23 | sl = self.get_sl(self.inputs[index][1]) 24 | 25 | if self.typ == 'rgb': 26 | self.data_path = os.path.join(self.dataset_root, 'Images', self.inputs[index][0]) 27 | 28 | if self.typ == 'depth': 29 | self.data_path = os.path.join(self.dataset_root, 'nturgb+d_depth_masked', self.inputs[index][0][:-4]) 30 | 31 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 32 | 33 | if self.args.Network == 'FusionNet': 34 | assert self.typ == 'rgb' 35 | self.data_path = os.path.join(self.dataset_root, 'nturgb+d_depth_masked', self.inputs[index][0][:-4]) 36 | self.clip1, skgmaparr1 = self.image_propose(self.data_path, sl) 37 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), \ 38 | self.inputs[index][2], self.data_path 39 | 40 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.inputs[index][0] 41 | 42 | def get_path(self, imgs_path, a): 43 | 44 | if self.typ == 'rgb': 45 | return os.path.join(imgs_path, "%06d.jpg" % int(a + 1)) 46 | else: 47 | return os.path.join(imgs_path, "MDepth-%08d.png" % int(a + 1)) 48 | 49 | def __len__(self): 50 | return len(self.inputs) 51 | -------------------------------------------------------------------------------- /lib/datasets/NvGesture.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .base import Datasets 7 | from torchvision import transforms, set_image_backend 8 | import random, os 9 | from PIL import Image 10 | import numpy as np 11 | import logging 12 | set_image_backend('accimage') 13 | np.random.seed(123) 14 | 15 | class NvData(Datasets): 16 | def __init__(self, args, ground_truth, modality, phase='train'): 17 | super(NvData, self).__init__(args, ground_truth, modality, phase) 18 | def transform_params(self, resize=(320, 240), crop_size=224, flip=0.5): 19 | if self.phase == 'train': 20 | left, top = random.randint(10, resize[0] - crop_size), random.randint(10, resize[1] - crop_size) 21 | is_flip = True if random.uniform(0, 1) < flip else False 22 | else: 23 | left, top = 32, 32 24 | is_flip = False 25 | return (left, top, left + crop_size, top + crop_size), is_flip 26 | 27 | def __getitem__(self, index): 28 | """ 29 | Args: 30 | index (int): Index 31 | Returns: 32 | tuple: (image, target) where target is class_index of the target class. 33 | """ 34 | sl = self.get_sl(self.inputs[index][1]) 35 | self.data_path = os.path.join(self.dataset_root, self.typ, self.inputs[index][0]) 36 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 37 | 38 | if self.args.Network == 'FusionNet': 39 | assert self.typ == 'rgb' 40 | self.data_path = self.data_path.replace('rgb', 'depth') 41 | self.clip1, skgmaparr1 = self.image_propose(self.data_path, sl) 42 | 43 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), self.inputs[index][2], self.data_path 44 | 45 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path 46 | 47 | def __len__(self): 48 | return len(self.inputs) 49 | -------------------------------------------------------------------------------- /lib/datasets/THU_READ.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .base import Datasets 7 | from torchvision import transforms, set_image_backend 8 | import random, os 9 | from PIL import Image 10 | import numpy as np 11 | import logging 12 | 13 | np.random.seed(123) 14 | 15 | 16 | class THUREAD(Datasets): 17 | def __init__(self, args, ground_truth, modality, phase='train'): 18 | super(THUREAD, self).__init__(args, ground_truth, modality, phase) 19 | 20 | def __getitem__(self, index): 21 | """ 22 | Args: 23 | index (int): Index 24 | Returns: 25 | tuple: (image, target) where target is class_index of the target class. 26 | """ 27 | sl = self.get_sl(self.inputs[index][1]) 28 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0]) 29 | self.clip, skgmaparr = self.image_propose(self.data_path, sl) 30 | 31 | if self.args.Network == 'FusionNet': 32 | assert self.typ == 'rgb' 33 | self.data_path1 = self.data_path.replace('RGB', 'Depth') 34 | self.data_path1 = '/'.join(self.data_path1.split('/')[:-1]) + '/{}'.format( 35 | self.data_path1.split('/')[-1].replace('Depth', 'D')) 36 | 37 | self.clip1, skgmaparr1 = self.image_propose(self.data_path1, sl) 38 | 39 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), \ 40 | self.inputs[index][2], self.data_path 41 | 42 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path 43 | 44 | def __len__(self): 45 | return len(self.inputs) 46 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | from .build import * 6 | -------------------------------------------------------------------------------- /lib/datasets/base.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/lib/datasets/base.py 4 | ''' 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, set_image_backend 9 | import torch.nn.functional as F 10 | 11 | from PIL import Image 12 | from PIL import ImageFilter, ImageOps 13 | import os, glob 14 | import math, random 15 | import numpy as np 16 | import logging 17 | from tqdm import tqdm as tqdm 18 | import pandas as pd 19 | from multiprocessing import Pool, cpu_count 20 | import multiprocessing as mp 21 | import cv2 22 | import json 23 | from scipy.ndimage.filters import gaussian_filter 24 | 25 | # import functools 26 | import matplotlib.pyplot as plt # For graphics 27 | np.random.seed(123) 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | return img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | class Normaliztion(object): 49 | """ 50 | same as mxnet, normalize into [-1, 1] 51 | image = (image - 127.5)/128 52 | """ 53 | 54 | def __call__(self, Image): 55 | new_video_x = (Image - 127.5) / 128 56 | return new_video_x 57 | 58 | class Datasets(Dataset): 59 | global kpt_dict 60 | def __init__(self, args, ground_truth, modality, phase='train'): 61 | 62 | def get_data_list_and_label(data_df): 63 | return [(lambda arr: (arr[0], int(arr[1]), int(arr[2])))(i[:-1].split(' ')) 64 | for i in open(data_df).readlines()] 65 | 66 | self.dataset_root = args.data 67 | self.sample_duration = args.sample_duration 68 | self.sample_size = args.sample_size 69 | self.phase = phase 70 | args.phase = phase 71 | self.typ = modality 72 | self.args = args 73 | self._w = args.w 74 | 75 | self.transform = transforms.Compose([Normaliztion(), transforms.ToTensor()]) 76 | 77 | self.inputs = list(filter(lambda x: x[1] > 16, get_data_list_and_label(ground_truth))) 78 | self.inputs = list(self.inputs) 79 | if phase == 'train': 80 | while len(self.inputs) % (args.batch_size * args.nprocs) != 0: 81 | sample = random.choice(self.inputs) 82 | self.inputs.append(sample) 83 | logging.info('Training Data Size is: {}'.format(len(self.inputs))) 84 | frames = [n[1] for n in self.inputs] 85 | logging.info('Average Train Data frames are: {}, max frames: {}, min frames: {}'.format(sum(frames)//len(self.inputs), max(frames), min(frames))) 86 | else: 87 | logging.info('Validation Data Size is: {} '.format(len(self.inputs))) 88 | frames = [n[1] for n in self.inputs] 89 | logging.info('Average Train Data frames are: {}, max frames: {}, min frames: {}'.format( 90 | sum(frames) // len(self.inputs), max(frames), min(frames))) 91 | 92 | def transform_params(self, resize=(320, 240), crop_size=224, flip=0.5): 93 | if self.phase == 'train': 94 | left, top = np.random.randint(0, resize[0] - crop_size), np.random.randint(0, resize[1] - crop_size) 95 | is_flip = True if np.random.uniform(0, 1) < flip else False 96 | else: 97 | left, top = (resize[0] - crop_size) // 2, (resize[1] - crop_size) // 2 98 | 99 | is_flip = False 100 | return (left, top, left + crop_size, top + crop_size), is_flip 101 | 102 | def rotate(self, image, angle, center=None, scale=1.0): 103 | (h, w) = image.shape[:2] 104 | if center is None: 105 | center = (w / 2, h / 2) 106 | M = cv2.getRotationMatrix2D(center, angle, scale) 107 | rotated = cv2.warpAffine(image, M, (w, h)) 108 | return rotated 109 | 110 | def get_path(self, imgs_path, a): 111 | return os.path.join(imgs_path, "%06d.jpg" % a) 112 | 113 | def depthProposess(self, img): 114 | h2, w2 = img.shape 115 | 116 | mask = img.copy() 117 | mask = cv2.erode(mask, np.ones((3, 3), np.uint8)) 118 | mask = cv2.dilate(mask, np.ones((10, 10), np.uint8)) 119 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 120 | # Find Max Maxtri 121 | Idx = [] 122 | for i in range(len(contours)): 123 | Area = cv2.contourArea(contours[i]) 124 | if Area > 500: 125 | Idx.append(i) 126 | centers = [] 127 | 128 | for i in Idx: 129 | rect = cv2.minAreaRect(contours[i]) 130 | center, (h, w), degree = rect 131 | centers.append(center) 132 | 133 | finall_center = np.int0(np.array(centers)) 134 | c_x = min(finall_center[:, 0]) 135 | c_y = min(finall_center[:, 1]) 136 | center = (c_x, c_y) 137 | 138 | crop_x, crop_y = 320, 240 139 | left = center[0] - crop_x // 2 if center[0] - crop_x // 2 > 0 else 0 140 | top = center[1] - crop_y // 2 if center[1] - crop_y // 2 > 0 else 0 141 | crop_w = left + crop_x if left + crop_x < w2 else w2 142 | crop_h = top + crop_y if top + crop_y < h2 else h2 143 | rect = (left, top, crop_w, crop_h) 144 | image = Image.fromarray(img) 145 | image = image.crop(rect) 146 | return image 147 | 148 | def image_propose(self, data_path, sl): 149 | sample_size = self.sample_size 150 | resize = eval(self.args.resize) 151 | crop_rect, is_flip = self.transform_params(resize=resize, crop_size=self.args.crop_size, flip=self.args.flip) # no flip 152 | if np.random.uniform(0, 1) < self.args.rotated and self.phase == 'train': 153 | r, l = eval(self.args.angle) 154 | rotated = np.random.randint(r, l) 155 | else: 156 | rotated = 0 157 | 158 | def transform(img): 159 | img = np.asarray(img) 160 | if img.shape[-1] != 3: 161 | img = np.uint8(255 * img) 162 | img = self.depthProposess(img) 163 | img = cv2.applyColorMap(np.asarray(img), cv2.COLORMAP_JET) 164 | img = self.rotate(np.asarray(img), rotated) 165 | img = Image.fromarray(img) 166 | img = img.resize(resize) 167 | img = img.crop(crop_rect) 168 | if self.args.Blur and self.args.phase == 'train': 169 | img = GaussianBlur()(img) 170 | if is_flip: 171 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 172 | return np.array(img.resize((sample_size, sample_size))) 173 | 174 | def Sample_Image(imgs_path, sl): 175 | frams = [] 176 | for a in sl: 177 | try: 178 | ori_image = Image.open(self.get_path(imgs_path, a)) 179 | except: 180 | ori_image = Image.open(os.path.join(imgs_path, "MDepth-%08d.png" % int(a+1))) # For NTU fusion 181 | img = transform(ori_image) 182 | frams.append(self.transform(img).view(3, sample_size, sample_size, 1)) 183 | skgmaparr = DynamicImage(frams, dynamic_only=False) 184 | return torch.cat(frams, dim=3).type(torch.FloatTensor), skgmaparr.unsqueeze(0) 185 | 186 | def DynamicImage(frames, dynamic_only): # frames: [[3, 224, 224, 1], ] 187 | def tensor_arr_rp(arr): 188 | l = len(arr) 189 | statics = [] 190 | def tensor_rankpooling(video_arr, lamb=1.): 191 | def get_w(N): 192 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)] 193 | 194 | # re = torch.zeros(video_arr[0].size(0), 1, video_arr[0].size(2), video_arr[0].size(3)).cuda() 195 | re = torch.zeros(video_arr[0].size()) 196 | for a, b in zip(video_arr, get_w(len(video_arr))): 197 | # a = transforms.Grayscale(1)(a) 198 | re += a * b 199 | re = F.relu(re) * lamb 200 | re -= torch.min(re) 201 | re = re / torch.max(re) if torch.max(re) != 0 else re / (torch.max(re) + 0.00001) 202 | 203 | re = transforms.Grayscale(1)(re.squeeze()) 204 | # Static Attention 205 | static = torch.where(re > torch.mean(re), re, torch.full_like(re, 0)) 206 | static = np.asarray(static.squeeze()) 207 | # static = cv2.morphologyEx(static, cv2.MORPH_OPEN, kernel=np.ones((3, 3), np.uint8)) 208 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) 209 | static = cv2.erode(static, kernel) 210 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) 211 | static = cv2.dilate(static, kernel) 212 | static -= np.min(static) 213 | static = static / np.max(static) if np.max(static) != 0 else static / (np.max(static) + 0.00001) 214 | statics.append(torch.from_numpy(static).unsqueeze(0)) 215 | return re 216 | 217 | return [tensor_rankpooling(arr[i:i + self._w]) for i in range(l)], statics 218 | arrrp, statics = tensor_arr_rp(frames) 219 | arrrp = torch.cat(arrrp, dim=0) # torch.Size([64, 224, 224]) 220 | t, h, w = arrrp.shape 221 | mask = torch.zeros(self._w - 1, h, w) 222 | garrs = torch.cat((mask, arrrp), dim=0)[:t, :] 223 | statics = torch.cat(statics) 224 | statics = torch.cat((mask, statics))[:t, :] 225 | if dynamic_only: 226 | return garrs 227 | return (garrs + statics) * statics 228 | return Sample_Image(data_path, sl) 229 | 230 | def get_sl(self, clip): 231 | sn = self.sample_duration 232 | if self.phase == 'train': 233 | f = lambda n: [(lambda n, arr: n if arr == [] else np.random.choice(arr))(n * i / sn, 234 | range(int(n * i / sn), 235 | max(int(n * i / sn) + 1, 236 | int(n * ( 237 | i + 1) / sn)))) 238 | for i in range(sn)] 239 | else: 240 | f = lambda n: [(lambda n, arr: n if arr == [] else int(np.mean(arr)))(n * i / sn, range(int(n * i / sn), 241 | max(int( 242 | n * i / sn) + 1, 243 | int(n * ( 244 | i + 1) / sn)))) 245 | for i in range(sn)] 246 | return f(int(clip)) 247 | def __getitem__(self, index): 248 | """ 249 | Args: 250 | index (int): Index 251 | Returns: 252 | tuple: (image, target) where target is class_index of the target class. 253 | """ 254 | sl = self.get_sl(self.inputs[index][1]) 255 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0]) 256 | self.clip = self.image_propose(self.data_path, sl) 257 | return self.clip.permute(0, 3, 1, 2), self.inputs[index][2] 258 | def __len__(self): 259 | return len(self.inputs) 260 | 261 | if __name__ == '__main__': 262 | import argparse 263 | from config import Config 264 | from lib import * 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument('--config', default='', help='Place config Congfile!') 267 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?') 268 | parser.add_argument('--local_rank', type=int, default=0) 269 | parser.add_argument('--nprocs', type=int, default=1) 270 | 271 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?') 272 | parser.add_argument('--save_output', action='store_true', help='Save logits?') 273 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo') 274 | 275 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability') 276 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment name') 277 | parser.add_argument('--seed', type=int, default=123, help='random seed') 278 | args = parser.parse_args() 279 | args = Config(args) 280 | np.random.seed(args.seed) 281 | torch.manual_seed(args.seed) 282 | args.dist = False 283 | args.eval_only = True 284 | args.test_batch_size = 1 285 | 286 | valid_queue, valid_sampler = build_dataset(args, phase='val') 287 | for step, (inputs, heatmap, target, _) in enumerate(valid_queue): 288 | print(inputs.shape) 289 | input() -------------------------------------------------------------------------------- /lib/datasets/build.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from .distributed_sampler import DistributedSampler 7 | from .IsoGD import IsoGDData 8 | from .NvGesture import NvData 9 | from .THU_READ import THUREAD 10 | from .Jester import JesterData 11 | from .NTU import NTUData 12 | import logging 13 | 14 | def build_dataset(args, phase): 15 | modality = dict( 16 | M='rgb', 17 | K='depth', 18 | F='Flow' 19 | ) 20 | assert args.type in modality, 'Error in modality!' 21 | Datasets_func = dict( 22 | NvGesture=NvData, 23 | IsoGD=IsoGDData, 24 | THUREAD=THUREAD, 25 | Jester=JesterData, 26 | NTU=NTUData 27 | ) 28 | assert args.dataset in Datasets_func, 'Error in dataset Function!' 29 | if args.local_rank == 0: 30 | logging.info('Dataset:{}, Modality:{}'.format(args.dataset, modality[args.type])) 31 | 32 | if args.dataset in ['THUREAD'] and args.type == 'K': 33 | splits = args.splits + '/depth_{}_lst.txt'.format(phase) 34 | else: 35 | splits = args.splits + '/{}.txt'.format(phase) 36 | dataset = Datasets_func[args.dataset](args, splits, modality[args.type], phase=phase) 37 | if args.dist: 38 | data_sampler = DistributedSampler(dataset) 39 | else: 40 | data_sampler = None 41 | 42 | if phase == 'train': 43 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, 44 | shuffle=(data_sampler is None), 45 | sampler=data_sampler, pin_memory=True), data_sampler 46 | else: 47 | # if args.eval_only and args.nprocs == 1: 48 | # args.test_batch_size = 8 49 | return torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, num_workers=args.num_workers, 50 | shuffle=False, 51 | sampler=data_sampler, pin_memory=True, drop_last=False if args.eval_only else True), data_sampler -------------------------------------------------------------------------------- /lib/datasets/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/samplers/distributed_sampler.py 4 | ''' 5 | 6 | import math 7 | import torch 8 | from torch.utils.data import DistributedSampler as _DistributedSampler 9 | 10 | class DistributedSampler(_DistributedSampler): 11 | 12 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 13 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 14 | self.shuffle = shuffle 15 | 16 | def __iter__(self): 17 | # deterministically shuffle based on epoch 18 | if self.shuffle: 19 | g = torch.Generator() 20 | g.manual_seed(self.epoch) 21 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 22 | else: 23 | indices = torch.arange(len(self.dataset)).tolist() 24 | 25 | # add extra samples to make it evenly divisible 26 | # in case that indices is shorter than half of total_size 27 | indices = (indices * 28 | math.ceil(self.total_size / len(indices)))[:self.total_size] 29 | assert len(indices) == self.total_size 30 | 31 | # subsample 32 | indices = indices[self.rank:self.total_size:self.num_replicas] 33 | assert len(indices) == self.num_samples 34 | 35 | return iter(indices) 36 | -------------------------------------------------------------------------------- /lib/model/DSN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/deepmind/kinetics-i3d/i3d.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops.layers.torch import Rearrange 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import cv2 13 | import os, math 14 | import sys 15 | from .DTN import DTNNet 16 | from .FRP import FRP_Module 17 | from .utils import * 18 | 19 | import os, math 20 | import sys 21 | sys.path.append('../../') 22 | from collections import OrderedDict 23 | from utils import load_pretrained_checkpoint 24 | import logging 25 | 26 | class DSNNet(nn.Module): 27 | VALID_ENDPOINTS = ( 28 | 'Conv3d_1a_7x7', 29 | 'MaxPool3d_2a_3x3', 30 | 'Conv3d_2b_1x1', 31 | 'Conv3d_2c_3x3', 32 | 'MaxPool3d_3a_3x3', 33 | 34 | 'Mixed_3b', 35 | 'Mixed_3c', 36 | 'MaxPool3d_4a_3x3', 37 | 'Mixed_4b', 38 | 'Mixed_4c', 39 | 'MaxPool3d_5a_2x2', 40 | 'Mixed_5b', 41 | 'Mixed_5c' 42 | ) 43 | 44 | def __init__(self, args, num_classes=400, spatial_squeeze=True, name='inception_i3d', in_channels=3, dropout_keep_prob=0.5, 45 | pretrained: str = False, 46 | dropout_spatial: float = 0.0): 47 | 48 | super(DSNNet, self).__init__() 49 | self._num_classes = num_classes 50 | self._spatial_squeeze = spatial_squeeze 51 | self.logits = None 52 | self.args = args 53 | 54 | self.end_points = {} 55 | 56 | ''' 57 | Low Level Features Extraction 58 | ''' 59 | end_point = 'Conv3d_1a_7x7' 60 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[1, 7, 7], 61 | stride=(1, 2, 2), padding=(0, 3, 3), name=name + end_point) 62 | 63 | end_point = 'MaxPool3d_2a_3x3' 64 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 65 | padding=0) 66 | 67 | end_point = 'Conv3d_2b_1x1' 68 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 69 | name=name + end_point) 70 | 71 | end_point = 'Conv3d_2c_3x3' 72 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[1, 3, 3], 73 | padding=(0, 1, 1), 74 | name=name + end_point) 75 | 76 | end_point = 'MaxPool3d_3a_3x3' 77 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 78 | padding=0) 79 | 80 | ''' 81 | Spatial Multi-scale Features Learning 82 | ''' 83 | end_point = 'Mixed_3b' 84 | self.end_points[end_point] = SpatialInceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point) 85 | 86 | end_point = 'Mixed_3c' 87 | self.end_points[end_point] = SpatialInceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point) 88 | 89 | end_point = 'MaxPool3d_4a_3x3' 90 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 91 | padding=0) 92 | 93 | end_point = 'Mixed_4b' 94 | self.end_points[end_point] = SpatialInceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) 95 | 96 | end_point = 'Mixed_4c' 97 | self.end_points[end_point] = SpatialInceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) 98 | 99 | end_point = 'MaxPool3d_5a_2x2' 100 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2), 101 | padding=0) 102 | 103 | end_point = 'Mixed_5b' 104 | self.end_points[end_point] = SpatialInceptionModule(160 + 224 + 64 + 64, [256, 160, 320, 32, 128, 128], 105 | name + end_point) 106 | 107 | end_point = 'Mixed_5c' 108 | self.end_points[end_point] = SpatialInceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], 109 | name + end_point) 110 | 111 | self.LinearMap = nn.Sequential( 112 | nn.LayerNorm(1024), 113 | nn.Linear(1024, 512), 114 | # nn.Dropout(dropout_spatial) 115 | ) 116 | 117 | self.avg_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) 118 | self.dropout = nn.Dropout(dropout_keep_prob) 119 | self.build() 120 | self.dtn = DTNNet(args, num_classes=self._num_classes) 121 | self.rrange = Rearrange('b c t h w -> b t c h w') 122 | 123 | if args.frp: 124 | self.frp_module = FRP_Module(w=args.w, inplanes=64) 125 | 126 | if pretrained: 127 | load_pretrained_checkpoint(self, pretrained) 128 | 129 | def build(self): 130 | for k in self.end_points.keys(): 131 | self.add_module(k, self.end_points[k]) 132 | 133 | def forward(self, x, garr): 134 | inp = x 135 | for end_point in self.VALID_ENDPOINTS: 136 | if end_point in self.end_points: 137 | if end_point in ['Mixed_3b']: 138 | x = self._modules[end_point](x) 139 | if self.args.frp: 140 | x = self.frp_module(x, garr) + x 141 | elif end_point in ['Mixed_4b']: 142 | x = self._modules[end_point](x) 143 | if self.args.frp: 144 | x = self.frp_module(x, garr) + x 145 | f = x 146 | elif end_point in ['Mixed_5b']: 147 | x = self._modules[end_point](x) 148 | if self.args.frp: 149 | x = self.frp_module(x, garr) + x 150 | else: 151 | x = self._modules[end_point](x) 152 | feat = x 153 | 154 | x = self.avg_pool(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1) 155 | x = self.LinearMap(x) 156 | cnn_vison = self.rrange(f.sum(dim=1, keepdim=True)) 157 | logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight) = self.dtn(x) 158 | # return logits, distillation_loss, (cnn_vison[0].detach(), att_map, inp[0, :], 159 | # cosin_similar, MHAS, (feat, logits[0])) 160 | return logits, distillation_loss, (cnn_vison[0], att_map, cosin_similar, visweight, MHAS, (feat, inp[0, :])) 161 | -------------------------------------------------------------------------------- /lib/model/DSN_Fusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/deepmind/kinetics-i3d/i3d.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops.layers.torch import Rearrange 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import cv2 13 | import os, math 14 | import sys 15 | from .DTN import DTNNet 16 | from .FRP import FRP_Module 17 | from .utils import * 18 | 19 | import os, math 20 | import sys 21 | sys.path.append('../../') 22 | from collections import OrderedDict 23 | from utils import load_pretrained_checkpoint 24 | import logging 25 | 26 | 27 | class DSNNet(nn.Module): 28 | VALID_ENDPOINTS = ( 29 | 'Conv3d_1a_7x7', 30 | 'MaxPool3d_2a_3x3', 31 | 'Conv3d_2b_1x1', 32 | 'Conv3d_2c_3x3', 33 | 'MaxPool3d_3a_3x3', 34 | 35 | 'Mixed_3b', 36 | 'Mixed_3c', 37 | 'MaxPool3d_4a_3x3', 38 | 'Mixed_4b', 39 | 'Mixed_4c', 40 | 'MaxPool3d_5a_2x2', 41 | 'Mixed_5b', 42 | 'Mixed_5c' 43 | ) 44 | 45 | def __init__(self, args, num_classes=400, spatial_squeeze=True, name='inception_i3d', in_channels=3, dropout_keep_prob=0.5, 46 | pretrained: str = False): 47 | 48 | super(DSNNet, self).__init__() 49 | self._num_classes = num_classes 50 | self._spatial_squeeze = spatial_squeeze 51 | self.logits = None 52 | self.args = args 53 | 54 | self.end_points = {} 55 | 56 | ''' 57 | Low Level Features Extraction 58 | ''' 59 | end_point = 'Conv3d_1a_7x7' 60 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[1, 7, 7], 61 | stride=(1, 2, 2), padding=(0, 3, 3), name=name + end_point) 62 | 63 | end_point = 'MaxPool3d_2a_3x3' 64 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 65 | padding=0) 66 | 67 | end_point = 'Conv3d_2b_1x1' 68 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 69 | name=name + end_point) 70 | 71 | end_point = 'Conv3d_2c_3x3' 72 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[1, 3, 3], 73 | padding=(0, 1, 1), 74 | name=name + end_point) 75 | 76 | end_point = 'MaxPool3d_3a_3x3' 77 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 78 | padding=0) 79 | 80 | ''' 81 | Spatial Multi-scale Features Learning 82 | ''' 83 | end_point = 'Mixed_3b' 84 | self.end_points[end_point] = SpatialInceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point) 85 | 86 | end_point = 'Mixed_3c' 87 | self.end_points[end_point] = SpatialInceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point) 88 | 89 | end_point = 'MaxPool3d_4a_3x3' 90 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 91 | padding=0) 92 | 93 | end_point = 'Mixed_4b' 94 | self.end_points[end_point] = SpatialInceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) 95 | 96 | end_point = 'Mixed_4c' 97 | self.end_points[end_point] = SpatialInceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) 98 | 99 | end_point = 'MaxPool3d_5a_2x2' 100 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2), 101 | padding=0) 102 | 103 | end_point = 'Mixed_5b' 104 | self.end_points[end_point] = SpatialInceptionModule(160 + 224 + 64 + 64, [256, 160, 320, 32, 128, 128], 105 | name + end_point) 106 | 107 | end_point = 'Mixed_5c' 108 | self.end_points[end_point] = SpatialInceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], 109 | name + end_point) 110 | 111 | self.LinearMap = nn.Sequential( 112 | nn.LayerNorm(1024), 113 | nn.Linear(1024, 512), 114 | 115 | ) 116 | 117 | self.avg_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) 118 | self.dropout = nn.Dropout(dropout_keep_prob) 119 | self.build() 120 | self.dtn = DTNNet(args, num_classes=self._num_classes) 121 | self.rrange = Rearrange('b c t h w -> b t c h w') 122 | 123 | if args.frp: 124 | self.frp_module = FRP_Module(w=args.w, inplanes=64) 125 | 126 | if pretrained: 127 | load_pretrained_checkpoint(self, pretrained) 128 | 129 | def build(self): 130 | for k in self.end_points.keys(): 131 | self.add_module(k, self.end_points[k]) 132 | 133 | def forward(self, x=None, garr=None, endpoint=None): 134 | if endpoint == 'spatial': 135 | for end_point in self.VALID_ENDPOINTS: 136 | if end_point in self.end_points: 137 | if end_point in ['Mixed_3b']: 138 | x = self._modules[end_point](x) 139 | if self.args.frp: 140 | x = self.frp_module(x, garr) + x 141 | elif end_point in ['Mixed_4b']: 142 | x = self._modules[end_point](x) 143 | if self.args.frp: 144 | x = self.frp_module(x, garr) + x 145 | f = x 146 | elif end_point in ['Mixed_5b']: 147 | x = self._modules[end_point](x) 148 | if self.args.frp: 149 | x = self.frp_module(x, garr) + x 150 | else: 151 | x = self._modules[end_point](x) 152 | 153 | x = self.avg_pool(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1) 154 | x = self.LinearMap(x) 155 | return x 156 | else: 157 | logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight) = self.dtn(x) 158 | return logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight) 159 | -------------------------------------------------------------------------------- /lib/model/DTN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | 11 | from einops import rearrange, repeat 12 | from einops.layers.torch import Rearrange 13 | import numpy as np 14 | import random, math 15 | from .utils import * 16 | from .trans_module import * 17 | 18 | np.random.seed(123) 19 | random.seed(123) 20 | 21 | 22 | class Transformer(nn.Module): 23 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., apply_transform=False, knn_attention=0.7): 24 | super().__init__() 25 | self.layers = nn.ModuleList([]) 26 | for _ in range(depth): 27 | self.layers.append(nn.ModuleList([ 28 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, 29 | apply_transform=apply_transform, knn_attention=knn_attention)), 30 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 31 | ])) 32 | 33 | def forward(self, x): 34 | for attn, ff in self.layers: 35 | x = attn(x) + x 36 | x = ff(x) + x 37 | return x 38 | 39 | 40 | class MultiScaleTransformerEncoder(nn.Module): 41 | 42 | def __init__(self, args, small_dim=1024, small_depth=4, small_heads=8, small_dim_head=64, hidden_dim_small=768, 43 | media_dim=1024, media_depth=4, media_heads=8, media_dim_head=64, hidden_dim_media=768, 44 | large_dim=1024, large_depth=4, large_heads=8, large_dim_head=64, hidden_dim_large=768, 45 | dropout=0., Local_flag=True): 46 | super().__init__() 47 | 48 | self.transformer_enc_small = Transformer(small_dim, small_depth, small_heads, small_dim_head, 49 | mlp_dim=hidden_dim_small, dropout=dropout, knn_attention=args.knn_attention) 50 | self.transformer_enc_media = Transformer(media_dim, media_depth, media_heads, media_dim_head, 51 | mlp_dim=hidden_dim_media, dropout=dropout, knn_attention=args.knn_attention) 52 | self.transformer_enc_large = Transformer(large_dim, large_depth, large_heads, large_dim_head, 53 | mlp_dim=hidden_dim_large, dropout=dropout, knn_attention=args.knn_attention) 54 | if Local_flag: 55 | self.Mixed_small = TemporalInceptionModule(512, [160,112,224,24,64,64], 'Mixed_small') 56 | self.Mixed_media = TemporalInceptionModule(512, [160,112,224,24,64,64], 'Mixed_media') 57 | self.Mixed_large = TemporalInceptionModule(512, [160, 112, 224, 24, 64, 64], 'Mixed_large') 58 | self.MaxPool = MaxPool3dSamePadding(kernel_size=[3, 1, 1], stride=(1, 1, 1), padding=0) 59 | 60 | def forward(self, xs, xm, xl, Local_flag=False): 61 | # Local Modeling 62 | if Local_flag: 63 | cls_small = xs[:, 0] 64 | xs = self.Mixed_small(xs[:, 1:, :].permute(0, 2, 1).view(xs.size(0), xs.size(-1), -1, 1, 1)) 65 | xs = self.MaxPool(xs) 66 | xs = torch.cat((cls_small.unsqueeze(1), xs.view(xs.size(0), xs.size(1), -1).permute(0, 2, 1)), dim=1) 67 | 68 | cls_media = xm[:, 0] 69 | xm = self.Mixed_media(xm[:, 1:, :].permute(0, 2, 1).view(xm.size(0), xm.size(-1), -1, 1, 1)) 70 | xm = self.MaxPool(xm) 71 | xm = torch.cat((cls_media.unsqueeze(1), xm.view(xm.size(0), xm.size(1), -1).permute(0, 2, 1)), dim=1) 72 | 73 | cls_large = xl[:, 0] 74 | xl = self.Mixed_large(xl[:, 1:, :].permute(0, 2, 1).view(xl.size(0), xl.size(-1), -1, 1, 1)) 75 | xl = self.MaxPool(xl) 76 | xl = torch.cat((cls_large.unsqueeze(1), xl.view(xl.size(0), xl.size(1), -1).permute(0, 2, 1)), dim=1) 77 | 78 | # Global Modeling 79 | xs = self.transformer_enc_small(xs) 80 | xm = self.transformer_enc_media(xm) 81 | xl = self.transformer_enc_large(xl) 82 | 83 | return xs, xm, xl 84 | 85 | 86 | class RCMModule(nn.Module): 87 | def __init__(self, args, dim_head=64, method='New', merge='GAP'): 88 | super(RCMModule, self).__init__() 89 | self.merge = merge 90 | self.heads = args.SEHeads 91 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 92 | self.avg_pool3d = nn.AdaptiveAvgPool3d((None, 1, None)) 93 | 94 | # Self Attention Layers 95 | self.q = nn.Linear(64, dim_head * self.heads, bias=False) 96 | self.k = nn.Linear(64, dim_head * self.heads, bias=False) 97 | self.scale = dim_head ** -0.5 98 | 99 | self.method = method 100 | if method == 'Ori': 101 | self.norm = nn.LayerNorm(128) 102 | self.project = nn.Sequential( 103 | nn.Linear(64, 512, bias=False), 104 | nn.GELU(), 105 | nn.Linear(512, 512, bias=False), 106 | nn.LayerNorm(512) 107 | ) 108 | elif method == 'New': 109 | if args.dataset == 'THU': 110 | hidden_dim = 128 111 | else: 112 | hidden_dim = 256 113 | self.project = nn.Sequential( 114 | nn.Linear(64, hidden_dim, bias=False), 115 | nn.GELU(), 116 | nn.Linear(hidden_dim, 64, bias=False), 117 | nn.LayerNorm(64), 118 | ) 119 | self.linear = nn.Linear(64, 512) 120 | # init.kaiming_uniform_(self.linear, a=math.sqrt(5)) 121 | 122 | if self.heads > 1: 123 | self.mergefc = nn.Sequential( 124 | nn.Dropout(0.4), 125 | nn.Linear(512 * self.heads, 512, bias=False), 126 | nn.LayerNorm(512) 127 | ) 128 | 129 | def forward(self, x): 130 | b, c, t = x.shape 131 | inp = x.clone() 132 | 133 | # Sequence (Y) direction 134 | xd_weight = self.project(self.avg_pool(inp.permute(0, 2, 1)).view(b, -1)) 135 | xd_weight = torch.sigmoid(xd_weight).view(b, -1, 1) 136 | 137 | # Feature (X) direction 138 | q, k = self.q(x), self.k(x) 139 | q, k = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), [q, k]) 140 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 141 | if self.merge == 'mean': 142 | dots = dots.mean(dim=2) 143 | elif self.merge == 'GAP': 144 | dots = self.avg_pool3d(dots).squeeze() 145 | 146 | if self.heads > 1: 147 | dots = dots.view(b, -1) 148 | dots = self.mergefc(dots) 149 | else: 150 | dots = dots.squeeze() 151 | y = torch.sigmoid(dots).view(b, c, 1) 152 | 153 | if self.method == 'Ori': 154 | out = x * (y.expand_as(x) + xd_weight.expand_as(x)) 155 | visweight = xd_weight # for visualization 156 | return out, xd_weight, visweight 157 | 158 | elif self.method == 'New': 159 | weight = einsum('b i d, b j d -> b i j', xd_weight, y) 160 | out = x * weight.permute(0, 2, 1) 161 | visweight = weight # for visualization 162 | return out, self.linear(xd_weight.squeeze()), visweight 163 | 164 | class DTNNet(nn.Module): 165 | def __init__(self, args, num_classes=249, small_dim=512, media_dim=512, large_dim=512, 166 | small_depth=1, media_depth=1, large_depth=1, 167 | heads=8, pool='cls', dropout=0.1, emb_dropout=0.0, branch_merge='pool', 168 | init: bool = False, 169 | warmup_temp_epochs: int = 30): 170 | super().__init__() 171 | self.low_frames = args.low_frames 172 | self.media_frames = args.media_frames 173 | self.high_frames = args.high_frames 174 | self.branch_merge = branch_merge 175 | self._args = args 176 | warmup_temp, temp = map(float, args.temp) 177 | 178 | multi_scale_enc_depth = args.N 179 | num_patches_small = self.low_frames 180 | num_patches_media = self.media_frames 181 | num_patches_large = self.high_frames 182 | 183 | self.pos_embedding_small = nn.Parameter(torch.randn(1, num_patches_small + 1, small_dim)) 184 | self.cls_token_small = nn.Parameter(torch.randn(1, 1, small_dim)) 185 | self.dropout_small = nn.Dropout(emb_dropout) 186 | 187 | self.pos_embedding_media = nn.Parameter(torch.randn(1, num_patches_media + 1, media_dim)) 188 | self.cls_token_media = nn.Parameter(torch.randn(1, 1, media_dim)) 189 | self.dropout_media = nn.Dropout(emb_dropout) 190 | 191 | self.pos_embedding_large = nn.Parameter(torch.randn(1, num_patches_large + 1, large_dim)) 192 | self.cls_token_large = nn.Parameter(torch.randn(1, 1, large_dim)) 193 | self.dropout_large = nn.Dropout(emb_dropout) 194 | 195 | self.multi_scale_transformers = nn.ModuleList([]) 196 | Local_flag = True 197 | for _ in range(multi_scale_enc_depth): 198 | self.multi_scale_transformers.append( 199 | MultiScaleTransformerEncoder(args, small_dim=small_dim, small_depth=small_depth, 200 | small_heads=heads, 201 | 202 | media_dim=media_dim, media_depth=media_depth, 203 | media_heads=heads, 204 | 205 | large_dim=large_dim, large_depth=large_depth, 206 | large_heads=heads, 207 | dropout=dropout, 208 | Local_flag=Local_flag)) 209 | Local_flag = False 210 | self.pool = pool 211 | # self.to_latent = nn.Identity() 212 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 213 | self.max_pool = nn.AdaptiveMaxPool1d(1) 214 | 215 | if self._args.recoupling: 216 | self.rcm = RCMModule(args) 217 | 218 | if args.Network != 'FusionNet': 219 | self.mlp_head_small = nn.Sequential( 220 | nn.LayerNorm(small_dim), 221 | nn.Linear(small_dim, num_classes), 222 | # nn.Dropout(0.4) 223 | ) 224 | self.mlp_head_media = nn.Sequential( 225 | nn.LayerNorm(media_dim), 226 | nn.Linear(media_dim, num_classes), 227 | # nn.Dropout(0.4) 228 | 229 | ) 230 | 231 | self.mlp_head_large = nn.Sequential( 232 | nn.LayerNorm(large_dim), 233 | nn.Linear(large_dim, num_classes), 234 | # nn.Dropout(0.4) 235 | ) 236 | 237 | self.show_res = Rearrange('b t (c p1 p2) -> b t c p1 p2', p1=int(small_dim ** 0.5), p2=int(small_dim ** 0.5)) 238 | self.temp_schedule = np.concatenate(( 239 | np.linspace(warmup_temp, 240 | temp, warmup_temp_epochs), 241 | np.ones(args.epochs - warmup_temp_epochs) * temp 242 | )) 243 | 244 | if init: 245 | self.init_weights() 246 | 247 | @torch.no_grad() 248 | def init_weights(self): 249 | def _init(m): 250 | if isinstance(m, nn.Linear): 251 | nn.init.xavier_uniform_( 252 | m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal 253 | if hasattr(m, 'bias') and m.bias is not None: 254 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) 255 | 256 | self.apply(_init) 257 | 258 | # ---------------------------------- 259 | # frames simple function 260 | # ---------------------------------- 261 | def f(self, n, sn): 262 | SL = lambda n, sn: [(lambda n, arr: n if arr == [] else random.choice(arr))(n * i / sn, 263 | range(int(n * i / sn), 264 | max(int(n * i / sn) + 1, 265 | int(n * ( 266 | i + 1) / sn)))) 267 | for i in range(sn)] 268 | return SL(n, sn) 269 | 270 | def forward(self, img): # img size: [2, 64, 1024] 271 | # ---------------------------------- 272 | # Recoupling: 273 | # ---------------------------------- 274 | if self._args.recoupling: 275 | img, spatial_weights, visweight = self.rcm(img.permute(0, 2, 1)) 276 | img = img.permute(0, 2, 1) 277 | else: 278 | visweight = img 279 | 280 | # ---------------------------------- 281 | sl_low = self.f(img.size(1), self.low_frames) 282 | xs = img[:, sl_low, :] 283 | b, n, _ = xs.shape 284 | 285 | cls_token_small = repeat(self.cls_token_small, '() n d -> b n d', b=b) 286 | xs = torch.cat((cls_token_small, xs), dim=1) 287 | xs += self.pos_embedding_small[:, :(n + 1)] 288 | xs = self.dropout_small(xs) 289 | 290 | # ---------------------------------- 291 | sl_media = self.f(img.size(1), self.media_frames) 292 | xm = img[:, sl_media, :] 293 | b, n, _ = xm.shape 294 | 295 | cls_token_media = repeat(self.cls_token_media, '() n d -> b n d', b=b) 296 | xm = torch.cat((cls_token_media, xm), dim=1) 297 | xm += self.pos_embedding_media[:, :(n + 1)] 298 | xm = self.dropout_media(xm) 299 | 300 | # ---------------------------------- 301 | sl_high = self.f(img.size(1), self.high_frames) 302 | xl = img[:, sl_high, :] 303 | b, n, _ = xl.shape 304 | 305 | cls_token_large = repeat(self.cls_token_large, '() n d -> b n d', b=b) 306 | xl = torch.cat((cls_token_large, xl), dim=1) 307 | xl += self.pos_embedding_large[:, :(n + 1)] 308 | xl = self.dropout_large(xl) 309 | 310 | # ---------------------------------- 311 | # Temporal Multi-scale features learning 312 | # ---------------------------------- 313 | Local_flag = True 314 | for multi_scale_transformer in self.multi_scale_transformers: 315 | xs, xm, xl = multi_scale_transformer(xs, xm, xl, Local_flag) 316 | Local_flag = False 317 | 318 | xs = xs.mean(dim=1) if self.pool == 'mean' else xs[:, 0] 319 | xm = xm.mean(dim=1) if self.pool == 'mean' else xm[:, 0] 320 | xl = xl.mean(dim=1) if self.pool == 'mean' else xl[:, 0] 321 | 322 | if self._args.recoupling: 323 | T = self._args.temper 324 | distillation_loss = F.kl_div(F.log_softmax(spatial_weights.squeeze() / T, dim=-1), 325 | F.softmax(((xs + xm + xl) / 3.).detach() / T, dim=-1), 326 | reduction='sum') 327 | else: 328 | distillation_loss = torch.zeros(1).cuda() 329 | 330 | if self._args.Network != 'FusionNet': 331 | if self._args.sharpness: 332 | temp = self.temp_schedule[self._args.epoch] 333 | xs = self.mlp_head_small(xs) / temp 334 | xm = self.mlp_head_media(xm) / temp 335 | xl = self.mlp_head_large(xl) / temp 336 | else: 337 | xs = self.mlp_head_small(xs) 338 | xm = self.mlp_head_media(xm) 339 | xl = self.mlp_head_large(xl) 340 | 341 | if self.branch_merge == 'sum': 342 | x = xs + xm + xl 343 | elif self.branch_merge == 'pool': 344 | x = self.max_pool(torch.cat((xs.unsqueeze(2), xm.unsqueeze(2), xl.unsqueeze(2)), dim=-1)).squeeze() 345 | 346 | # --------------------------------- 347 | # Get score from multi-branch Trans for visualization 348 | # --------------------------------- 349 | scores_small = self.multi_scale_transformers[2].transformer_enc_small.layers[-1][0].fn.scores 350 | scores_media = self.multi_scale_transformers[2].transformer_enc_media.layers[-1][0].fn.scores 351 | scores_large = self.multi_scale_transformers[2].transformer_enc_large.layers[-1][0].fn.scores 352 | 353 | # resize attn 354 | attn_media = scores_media.detach().clone() 355 | attn_media.resize_(*scores_small.size()) 356 | 357 | attn_large = scores_large.detach().clone() 358 | attn_large.resize_(*scores_small.size()) 359 | 360 | att_small = scores_small.detach().clone() 361 | 362 | scores = torch.cat((att_small, attn_media, attn_large), dim=1) # [2, 24, 17, 17] 363 | att_map = torch.zeros(scores.size(0), scores.size(1), scores.size(1), dtype=torch.float) 364 | for b in range(scores.size(0)): 365 | for i, s1 in enumerate(scores[b]): 366 | for j, s2 in enumerate(scores[b]): 367 | cosin_simil = torch.cosine_similarity(s1.view(1, -1), s2.view(1, -1)) 368 | att_map[b][i][j] = cosin_simil 369 | 370 | # -------------------------------- 371 | # Measure cosine similarity of xs and xl 372 | # -------------------------------- 373 | cosin_similar_xs_xm = torch.cosine_similarity(xs[0], xm[0], dim=-1) 374 | cosin_similar_xs_xl = torch.cosine_similarity(xs[0], xl[0], dim=-1) 375 | cosin_similar_xm_xl = torch.cosine_similarity(xm[0], xl[0], dim=-1) 376 | cosin_similar_sum = cosin_similar_xs_xm + cosin_similar_xs_xl + cosin_similar_xm_xl 377 | 378 | return (x, xs, xm, xl), distillation_loss, (att_map, cosin_similar_sum.cpu(), 379 | (scores_small[0], scores_media[0], scores_large[0]), visweight[0]) -------------------------------------------------------------------------------- /lib/model/FRP.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/lib/model/RAAR3DNet.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops.layers.torch import Rearrange 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from torchvision import transforms 12 | import numpy as np 13 | import cv2 14 | from torchvision.utils import save_image, make_grid 15 | 16 | def tensor_split(t): 17 | arr = torch.split(t, 1, dim=2) 18 | arr = [x.squeeze(2) for x in arr] 19 | return arr 20 | 21 | def tensor_merge(arr): 22 | arr = [x.unsqueeze(1) for x in arr] 23 | t = torch.cat(arr, dim=1) 24 | return t.permute(0, 2, 1, 3, 4) 25 | 26 | class FRP_Module(nn.Module): 27 | def __init__(self, w, inplanes): 28 | super(FRP_Module, self).__init__() 29 | self._w = w 30 | self.rpconv1d = nn.Conv1d(2, 1, 1, bias=False) # Rank Pooling Conv1d, Kernel Size 2x1x1 31 | self.rpconv1d.weight.data = torch.FloatTensor([[[1.0], [0.0]]]) 32 | # self.bnrp = nn.BatchNorm3d(inplanes) # BatchNorm Rank Pooling 33 | # self.relu = nn.ReLU(inplace=True) 34 | self.hapooling = nn.MaxPool2d(kernel_size=2) 35 | 36 | def forward(self, x, datt=None): 37 | inp = x 38 | if self._w < 1: 39 | return x 40 | def run_layer_on_arr(arr, l): 41 | return [l(x) for x in arr] 42 | def oneconv(a, b): 43 | s = a.size() 44 | c = torch.cat([a.contiguous().view(s[0], -1, 1), b.contiguous().view(s[0], -1, 1)], dim=2) 45 | c = self.rpconv1d(c.permute(0, 2, 1)).permute(0, 2, 1) 46 | return c.view(s) 47 | if datt is not None: 48 | tarr = tensor_split(x) 49 | garr = tensor_split(datt) 50 | while tarr[0].size()[3] < garr[0].size()[3]: # keep feature map and heatmap the same size 51 | garr = run_layer_on_arr(garr, self.hapooling) 52 | 53 | attarr = [a * (b + torch.ones(a.size()).cuda()) for a, b in zip(tarr, garr)] 54 | datt = [oneconv(a, b) for a, b in zip(tarr, attarr)] 55 | return tensor_merge(datt) 56 | 57 | def tensor_arr_rp(arr): 58 | l = len(arr) 59 | def tensor_rankpooling(video_arr): 60 | def get_w(N): 61 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)] 62 | 63 | # re = torch.zeros(video_arr[0].size(0), 1, video_arr[0].size(2), video_arr[0].size(3)).cuda() 64 | re = torch.zeros(video_arr[0].size()).cuda() 65 | for a, b in zip(video_arr, get_w(len(video_arr))): 66 | # a = transforms.Grayscale(1)(a) 67 | re += a * b 68 | re = F.gelu(re) 69 | re -= torch.min(re) 70 | re = re / torch.max(re) if torch.max(re) != 0 else re / (torch.max(re) + 0.00001) 71 | return transforms.Grayscale(1)(re) 72 | 73 | return [tensor_rankpooling(arr[i:i + self._w]) for i in range(l)] 74 | 75 | arrrp = tensor_arr_rp(tensor_split(x)) 76 | 77 | b, c, t, h, w = tensor_merge(arrrp).shape 78 | mask = torch.zeros(b, c, self._w-1, h, w, device=tensor_merge(arrrp).device) 79 | garrs = torch.cat((mask, tensor_merge(arrrp)), dim=2) 80 | return garrs 81 | 82 | if __name__ == '__main__': 83 | model = SATT_Module().cuda() 84 | inp = torch.randn(2, 3, 64, 224, 224).cuda() 85 | out = model(inp) 86 | print(out.shape) 87 | -------------------------------------------------------------------------------- /lib/model/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | from .build import * -------------------------------------------------------------------------------- /lib/model/build.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | from .DSN import DSNNet 6 | from .fusion_Net import CrossFusionNet 7 | 8 | import logging 9 | 10 | def build_model(args): 11 | num_classes = dict( 12 | IsoGD=249, 13 | NvGesture=25, 14 | Jester=27, 15 | THUREAD=40, 16 | NTU=60 17 | ) 18 | func_dict = dict( 19 | I3DWTrans=DSNNet, 20 | FusionNet=CrossFusionNet 21 | ) 22 | assert args.dataset in num_classes, 'Error in load dataset !' 23 | assert args.Network in func_dict, 'Error in Network function !' 24 | args.num_classes = num_classes[args.dataset] 25 | if args.local_rank == 0: 26 | logging.info('Model:{}, Total Categories:{}'.format(args.Network, args.num_classes)) 27 | return func_dict[args.Network](args, num_classes=args.num_classes, pretrained=args.pretrained) 28 | -------------------------------------------------------------------------------- /lib/model/fusion_Net.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from collections import OrderedDict 10 | 11 | import numpy as np 12 | 13 | import os 14 | import sys 15 | from collections import OrderedDict 16 | 17 | sys.path.append(['../../', '../']) 18 | from utils import load_pretrained_checkpoint, load_checkpoint 19 | import logging 20 | from .DSN_Fusion import DSNNet 21 | 22 | class LabelSmoothingCrossEntropy(torch.nn.Module): 23 | def __init__(self, smoothing: float = 0.1, 24 | reduction="mean", weight=None): 25 | super(LabelSmoothingCrossEntropy, self).__init__() 26 | self.smoothing = smoothing 27 | self.reduction = reduction 28 | self.weight = weight 29 | 30 | def reduce_loss(self, loss): 31 | return loss.mean() if self.reduction == 'mean' else loss.sum() \ 32 | if self.reduction == 'sum' else loss 33 | 34 | def linear_combination(self, x, y): 35 | return self.smoothing * x + (1 - self.smoothing) * y 36 | 37 | def forward(self, preds, target): 38 | assert 0 <= self.smoothing < 1 39 | 40 | if self.weight is not None: 41 | self.weight = self.weight.to(preds.device) 42 | 43 | n = preds.size(-1) 44 | log_preds = F.log_softmax(preds, dim=-1) 45 | loss = self.reduce_loss(-log_preds.sum(dim=-1)) 46 | nll = F.nll_loss( 47 | log_preds, target, reduction=self.reduction, weight=self.weight 48 | ) 49 | return self.linear_combination(loss / n, nll) 50 | 51 | 52 | class Encoder(nn.Module): 53 | def __init__(self, C_in, C_out, dilation=2): 54 | super(Encoder, self).__init__() 55 | self.enconv = nn.Sequential( 56 | nn.Conv2d(C_in, C_in, kernel_size=1, stride=1, padding=0, bias=False), 57 | nn.BatchNorm2d(C_in), 58 | nn.ReLU(inplace=False), 59 | 60 | nn.Conv2d(C_in, C_in // 2, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(C_in // 2), 62 | nn.ReLU(inplace=False), 63 | 64 | nn.Conv2d(C_in // 2, C_in // 4, kernel_size=1, stride=1, padding=0, bias=False), 65 | nn.BatchNorm2d(C_in // 4), 66 | nn.ReLU(inplace=False), 67 | 68 | nn.Conv2d(C_in // 4, C_out, kernel_size=1, stride=1, padding=0, bias=False), 69 | ) 70 | 71 | def forward(self, x1, x2): 72 | b, c = x1.shape 73 | x = torch.cat((x1, x2), dim=1).view(b, -1, 1, 1) 74 | x = self.enconv(x) 75 | return x 76 | 77 | 78 | class Decoder(nn.Module): 79 | def __init__(self, C_in, C_out, dilation=2): 80 | super(Decoder, self).__init__() 81 | self.deconv = nn.Sequential( 82 | nn.Conv2d(C_in, C_out // 4, kernel_size=1, padding=0, bias=False), 83 | nn.BatchNorm2d(C_out // 4), 84 | nn.ReLU(), 85 | 86 | nn.Conv2d(C_out // 4, C_out // 2, kernel_size=1, padding=0, bias=False), 87 | nn.BatchNorm2d(C_out // 2), 88 | nn.ReLU(), 89 | ) 90 | 91 | def forward(self, x): 92 | x = self.deconv(x) 93 | return x 94 | 95 | 96 | class FusionModule(nn.Module): 97 | def __init__(self, channel_in=1024, channel_out=256, num_classes=60): 98 | super(FusionModule, self).__init__() 99 | self.encoder = Encoder(channel_in, channel_out) 100 | self.decoder = Decoder(channel_out, channel_in) 101 | self.efc = nn.Conv2d(channel_out, num_classes, kernel_size=1, padding=0, bias=False) 102 | 103 | def forward(self, r, d): 104 | en_x = self.encoder(r, d) # [4, 256, 1, 1] 105 | de_x = self.decoder(en_x) 106 | en_x = self.efc(en_x) 107 | return en_x.squeeze(), de_x 108 | 109 | class CrossFusionNet(nn.Module): 110 | def __init__(self, args, num_classes, pretrained, spatial_interact=True, temporal_interact=True): 111 | super(CrossFusionNet, self).__init__() 112 | self._MES = torch.nn.MSELoss() 113 | self._BCE = torch.nn.BCELoss() 114 | self._CE = LabelSmoothingCrossEntropy() 115 | self.spatial_interact = spatial_interact 116 | self.temporal_interact = temporal_interact 117 | 118 | self.fusion_model = FusionModule(channel_out=256, num_classes=num_classes) 119 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 120 | self.fc = nn.Conv2d(512, 1, kernel_size=1, padding=0, bias=False) 121 | self.dropout = nn.Dropout(0.5) 122 | 123 | assert args.rgb_checkpoint and args.depth_checkpoint 124 | self.Modalit_rgb = DSNNet(args, num_classes=num_classes, 125 | pretrained=args.rgb_checkpoint) 126 | 127 | self.Modalit_depth = DSNNet(args, num_classes=num_classes, 128 | pretrained=args.depth_checkpoint) 129 | 130 | if self.spatial_interact: 131 | self.crossFusion = nn.Sequential( 132 | nn.Conv2d(512 * 2, 512, kernel_size=1, stride=1, padding=0, bias=False), 133 | nn.BatchNorm2d(512), 134 | nn.ReLU(), 135 | nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False), 136 | nn.Dropout(0.4) 137 | 138 | ) 139 | if self.temporal_interact: 140 | self.crossFusionT = nn.Sequential( 141 | nn.Conv2d(512 * 2, 512, kernel_size=1, stride=1, padding=0, bias=False), 142 | nn.BatchNorm2d(512), 143 | nn.ReLU(), 144 | nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False), 145 | nn.Dropout(0.4) 146 | ) 147 | 148 | self.classifier1 = nn.Sequential( 149 | nn.LayerNorm(512), 150 | nn.Linear(512, num_classes) 151 | ) 152 | self.classifier2 = nn.Sequential( 153 | nn.LayerNorm(512), 154 | nn.Linear(512, num_classes) 155 | ) 156 | 157 | if pretrained: 158 | load_pretrained_checkpoint(self, pretrained) 159 | logging.info("Load Pre-trained model state_dict Done !") 160 | 161 | def forward(self, inputs, garrs, target): 162 | rgb, depth = inputs 163 | rgb_garr, depth_garr = garrs 164 | 165 | spatial_M = self.Modalit_rgb(rgb, rgb_garr, endpoint='spatial') 166 | spatial_K = self.Modalit_depth(depth, depth_garr, endpoint='spatial') 167 | 168 | if self.spatial_interact: 169 | b, t, c = spatial_M.shape 170 | spatial_fusion_features = self.crossFusion(F.normalize(torch.cat((spatial_M, spatial_K), dim=-1), p = 2, dim=-1).view(b, c*2, t, 1)).squeeze() 171 | 172 | (temporal_M, M_xs, M_xm, M_xl), distillationM, _ = self.Modalit_rgb(x=spatial_M + spatial_fusion_features.view(spatial_M.shape) if self.spatial_interact else spatial_M, 173 | endpoint='temporal') # size[4, 512] 174 | (temporal_K, K_xs, K_xm, K_xl), distillationK, _ = self.Modalit_depth(x=spatial_K + spatial_fusion_features.view(spatial_M.shape) if self.spatial_interact else spatial_K, 175 | endpoint='temporal') 176 | logit_r = self.classifier1(temporal_M) 177 | logit_d = self.classifier2(temporal_K) 178 | 179 | if self.temporal_interact: 180 | b, c = temporal_M.shape 181 | temporal_fusion_features = self.crossFusionT(F.normalize(torch.cat((temporal_M, temporal_K), dim=-1), p = 2, dim=-1).view(b, c*2, 1, 1)).squeeze() 182 | temporal_M, temporal_K = temporal_M+temporal_fusion_features, temporal_K+temporal_fusion_features 183 | 184 | en_x, de_x = self.fusion_model(temporal_M, temporal_K) 185 | b, c = temporal_M.shape 186 | bce_r = torch.sigmoid(self.fc(self.dropout(temporal_M).view(b, c, 1, 1))).view(b, -1) 187 | bce_d = torch.sigmoid(self.fc(self.dropout(temporal_K).view(b, c, 1, 1))).view(b, -1) 188 | 189 | BCE_loss = self._BCE(bce_r, torch.ones(bce_r.size(0), 1).cuda()) + self._BCE(bce_d, torch.zeros(bce_d.size(0), 190 | 1).cuda()) 191 | MSE_loss = self._MES(de_x.view(b, c), temporal_M) + self._MES(de_x.view(b, c), temporal_K) 192 | CE_loss = self._CE(en_x, target) + self._CE(logit_r, target) + self._CE(logit_d, target) 193 | distillation = distillationM + distillationK 194 | 195 | return (en_x, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) 196 | -------------------------------------------------------------------------------- /lib/model/trans_module.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/rishikksh20/CrossViT-pytorch/blob/master/crossvit.py 4 | ''' 5 | 6 | import torch 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | import math 11 | 12 | from einops import rearrange, repeat 13 | from einops.layers.torch import Rearrange 14 | 15 | class Residual(nn.Module): 16 | def __init__(self, fn): 17 | super().__init__() 18 | self.fn = fn 19 | 20 | def forward(self, x, **kwargs): 21 | return self.fn(x, **kwargs) + x 22 | 23 | 24 | class PreNorm(nn.Module): 25 | def __init__(self, dim, fn): 26 | super().__init__() 27 | self.norm = nn.LayerNorm(dim) 28 | self.fn = fn 29 | 30 | def forward(self, x, **kwargs): 31 | return self.fn(self.norm(x), **kwargs) 32 | 33 | 34 | 35 | # class FeedForward(nn.Module): 36 | # def __init__(self, dim, hidden_dim, dropout=0.): 37 | # super().__init__() 38 | # self.net = nn.Sequential( 39 | # nn.Linear(dim, hidden_dim), 40 | # nn.GELU(), 41 | # nn.Dropout(dropout), 42 | # nn.Linear(hidden_dim, dim), 43 | # nn.Dropout(dropout) 44 | # ) 45 | 46 | # def forward(self, x): 47 | # return self.net(x) 48 | 49 | class FeedForward(nn.Module): 50 | """FeedForward Neural Networks for each position""" 51 | def __init__(self, dim, hidden_dim, dropout=0.): 52 | super().__init__() 53 | self.fc1 = nn.Linear(dim, hidden_dim) 54 | self.fc2 = nn.Linear(hidden_dim, dim) 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | def forward(self, x): 58 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 59 | return self.dropout(self.fc2(self.dropout(F.gelu(self.fc1(x))))) 60 | 61 | class Attention(nn.Module): 62 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., apply_transform=False, transform_scale=True, knn_attention=0.7): 63 | super().__init__() 64 | inner_dim = dim_head * heads 65 | project_out = not (heads == 1 and dim_head == dim) 66 | 67 | self.heads = heads 68 | self.scale = dim_head ** -0.5 69 | self.apply_transform = apply_transform 70 | self.knn_attention = bool(knn_attention) 71 | self.topk = knn_attention 72 | 73 | if apply_transform: 74 | self.reatten_matrix = torch.nn.Conv2d(heads, heads, 1, 1) 75 | self.var_norm = torch.nn.BatchNorm2d(heads) 76 | self.reatten_scale = self.scale if transform_scale else 1.0 77 | 78 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 79 | 80 | self.to_out = nn.Sequential( 81 | nn.Linear(inner_dim, dim), 82 | nn.Dropout(dropout) 83 | ) if project_out else nn.Identity() 84 | self.scores = None 85 | 86 | def forward(self, x): 87 | b, n, _, h = *x.shape, self.heads 88 | qkv = self.to_qkv(x).chunk(3, dim=-1) 89 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 90 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 91 | 92 | if self.knn_attention: 93 | mask = torch.zeros(b, self.heads, n, n, device=x.device, requires_grad=False) 94 | index = torch.topk(dots, k=int(dots.size(-1)*self.topk), dim=-1, largest=True)[1] 95 | mask.scatter_(-1, index, 1.) 96 | dots = torch.where(mask > 0, dots, torch.full_like(dots, float('-inf'))) 97 | attn = dots.softmax(dim=-1) 98 | if self.apply_transform: 99 | attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale 100 | 101 | self.scores = attn 102 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 103 | 104 | 105 | out = rearrange(out, 'b h n d -> b n (h d)') 106 | out = self.to_out(out) 107 | return out 108 | 109 | 110 | class CrossAttention(nn.Module): 111 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 112 | super().__init__() 113 | inner_dim = dim_head * heads 114 | project_out = not (heads == 1 and dim_head == dim) 115 | 116 | self.heads = heads 117 | self.scale = dim_head ** -0.5 118 | 119 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 120 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 121 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 122 | 123 | self.to_out = nn.Sequential( 124 | nn.Linear(inner_dim, dim), 125 | nn.Dropout(dropout) 126 | ) if project_out else nn.Identity() 127 | 128 | def forward(self, x_qkv): 129 | b, n, _, h = *x_qkv.shape, self.heads 130 | 131 | k = self.to_k(x_qkv) 132 | k = rearrange(k, 'b n (h d) -> b h n d', h=h) 133 | 134 | v = self.to_v(x_qkv) 135 | v = rearrange(v, 'b n (h d) -> b h n d', h=h) 136 | 137 | q = self.to_q(x_qkv[:, 0].unsqueeze(1)) 138 | q = rearrange(q, 'b n (h d) -> b h n d', h=h) 139 | 140 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 141 | 142 | attn = dots.softmax(dim=-1) 143 | 144 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 145 | out = rearrange(out, 'b h n d -> b n (h d)') 146 | out = self.to_out(out) 147 | return out 148 | -------------------------------------------------------------------------------- /lib/model/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/deepmind/kinetics-i3d/blob/master/i3d.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import numpy as np 11 | import os 12 | import sys 13 | 14 | class MaxPool3dSamePadding(nn.MaxPool3d): 15 | def compute_pad(self, dim, s): 16 | if s % self.stride[dim] == 0: 17 | return max(self.kernel_size[dim] - self.stride[dim], 0) 18 | else: 19 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 20 | 21 | def forward(self, x): 22 | (batch, channel, t, h, w) = x.size() 23 | pad_t = self.compute_pad(0, t) 24 | pad_h = self.compute_pad(1, h) 25 | pad_w = self.compute_pad(2, w) 26 | pad_t_f = pad_t // 2 27 | pad_t_b = pad_t - pad_t_f 28 | pad_h_f = pad_h // 2 29 | pad_h_b = pad_h - pad_h_f 30 | pad_w_f = pad_w // 2 31 | pad_w_b = pad_w - pad_w_f 32 | 33 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 34 | x = F.pad(x, pad) 35 | return super(MaxPool3dSamePadding, self).forward(x) 36 | 37 | class Unit3D(nn.Module): 38 | 39 | def __init__(self, in_channels, 40 | output_channels, 41 | kernel_shape=(1, 1, 1), 42 | stride=(1, 1, 1), 43 | padding=0, 44 | activation_fn=F.relu, 45 | use_batch_norm=True, 46 | use_bias=False, 47 | name='unit_3d'): 48 | 49 | """Initializes Unit3D module.""" 50 | super(Unit3D, self).__init__() 51 | 52 | self._output_channels = output_channels 53 | self._kernel_shape = kernel_shape 54 | self._stride = stride 55 | self._use_batch_norm = use_batch_norm 56 | self._activation_fn = activation_fn 57 | self._use_bias = use_bias 58 | self.name = name 59 | self.padding = padding 60 | 61 | self.conv3d = nn.Conv3d(in_channels=in_channels, 62 | out_channels=self._output_channels, 63 | kernel_size=self._kernel_shape, 64 | stride=self._stride, 65 | padding=0, 66 | bias=self._use_bias) 67 | 68 | if self._use_batch_norm: 69 | self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01) 70 | 71 | def compute_pad(self, dim, s): 72 | if s % self._stride[dim] == 0: 73 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 74 | else: 75 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 76 | 77 | 78 | def forward(self, x): 79 | (batch, channel, t, h, w) = x.size() 80 | pad_t = self.compute_pad(0, t) 81 | pad_h = self.compute_pad(1, h) 82 | pad_w = self.compute_pad(2, w) 83 | pad_t_f = pad_t // 2 84 | pad_t_b = pad_t - pad_t_f 85 | pad_h_f = pad_h // 2 86 | pad_h_b = pad_h - pad_h_f 87 | pad_w_f = pad_w // 2 88 | pad_w_b = pad_w - pad_w_f 89 | 90 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 91 | x = F.pad(x, pad) 92 | x = self.conv3d(x) 93 | if self._use_batch_norm: 94 | x = self.bn(x) 95 | if self._activation_fn is not None: 96 | x = self._activation_fn(x) 97 | return x 98 | 99 | class TemporalInceptionModule(nn.Module): 100 | def __init__(self, in_channels, out_channels, name): 101 | super(TemporalInceptionModule, self).__init__() 102 | 103 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 104 | name=name+'/Branch_0/Conv3d_0a_1x1') 105 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 106 | name=name+'/Branch_1/Conv3d_0a_1x1') 107 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 1, 1], 108 | name=name+'/Branch_1/Conv3d_0b_3x3') 109 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 110 | name=name+'/Branch_2/Conv3d_0a_1x1') 111 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 1, 1], 112 | name=name+'/Branch_2/Conv3d_0b_3x3') 113 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 1, 1], 114 | stride=(1, 1, 1), padding=0) 115 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 116 | name=name+'/Branch_3/Conv3d_0b_1x1') 117 | self.name = name 118 | 119 | def forward(self, x): 120 | b0 = self.b0(x) 121 | b1 = self.b1b(self.b1a(x)) 122 | b2 = self.b2b(self.b2a(x)) 123 | b3 = self.b3b(self.b3a(x)) 124 | return torch.cat([b0,b1,b2,b3], dim=1) 125 | 126 | 127 | class SpatialInceptionModule(nn.Module): 128 | def __init__(self, in_channels, out_channels, name): 129 | super(SpatialInceptionModule, self).__init__() 130 | 131 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 132 | name=name + '/Branch_0/Conv3d_0a_1x1') 133 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 134 | name=name + '/Branch_1/Conv3d_0a_1x1') 135 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[1, 3, 3], 136 | name=name + '/Branch_1/Conv3d_0b_3x3') 137 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 138 | name=name + '/Branch_2/Conv3d_0a_1x1') 139 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[1, 3, 3], 140 | name=name + '/Branch_2/Conv3d_0b_3x3') 141 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 142 | stride=(1, 1, 1), padding=0) 143 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 144 | name=name + '/Branch_3/Conv3d_0b_1x1') 145 | self.name = name 146 | 147 | def forward(self, x): 148 | b0 = self.b0(x) 149 | b1 = self.b1b(self.b1a(x)) 150 | b2 = self.b2b(self.b2a(x)) 151 | b3 = self.b3b(self.b3a(x)) 152 | return torch.cat([b0, b1, b2, b3], dim=1) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | TRAIN=$1 3 | CONFIG=$2 4 | GPUID=$3 5 | GPUNUM=$4 6 | PORT=${PORT:-29509} 7 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node=$GPUNUM --master_port=$PORT $TRAIN --config $CONFIG --nprocs $GPUNUM --save_output 8 | -------------------------------------------------------------------------------- /tools/fusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | import os, random, math 5 | import time 6 | import glob 7 | import numpy as np 8 | import shutil 9 | 10 | import torch 11 | 12 | import logging 13 | import argparse 14 | import traceback 15 | import torch.nn as nn 16 | import torch.utils 17 | import torchvision.datasets as dset 18 | import torch.backends.cudnn as cudnn 19 | 20 | import sys 21 | sys.path.append(os.path.abspath(os.path.join("..", os.getcwd()))) 22 | from config import Config 23 | from lib import * 24 | import torch.distributed as dist 25 | from utils import * 26 | from utils.build import * 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--config', help='Place config Congfile!') 31 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?') 32 | parser.add_argument('--local_rank', type=int, default=0) 33 | parser.add_argument('--nprocs', type=int, default=1) 34 | 35 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?') 36 | parser.add_argument('--save_output', action='store_true', help='Save logits?') 37 | parser.add_argument('--fp16', action='store_true', help='Training with fp16') 38 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo') 39 | parser.add_argument('--resume', type=str, default='', help='resume model path.') 40 | 41 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability') 42 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment name') 43 | parser.add_argument('--seed', type=int, default=123, help='random seed') 44 | args = parser.parse_args() 45 | args = Config(args) 46 | 47 | #==================================================== 48 | # Some configuration 49 | #==================================================== 50 | 51 | try: 52 | if args.resume: 53 | args.save = os.path.split(args.resume)[0] 54 | else: 55 | args.save = '{}/{}-{}-{}-{}'.format(args.save, args.Network, args.dataset, args.type, time.strftime("%Y%m%d-%H%M%S")) 56 | utils.create_exp_dir(args.save, scripts_to_save=[args.config] + glob.glob('./tools/*.py') + glob.glob('./lib/*')) 57 | except: 58 | pass 59 | log_format = '%(asctime)s %(message)s' 60 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 61 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 62 | fh = logging.FileHandler(os.path.join(args.save, 'log{}.txt'.format(time.strftime("%Y%m%d-%H%M%S")))) 63 | fh.setFormatter(logging.Formatter(log_format)) 64 | logging.getLogger().addHandler(fh) 65 | 66 | #--------------------------------- 67 | # Fusion Net Training 68 | #--------------------------------- 69 | def reduce_mean(tensor, nprocs): 70 | rt = tensor.clone() 71 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 72 | rt /= nprocs 73 | return rt.item() 74 | 75 | 76 | def main(local_rank, nprocs, args): 77 | if not torch.cuda.is_available(): 78 | logging.info('no gpu device available') 79 | sys.exit(1) 80 | 81 | np.random.seed(args.seed) 82 | cudnn.benchmark = True 83 | torch.manual_seed(args.seed) 84 | cudnn.enabled = True 85 | torch.cuda.manual_seed(args.seed) 86 | logging.info('gpu device = %d' % local_rank) 87 | 88 | # --------------------------- 89 | # Init distribution 90 | # --------------------------- 91 | torch.cuda.set_device(local_rank) 92 | torch.distributed.init_process_group(backend='nccl') 93 | 94 | # ---------------------------- 95 | # build function 96 | # ---------------------------- 97 | model = build_model(args) 98 | model = model.cuda(local_rank) 99 | 100 | criterion = build_loss(args) 101 | optimizer = build_optim(args, model) 102 | scheduler = build_scheduler(args, optimizer) 103 | 104 | train_queue, train_sampler = build_dataset(args, phase='train') 105 | valid_queue, valid_sampler = build_dataset(args, phase='valid') 106 | 107 | if args.resume: 108 | model, optimizer, strat_epoch, best_acc = load_checkpoint(model, args.resume, optimizer) 109 | logging.info("The network will resume training.") 110 | logging.info("Start Epoch: {}, Learning rate: {}, Best accuracy: {}".format(strat_epoch, [g['lr'] for g in 111 | optimizer.param_groups], 112 | round(best_acc, 4))) 113 | if args.resumelr: 114 | for g in optimizer.param_groups: g['lr'] = args.resumelr 115 | args.resume_scheduler = cosine_scheduler(args.resumelr, 1e-5, args.epochs - strat_epoch, len(train_queue)) 116 | args.resume_epoch = strat_epoch - 1 117 | 118 | else: 119 | strat_epoch = 0 120 | best_acc = 0.0 121 | args.resume_epoch = 0 122 | scheduler[0].last_epoch = strat_epoch 123 | 124 | 125 | if args.SYNC_BN and args.nprocs > 1: 126 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 127 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) 128 | if local_rank == 0: 129 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 130 | 131 | 132 | train_results = dict( 133 | train_score=[], 134 | train_loss=[], 135 | valid_score=[], 136 | valid_loss=[], 137 | best_score=0.0 138 | ) 139 | if args.eval_only: 140 | valid_acc, _, _, meter_dict = infer(valid_queue, model, criterion, local_rank, 0) 141 | valid_acc = max(meter_dict['Acc_all'].avg, meter_dict['Acc'].avg, meter_dict['Acc_3'].avg) 142 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}'.format(valid_acc, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg)) 143 | return 144 | 145 | #--------------------------- 146 | # Mixed Precision Training 147 | # -------------------------- 148 | if args.fp16: 149 | scaler = torch.cuda.amp.GradScaler() 150 | else: 151 | scaler = None 152 | for epoch in range(strat_epoch, args.epochs): 153 | train_sampler.set_epoch(epoch) 154 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs 155 | 156 | if epoch < args.scheduler['warm_up_epochs']: 157 | for g in optimizer.param_groups: 158 | g['lr'] = scheduler[-1](epoch) 159 | 160 | args.epoch = epoch 161 | train_acc, train_obj, meter_dict_train = train(train_queue, model, criterion, optimizer, epoch, local_rank, scaler) 162 | valid_acc, valid_obj, valid_dict, meter_dict_val = infer(valid_queue, model, criterion, local_rank, epoch) 163 | valid_acc = max(meter_dict_val['Acc_all'].avg, meter_dict_val['Acc'].avg, meter_dict_val['Acc_3'].avg) 164 | if epoch >= args.scheduler['warm_up_epochs']: 165 | if args.scheduler['name'] == 'ReduceLR': 166 | scheduler[0].step(valid_acc) 167 | else: 168 | scheduler[0].step() 169 | 170 | if local_rank == 0: 171 | if valid_acc > best_acc: 172 | best_acc = valid_acc 173 | isbest = True 174 | else: 175 | isbest = False 176 | logging.info('train_acc %f', train_acc) 177 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}, best acc: {}'.format(meter_dict_val['Acc'].avg, meter_dict_val['Acc_1'].avg, 178 | meter_dict_val['Acc_2'].avg, 179 | meter_dict_val['Acc_3'].avg, best_acc)) 180 | 181 | state = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'bestacc': best_acc} 182 | save_checkpoint(state, isbest, args.save) 183 | 184 | train_results['train_score'].append(train_acc) 185 | train_results['train_loss'].append(train_obj) 186 | train_results['valid_score'].append(valid_acc) 187 | train_results['valid_loss'].append(valid_obj) 188 | train_results['best_score'] = best_acc 189 | train_results.update(valid_dict) 190 | train_results['categories'] = np.unique(valid_dict['grounds']) 191 | 192 | if isbest: 193 | EvaluateMetric(PREDICTIONS_PATH=args.save, train_results=train_results, idx=epoch) 194 | for k, v in train_results.items(): 195 | if isinstance(v, list): 196 | v.clear() 197 | 198 | def train(train_queue, model, criterion, optimizer, epoch, local_rank, scaler): 199 | model.train() 200 | meter_dict = dict( 201 | Total_loss=AverageMeter(), 202 | MSE_loss=AverageMeter(), 203 | CE_loss=AverageMeter(), 204 | BCE_loss=AverageMeter(), 205 | Distill_loss = AverageMeter() 206 | 207 | ) 208 | meter_dict['Data_Time'] = AverageMeter() 209 | meter_dict.update(dict( 210 | Acc_1=AverageMeter(), 211 | Acc_2=AverageMeter(), 212 | Acc_3=AverageMeter(), 213 | Acc=AverageMeter() 214 | )) 215 | 216 | end = time.time() 217 | for step, (inputs, heatmap, target, _) in enumerate(train_queue): 218 | meter_dict['Data_Time'].update((time.time() - end)/args.batch_size) 219 | inputs, target, heatmap = map(lambda x: [d.cuda(local_rank, non_blocking=True) for d in x] if isinstance(x, list) else x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap]) 220 | 221 | if args.resumelr: 222 | for g in optimizer.param_groups: 223 | g['lr'] = args.resume_scheduler[len(train_queue) * args.resume_epoch + step] 224 | # --------------------------- 225 | # Mixed Precision Training 226 | # -------------------------- 227 | if args.fp16: 228 | print('Train with FP16') 229 | optimizer.zero_grad() 230 | # Runs the forward pass with autocasting. 231 | with torch.cuda.amp.autocast(): 232 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target) 233 | globals()['CE_loss'] = CE_loss 234 | globals()['MSE_loss'] = MSE_loss 235 | globals()['BCE_loss'] = BCE_loss 236 | globals()['Distill_loss'] = distillation 237 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation 238 | 239 | scaler.scale(Total_loss).backward() 240 | # Unscales the gradients of optimizer's assigned params in-place 241 | scaler.unscale_(optimizer) 242 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip) 243 | scaler.step(optimizer) 244 | scaler.update() 245 | else: 246 | # --------------------------- 247 | # Fp32 Precision Training 248 | # -------------------------- 249 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target) 250 | globals()['CE_loss'] = CE_loss 251 | globals()['MSE_loss'] = MSE_loss 252 | globals()['BCE_loss'] = BCE_loss 253 | globals()['Distill_loss'] = distillation 254 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation 255 | 256 | optimizer.zero_grad() 257 | Total_loss.backward() 258 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip) 259 | optimizer.step() 260 | 261 | #--------------------- 262 | # Meter performance 263 | #--------------------- 264 | torch.distributed.barrier() 265 | globals()['Acc'] = calculate_accuracy(logits, target) 266 | globals()['Acc_1'] = calculate_accuracy(logit_r, target) 267 | globals()['Acc_2'] = calculate_accuracy(logit_d, target) 268 | globals()['Acc_3'] = calculate_accuracy(logit_r+logit_d, target) 269 | 270 | for name in meter_dict: 271 | if 'loss' in name: 272 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 273 | if 'Acc' in name: 274 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 275 | 276 | if step % args.report_freq == 0 and local_rank == 0: 277 | 278 | log_info = { 279 | 'Epoch': '{}/{}'.format(epoch + 1, args.epochs), 280 | 'Mini-Batch': '{:0>5d}/{:0>5d}'.format(step + 1, 281 | len(train_queue.dataset) // (args.batch_size * args.nprocs)), 282 | 'Lr': ['{:.4f}'.format(g['lr']) for g in optimizer.param_groups], 283 | } 284 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items())) 285 | print_func(log_info) 286 | end = time.time() 287 | args.resume_epoch += 1 288 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, meter_dict 289 | 290 | @torch.no_grad() 291 | def concat_all_gather(tensor): 292 | """ 293 | Performs all_gather operation on the provided tensors. 294 | *** Warning ***: torch.distributed.all_gather has no gradient. 295 | """ 296 | tensors_gather = [torch.ones_like(tensor) 297 | for _ in range(torch.distributed.get_world_size())] 298 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 299 | 300 | output = torch.cat(tensors_gather, dim=0) 301 | return output 302 | 303 | @torch.no_grad() 304 | def infer(valid_queue, model, criterion, local_rank, epoch): 305 | model.eval() 306 | 307 | meter_dict = dict( 308 | Total_loss=AverageMeter(), 309 | MSE_loss=AverageMeter(), 310 | CE_loss=AverageMeter(), 311 | Distill_loss=AverageMeter() 312 | ) 313 | meter_dict.update(dict( 314 | Acc_1=AverageMeter(), 315 | Acc_2=AverageMeter(), 316 | Acc_3=AverageMeter(), 317 | Acc = AverageMeter(), 318 | Acc_all=AverageMeter(), 319 | )) 320 | 321 | meter_dict['Infer_Time'] = AverageMeter() 322 | grounds, preds, v_paths = [], [], [] 323 | for step, (inputs, heatmap, target, v_path) in enumerate(valid_queue): 324 | end = time.time() 325 | inputs, target, heatmap = map( 326 | lambda x: [d.cuda(local_rank, non_blocking=True) for d in x] if isinstance(x, list) else x.cuda(local_rank, 327 | non_blocking=True), 328 | [inputs, target, heatmap]) 329 | if args.fp16: 330 | with torch.cuda.amp.autocast(): 331 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target) 332 | else: 333 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target) 334 | meter_dict['Infer_Time'].update((time.time() - end) / args.test_batch_size) 335 | globals()['CE_loss'] = CE_loss 336 | globals()['MSE_loss'] = MSE_loss 337 | globals()['BCE_loss'] = BCE_loss 338 | globals()['Distill_loss'] = distillation 339 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation 340 | 341 | torch.distributed.barrier() 342 | globals()['Acc'] = calculate_accuracy(logits, target) 343 | globals()['Acc_1'] = calculate_accuracy(logit_r, target) 344 | globals()['Acc_2'] = calculate_accuracy(logit_d, target) 345 | globals()['Acc_3'] = calculate_accuracy(logit_r+logit_d, target) 346 | globals()['Acc_all'] = calculate_accuracy(logit_r+logit_d+logits, target) 347 | 348 | 349 | grounds += target.cpu().tolist() 350 | preds += torch.argmax(logits, dim=1).cpu().tolist() 351 | v_paths += v_path 352 | for name in meter_dict: 353 | if 'loss' in name: 354 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 355 | if 'Acc' in name: 356 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 357 | 358 | if step % args.report_freq == 0 and local_rank == 0: 359 | log_info = { 360 | 'Epoch': epoch + 1, 361 | 'Mini-Batch': '{:0>4d}/{:0>4d}'.format(step + 1, len(valid_queue.dataset) // ( 362 | args.test_batch_size * args.nprocs)), 363 | } 364 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items())) 365 | print_func(log_info) 366 | 367 | torch.distributed.barrier() 368 | grounds_gather = concat_all_gather(torch.tensor(grounds).cuda(local_rank)) 369 | preds_gather = concat_all_gather(torch.tensor(preds).cuda(local_rank)) 370 | grounds_gather, preds_gather = list(map(lambda x: x.cpu().numpy(), [grounds_gather, preds_gather])) 371 | 372 | if local_rank == 0: 373 | v_paths = np.array(v_paths) 374 | grounds = np.array(grounds) 375 | preds = np.array(preds) 376 | wrong_idx = np.where(grounds != preds) 377 | v_paths = v_paths[wrong_idx[0]] 378 | grounds = grounds[wrong_idx[0]] 379 | preds = preds[wrong_idx[0]] 380 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, dict(grounds=grounds_gather, preds=preds_gather, valid_images=(v_paths, grounds, preds)), meter_dict 381 | 382 | if __name__ == '__main__': 383 | try: 384 | main(args.local_rank, args.nprocs, args) 385 | except KeyboardInterrupt: 386 | torch.cuda.empty_cache() 387 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3: 388 | print(f'remove {args.save}: Directory') 389 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save)) 390 | os._exit(0) 391 | except Exception: 392 | print(traceback.print_exc()) 393 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3: 394 | print(f'remove {args.save}: Directory') 395 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save)) 396 | os._exit(0) 397 | finally: 398 | torch.cuda.empty_cache() 399 | -------------------------------------------------------------------------------- /tools/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import time 6 | import glob 7 | import numpy as np 8 | import shutil 9 | import cv2 10 | import os, random, math 11 | import sys 12 | sys.path.append(os.path.join('..', os.path.abspath(os.path.join(os.getcwd()))) ) 13 | 14 | import torch 15 | import utils 16 | import logging 17 | import argparse 18 | import traceback 19 | import torch.nn as nn 20 | import torch.utils 21 | import torchvision.datasets as dset 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | 25 | # import flops_benchmark 26 | from utils.visualizer import Visualizer 27 | from config import Config 28 | from lib import * 29 | from utils import * 30 | 31 | #------------------------ 32 | # evaluation metrics 33 | #------------------------ 34 | from sklearn.decomposition import PCA 35 | from sklearn import manifold 36 | import pandas as pd 37 | import matplotlib.pyplot as plt # For graphics 38 | import seaborn as sns 39 | from torchvision.utils import save_image, make_grid 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--config', help='Load Congfile.') 43 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?') 44 | parser.add_argument('--local_rank', type=int, default=0) 45 | parser.add_argument('--nprocs', type=int, default=1) 46 | 47 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?') 48 | parser.add_argument('--save_output', action='store_true', help='Save logits?') 49 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo') 50 | parser.add_argument('--resume', type=str, default='', help='resume model path.') 51 | 52 | parser.add_argument('--distill-lamdb', type=float, default=0.0, help='initial distillation loss weight') 53 | 54 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability') 55 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment dir') 56 | parser.add_argument('--seed', type=int, default=123, help='random seed') 57 | args = parser.parse_args() 58 | args = Config(args) 59 | 60 | try: 61 | if args.resume: 62 | args.save = os.path.split(args.resume)[0] 63 | else: 64 | args.save = '{}/{}-{}-{}-{}'.format(args.save, args.Network, args.dataset, args.type, time.strftime("%Y%m%d-%H%M%S")) 65 | utils.create_exp_dir(args.save, scripts_to_save=[args.config] + glob.glob('./tools/*.py') + glob.glob('./lib/*')) 66 | except: 67 | pass 68 | log_format = '%(asctime)s %(message)s' 69 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 70 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 71 | fh = logging.FileHandler(os.path.join(args.save, 'log{}.txt'.format(time.strftime("%Y%m%d-%H%M%S")))) 72 | fh.setFormatter(logging.Formatter(log_format)) 73 | logging.getLogger().addHandler(fh) 74 | 75 | 76 | def reduce_mean(tensor, nprocs): 77 | rt = tensor.clone() 78 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 79 | rt /= nprocs 80 | return rt.item() 81 | 82 | def main(local_rank, nprocs, args): 83 | if not torch.cuda.is_available(): 84 | logging.info('no gpu device available') 85 | sys.exit(1) 86 | 87 | np.random.seed(args.seed) 88 | cudnn.benchmark = True 89 | torch.manual_seed(args.seed) 90 | cudnn.enabled = True 91 | torch.cuda.manual_seed(args.seed) 92 | logging.info('gpu device = %d' % local_rank) 93 | 94 | #--------------------------- 95 | # Init distribution 96 | #--------------------------- 97 | torch.cuda.set_device(local_rank) 98 | torch.distributed.init_process_group(backend='nccl') 99 | 100 | #---------------------------- 101 | # build function 102 | #---------------------------- 103 | model = build_model(args) 104 | model = model.cuda(local_rank) 105 | 106 | criterion = build_loss(args) 107 | optimizer = build_optim(args, model) 108 | scheduler = build_scheduler(args, optimizer) 109 | 110 | train_queue, train_sampler = build_dataset(args, phase='train') 111 | valid_queue, valid_sampler = build_dataset(args, phase='valid') 112 | 113 | 114 | if args.resume: 115 | model, optimizer, strat_epoch, best_acc = load_checkpoint(model, args.resume, optimizer) 116 | logging.info("Start Epoch: {}, Learning rate: {}, Best accuracy: {}".format(strat_epoch, [g['lr'] for g in 117 | optimizer.param_groups], 118 | round(best_acc, 4))) 119 | if args.resumelr: 120 | for g in optimizer.param_groups: 121 | args.resumelr = g['lr'] if not isinstance(args.resumelr, float) else args.resumelr 122 | g['lr'] = args.resumelr 123 | #resume_scheduler = np.linspace(args.resumelr, 1e-5, args.epochs - strat_epoch) 124 | resume_scheduler = cosine_scheduler(args.resumelr, 1e-5, args.epochs - strat_epoch + 1, niter_per_ep=1).tolist() 125 | resume_scheduler.pop(0) 126 | 127 | args.epoch = strat_epoch - 1 128 | else: 129 | strat_epoch = 0 130 | best_acc = 0.0 131 | args.epoch = strat_epoch 132 | 133 | scheduler[0].last_epoch = strat_epoch 134 | 135 | if args.SYNC_BN and args.nprocs > 1: 136 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 137 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False) 138 | if local_rank == 0: 139 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 140 | # logging.info('FLOPs: {}'.format(flops_benchmark.count_flops(model))) 141 | 142 | train_results = dict( 143 | train_score=[], 144 | train_loss=[], 145 | valid_score=[], 146 | valid_loss=[], 147 | best_score=0.0 148 | ) 149 | if args.eval_only: 150 | valid_acc, _, _, meter_dict, output = infer(valid_queue, model, criterion, local_rank, strat_epoch) 151 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}'.format(valid_acc, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg)) 152 | if args.save_output: 153 | torch.save(output, os.path.join(args.save, '{}-output.pth'.format(args.type))) 154 | return 155 | 156 | for epoch in range(strat_epoch, args.epochs): 157 | train_sampler.set_epoch(epoch) 158 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs 159 | 160 | if epoch < args.scheduler['warm_up_epochs']-1: 161 | for g in optimizer.param_groups: 162 | g['lr'] = scheduler[-1](epoch) 163 | else: 164 | args.distill_lamdb = args.distill 165 | 166 | args.epoch = epoch 167 | train_acc, train_obj, meter_dict_train = train(train_queue, model, criterion, optimizer, epoch, local_rank) 168 | valid_acc, valid_obj, valid_dict, meter_dict_val, output = infer(valid_queue, model, criterion, local_rank, epoch) 169 | 170 | # scheduler_func.step(scheduler, valid_acc) 171 | if epoch >= args.scheduler['warm_up_epochs']: 172 | if args.resume and args.resumelr: 173 | for g in optimizer.param_groups: 174 | g['lr'] = resume_scheduler[0] 175 | resume_scheduler.pop(0) 176 | elif args.scheduler['name'] == 'ReduceLR': 177 | scheduler[0].step(valid_acc) 178 | else: 179 | scheduler[0].step() 180 | 181 | if local_rank == 0: 182 | if valid_acc > best_acc: 183 | best_acc = valid_acc 184 | isbest = True 185 | else: 186 | isbest = False 187 | logging.info('train_acc %f', train_acc) 188 | logging.info('valid_acc %f, best_acc %f', valid_acc, best_acc) 189 | state = {'model': model.module.state_dict(),'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'bestacc': best_acc} 190 | save_checkpoint(state, isbest, args.save) 191 | 192 | train_results['train_score'].append(train_acc) 193 | train_results['train_loss'].append(train_obj) 194 | train_results['valid_score'].append(valid_acc) 195 | train_results['valid_loss'].append(valid_obj) 196 | train_results['best_score'] = best_acc 197 | train_results.update(valid_dict) 198 | train_results['categories'] = np.unique(valid_dict['grounds']) 199 | 200 | if args.visdom['enable']: 201 | vis.plot_many({'train_acc': train_acc, 'loss': train_obj, 202 | 'cosin_similar': meter_dict_train['cosin_similar'].avg}, 'Train-' + args.type, epoch) 203 | vis.plot_many({'valid_acc': valid_acc, 'loss': valid_obj, 204 | 'cosin_similar': meter_dict_val['cosin_similar'].avg}, 'Valid-' + args.type, epoch) 205 | 206 | if isbest: 207 | if args.save_output: 208 | torch.save(output, os.path.join(args.save, '{}-output.pth'.format(args.type))) 209 | EvaluateMetric(PREDICTIONS_PATH=args.save, train_results=train_results, idx=epoch) 210 | for k, v in train_results.items(): 211 | if isinstance(v, list): 212 | v.clear() 213 | 214 | def Visfeature(inputs, feature, v_path=None): 215 | if args.visdom['enable']: 216 | vis.featuremap('CNNVision', 217 | torch.sum(make_grid(feature[0].detach(), nrow=int(feature[0].size(0) ** 0.5), padding=2), dim=0).flipud()) 218 | vis.featuremap('Attention Maps Similarity', 219 | make_grid(feature[1], nrow=int(feature[1].detach().cpu().size(0) ** 0.5), padding=2)[0].flipud()) 220 | 221 | vis.featuremap('Enhancement Weights', feature[3].flipud()) 222 | else: 223 | fig = plt.figure() 224 | ax = fig.add_subplot() 225 | sns.heatmap( 226 | torch.sum(make_grid(feature[0].detach(), nrow=int(feature[0].size(0) ** 0.5), padding=2), dim=0).cpu().numpy(), 227 | annot=False, fmt='g', ax=ax) 228 | ax.set_title('CNNVision', fontsize=10) 229 | fig.savefig(os.path.join(args.save, 'CNNVision.jpg'), dpi=fig.dpi) 230 | plt.close() 231 | 232 | fig = plt.figure() 233 | ax = fig.add_subplot() 234 | sns.heatmap(make_grid(feature[1].detach(), nrow=int(feature[1].size(0) ** 0.5), padding=2)[0].cpu().numpy(), annot=False, 235 | fmt='g', ax=ax) 236 | ax.set_title('Attention Maps Similarity', fontsize=10) 237 | fig.savefig(os.path.join(args.save, 'AttMapSimilarity.jpg'), dpi=fig.dpi) 238 | plt.close() 239 | 240 | fig = plt.figure() 241 | ax = fig.add_subplot() 242 | sns.heatmap(feature[3].detach().cpu().numpy(), annot=False, fmt='g', ax=ax) 243 | ax.set_title('Enhancement Weights', fontsize=10) 244 | fig.savefig(os.path.join(args.save, 'EnhancementWeights.jpg'), dpi=fig.dpi) 245 | plt.close() 246 | 247 | #------------------------------------------ 248 | # Spatial feature visualization 249 | #------------------------------------------ 250 | headmap = feature[-1][0][0,:].detach().cpu().numpy() 251 | headmap = np.mean(headmap, axis=0) 252 | headmap /= np.max(headmap) # torch.Size([64, 7, 7]) 253 | headmap = torch.from_numpy(headmap) 254 | img = feature[-1][1] 255 | 256 | result = [] 257 | for map, mg in zip(headmap.unsqueeze(1), img.permute(1,2,3,0)): 258 | map = cv2.resize(map.squeeze().cpu().numpy(), (mg.shape[0]//2, mg.shape[1]//2)) 259 | map = np.uint8(255 * map) 260 | map = cv2.applyColorMap(map, cv2.COLORMAP_JET) 261 | 262 | mg = np.uint8(mg.cpu().numpy() * 128 + 127.5) 263 | mg = cv2.resize(mg, (mg.shape[0]//2, mg.shape[1]//2)) 264 | superimposed_img = cv2.addWeighted(mg, 0.4, map, 0.6, 0) 265 | 266 | result.append(torch.from_numpy(superimposed_img).unsqueeze(0)) 267 | superimposed_imgs = torch.cat(result).permute(0, 3, 1, 2) 268 | # save_image(superimposed_imgs, os.path.join(args.save, 'CAM-Features.png'), nrow=int(superimposed_imgs.size(0) ** 0.5), padding=2).permute(1,2,0) 269 | superimposed_imgs = make_grid(superimposed_imgs, nrow=int(superimposed_imgs.size(0) ** 0.5), padding=2).permute(1,2,0) 270 | cv2.imwrite(os.path.join(args.save, 'CAM-Features.png'), superimposed_imgs.numpy()) 271 | 272 | if args.eval_only: 273 | MHAS_s, MHAS_m, MHAS_l = feature[4] 274 | MHAS_s, MHAS_m, MHAS_l = MHAS_s.detach().cpu(), MHAS_m.detach().cpu(), MHAS_l.detach().cpu() 275 | # Normalize 276 | att_max, index_max = torch.max(MHAS_s.view(MHAS_s.size(0), -1), dim=-1) 277 | att_min, index_min = torch.min(MHAS_s.view(MHAS_s.size(0), -1), dim=-1) 278 | MHAS_s = (MHAS_s - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1)) 279 | 280 | att_max, index_max = torch.max(MHAS_m.view(MHAS_m.size(0), -1), dim=-1) 281 | att_min, index_min = torch.min(MHAS_m.view(MHAS_m.size(0), -1), dim=-1) 282 | MHAS_m = (MHAS_m - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1)) 283 | 284 | att_max, index_max = torch.max(MHAS_l.view(MHAS_l.size(0), -1), dim=-1) 285 | att_min, index_min = torch.min(MHAS_l.view(MHAS_l.size(0), -1), dim=-1) 286 | MHAS_l = (MHAS_l - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1)) 287 | 288 | mhas_s = make_grid(MHAS_s.unsqueeze(1), nrow=int(MHAS_s.size(0) ** 0.5), padding=2)[0] 289 | mhas_m = make_grid(MHAS_m.unsqueeze(1), nrow=int(MHAS_m.size(0) ** 0.5), padding=2)[0] 290 | mhas_l = make_grid(MHAS_l.unsqueeze(1), nrow=int(MHAS_l.size(0) ** 0.5), padding=2)[0] 291 | if args.visdom['enable']: 292 | vis.featuremap('MHAS Map', mhas_l) 293 | 294 | fig = plt.figure(figsize=(20, 10)) 295 | ax = fig.add_subplot(131) 296 | sns.heatmap(mhas_s.squeeze(), annot=False, fmt='g', ax=ax) 297 | ax.set_title('\nMHSA Small', fontsize=10) 298 | 299 | ax = fig.add_subplot(132) 300 | sns.heatmap(mhas_m.squeeze(), annot=False, fmt='g', ax=ax) 301 | ax.set_title('\nMHSA Medium', fontsize=10) 302 | 303 | ax = fig.add_subplot(133) 304 | sns.heatmap(mhas_l.squeeze(), annot=False, fmt='g', ax=ax) 305 | ax.set_title('\nMHSA Large', fontsize=10) 306 | plt.suptitle('{}'.format(v_path[0].split('/')[-1]), fontsize=20) 307 | fig.savefig('demo/{}-MHAS.jpg'.format(args.save.split('/')[-1]), dpi=fig.dpi) 308 | plt.close() 309 | 310 | def train(train_queue, model, criterion, optimizer, epoch, local_rank): 311 | model.train() 312 | 313 | meter_dict = dict( 314 | Total_loss=AverageMeter(), 315 | CE_loss=AverageMeter(), 316 | Distil_loss=AverageMeter() 317 | ) 318 | meter_dict.update(dict( 319 | cosin_similar=AverageMeter() 320 | )) 321 | meter_dict['Data_Time'] = AverageMeter() 322 | meter_dict.update(dict( 323 | Acc_1=AverageMeter(), 324 | Acc_2=AverageMeter(), 325 | Acc_3=AverageMeter(), 326 | Acc=AverageMeter() 327 | )) 328 | 329 | end = time.time() 330 | CE = criterion 331 | for step, (inputs, heatmap, target, _) in enumerate(train_queue): 332 | meter_dict['Data_Time'].update((time.time() - end)/args.batch_size) 333 | inputs, target, heatmap = map(lambda x: x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap]) 334 | 335 | (logits, xs, xm, xl), distillation_loss, feature = model(inputs, heatmap) 336 | if args.MultiLoss: 337 | lamd1, lamd2, lamd3, lamd4 = map(float, args.loss_lamdb) 338 | globals()['CE_loss'] = lamd1*CE(logits, target) + lamd2*CE(xs, target) + lamd3*CE(xm, target) + lamd4*CE(xl, target) 339 | else: 340 | globals()['CE_loss'] = CE(logits, target) 341 | globals()['Distil_loss'] = distillation_loss * args.distill_lamdb 342 | globals()['Total_loss'] = CE_loss + Distil_loss 343 | 344 | optimizer.zero_grad() 345 | Total_loss.backward() 346 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip) 347 | optimizer.step() 348 | 349 | #--------------------- 350 | # Meter performance 351 | #--------------------- 352 | torch.distributed.barrier() 353 | globals()['Acc'] = calculate_accuracy(logits, target) 354 | globals()['Acc_1'] = calculate_accuracy(xs, target) 355 | globals()['Acc_2'] = calculate_accuracy(xm, target) 356 | globals()['Acc_3'] = calculate_accuracy(xl, target) 357 | 358 | for name in meter_dict: 359 | if 'loss' in name: 360 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 361 | if 'cosin' in name: 362 | meter_dict[name].update(float(feature[2])) 363 | if 'Acc' in name: 364 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 365 | 366 | if (step+1) % args.report_freq == 0 and local_rank == 0: 367 | log_info = { 368 | 'Epoch': '{}/{}'.format(epoch + 1, args.epochs), 369 | 'Mini-Batch': '{:0>5d}/{:0>5d}'.format(step + 1, 370 | len(train_queue.dataset) // (args.batch_size * args.nprocs)), 371 | 'Lr': [round(float(g['lr']), 7) for g in optimizer.param_groups], 372 | } 373 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items())) 374 | print_func(log_info) 375 | 376 | if args.vis_feature: 377 | Visfeature(inputs, feature) 378 | end = time.time() 379 | 380 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, meter_dict 381 | 382 | @torch.no_grad() 383 | def concat_all_gather(tensor): 384 | """ 385 | Performs all_gather operation on the provided tensors. 386 | *** Warning ***: torch.distributed.all_gather has no gradient. 387 | """ 388 | tensors_gather = [torch.ones_like(tensor) 389 | for _ in range(torch.distributed.get_world_size())] 390 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 391 | output = torch.cat(tensors_gather, dim=0) 392 | return output 393 | 394 | @torch.no_grad() 395 | def infer(valid_queue, model, criterion, local_rank, epoch): 396 | model.eval() 397 | 398 | meter_dict = dict( 399 | Total_loss=AverageMeter(), 400 | CE_loss=AverageMeter(), 401 | Distil_loss=AverageMeter() 402 | ) 403 | meter_dict.update(dict( 404 | cosin_similar=AverageMeter(), 405 | )) 406 | meter_dict.update(dict( 407 | Acc_1=AverageMeter(), 408 | Acc_2=AverageMeter(), 409 | Acc_3=AverageMeter(), 410 | Acc=AverageMeter() 411 | )) 412 | 413 | meter_dict['Infer_Time'] = AverageMeter() 414 | CE = criterion 415 | grounds, preds, v_paths = [], [], [] 416 | output = {} 417 | for step, (inputs, heatmap, target, v_path) in enumerate(valid_queue): 418 | n = inputs.size(0) 419 | end = time.time() 420 | inputs, target, heatmap = map(lambda x: x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap]) 421 | 422 | (xs, xm, xl, logits), distillation_loss, feature = model(inputs, heatmap) 423 | meter_dict['Infer_Time'].update((time.time() - end) / n) 424 | 425 | if args.MultiLoss: 426 | lamd1, lamd2, lamd3, lamd4 = map(float, args.loss_lamdb) 427 | globals()['CE_loss'] = lamd1 * CE(logits, target) + lamd2 * CE(xs, target) + lamd3 * CE(xm, 428 | target) + lamd4 * CE( 429 | xl, target) 430 | else: 431 | globals()['CE_loss'] = CE(logits, target) 432 | globals()['Distil_loss'] = distillation_loss * args.distill_lamdb 433 | globals()['Total_loss'] = CE_loss + Distil_loss 434 | 435 | grounds += target.cpu().tolist() 436 | preds += torch.argmax(logits, dim=1).cpu().tolist() 437 | v_paths += v_path 438 | torch.distributed.barrier() 439 | globals()['Acc'] = calculate_accuracy(logits, target) 440 | globals()['Acc_1'] = calculate_accuracy(xs+xm, target) 441 | globals()['Acc_2'] = calculate_accuracy(xs+xl, target) 442 | globals()['Acc_3'] = calculate_accuracy(xl+xm, target) 443 | 444 | for name in meter_dict: 445 | if 'loss' in name: 446 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 447 | if 'cosin' in name: 448 | meter_dict[name].update(float(feature[2])) 449 | if 'Acc' in name: 450 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs)) 451 | 452 | if step % args.report_freq == 0 and local_rank == 0: 453 | log_info = { 454 | 'Epoch': epoch + 1, 455 | 'Mini-Batch': '{:0>4d}/{:0>4d}'.format(step + 1, len(valid_queue.dataset) // ( 456 | args.test_batch_size * args.nprocs)), 457 | } 458 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items())) 459 | print_func(log_info) 460 | if args.vis_feature: 461 | Visfeature(inputs, feature, v_path) 462 | 463 | if args.save_output: 464 | for t, logit in zip(v_path, logits): 465 | output[t] = logit 466 | torch.distributed.barrier() 467 | grounds_gather = concat_all_gather(torch.tensor(grounds).cuda(local_rank)) 468 | preds_gather = concat_all_gather(torch.tensor(preds).cuda(local_rank)) 469 | grounds_gather, preds_gather = list(map(lambda x: x.cpu().numpy(), [grounds_gather, preds_gather])) 470 | 471 | if local_rank == 0: 472 | # v_paths = np.array(v_paths)[random.sample(list(wrong), 10)] 473 | v_paths = np.array(v_paths) 474 | grounds = np.array(grounds) 475 | preds = np.array(preds) 476 | wrong_idx = np.where(grounds != preds) 477 | v_paths = v_paths[wrong_idx[0]] 478 | grounds = grounds[wrong_idx[0]] 479 | preds = preds[wrong_idx[0]] 480 | return max(meter_dict['Acc'].avg, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg), meter_dict['Total_loss'].avg, dict(grounds=grounds_gather, preds=preds_gather, valid_images=(v_paths, grounds, preds)), meter_dict, output 481 | 482 | if __name__ == '__main__': 483 | if args.visdom['enable']: 484 | vis = Visualizer(args.visdom['visname']) 485 | try: 486 | main(args.local_rank, args.nprocs, args) 487 | except KeyboardInterrupt: 488 | torch.cuda.empty_cache() 489 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3: 490 | print('remove {}: Directory'.format(args.save)) 491 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save)) 492 | os._exit(0) 493 | except Exception: 494 | print(traceback.print_exc()) 495 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3: 496 | print('remove {}: Directory'.format(args.save)) 497 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save)) 498 | os._exit(0) 499 | finally: 500 | torch.cuda.empty_cache() 501 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | from .print_function import print_func 6 | from .build import * 7 | from .evaluate_metric import EvaluateMetric 8 | from .utils import * 9 | -------------------------------------------------------------------------------- /utils/build.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import torch 6 | import math 7 | import torch.nn.functional as F 8 | from .utils import cosine_scheduler 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | 13 | class LabelSmoothingCrossEntropy(torch.nn.Module): 14 | def __init__(self, smoothing: float = 0.1, 15 | reduction="mean", weight=None): 16 | super(LabelSmoothingCrossEntropy, self).__init__() 17 | self.smoothing = smoothing 18 | self.reduction = reduction 19 | self.weight = weight 20 | 21 | def reduce_loss(self, loss): 22 | return loss.mean() if self.reduction == 'mean' else loss.sum() \ 23 | if self.reduction == 'sum' else loss 24 | 25 | def linear_combination(self, x, y): 26 | return self.smoothing * x + (1 - self.smoothing) * y 27 | 28 | def forward(self, preds, target): 29 | assert 0 <= self.smoothing < 1 30 | 31 | if self.weight is not None: 32 | self.weight = self.weight.to(preds.device) 33 | 34 | n = preds.size(-1) 35 | log_preds = F.log_softmax(preds, dim=-1) 36 | loss = self.reduce_loss(-log_preds.sum(dim=-1)) 37 | nll = F.nll_loss( 38 | log_preds, target, reduction=self.reduction, weight=self.weight 39 | ) 40 | return self.linear_combination(loss / n, nll) 41 | 42 | def build_optim(args, model): 43 | if args.optim == 'SGD': 44 | optimizer = torch.optim.SGD( 45 | model.parameters(), 46 | args.learning_rate, 47 | momentum=args.momentum, 48 | weight_decay=args.weight_decay 49 | ) 50 | elif args.optim == 'Adam': 51 | optimizer = torch.optim.Adam( 52 | model.parameters(), 53 | lr=args.learning_rate 54 | ) 55 | elif args.optim == 'AdamW': 56 | optimizer = torch.optim.AdamW( 57 | model.parameters(), 58 | lr=args.learning_rate 59 | ) 60 | return optimizer 61 | # 62 | def build_scheduler(args, optimizer): 63 | if args.scheduler['name'] == 'cosin': 64 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 65 | optimizer, float(args.epochs-args.scheduler['warm_up_epochs']), eta_min=args.learning_rate_min) 66 | elif args.scheduler['name'] == 'ReduceLR': 67 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, 68 | patience=args.scheduler['patience'], verbose=True, 69 | threshold=0.0001, 70 | threshold_mode='rel', cooldown=3, min_lr=0.00001, 71 | eps=1e-08) 72 | else: 73 | raise NameError('build scheduler error!') 74 | 75 | if args.scheduler['warm_up_epochs'] > 0: 76 | warmup_schedule = lambda epoch: np.linspace(1e-8, args.learning_rate, args.scheduler['warm_up_epochs'])[epoch] 77 | return (scheduler, warmup_schedule) 78 | return (scheduler,) 79 | 80 | def build_loss(args): 81 | loss_Function=dict( 82 | CE_smooth = LabelSmoothingCrossEntropy(), 83 | CE = torch.nn.CrossEntropyLoss(), 84 | MSE = torch.nn.MSELoss(), 85 | BCE = torch.nn.BCELoss(), 86 | ) 87 | if args.loss['name'] == 'CE' and args.loss['labelsmooth']: 88 | return loss_Function['CE_smooth'] 89 | return loss_Function[args.loss['name']] 90 | -------------------------------------------------------------------------------- /utils/evaluate_metric.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | # ------------------- 6 | # import modules 7 | # ------------------- 8 | import random, os 9 | import numpy as np 10 | import cv2 11 | import heapq 12 | import shutil 13 | from textwrap import wrap 14 | 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | import matplotlib.image as mpimage 18 | 19 | from sklearn.model_selection import train_test_split 20 | from sklearn.metrics import confusion_matrix, auc, roc_curve, roc_auc_score 21 | import seaborn as sns 22 | from torchvision import transforms 23 | from PIL import Image 24 | import torch 25 | from torchvision.utils import save_image, make_grid 26 | # --------------------------------------- 27 | # Plot accuracy and loss 28 | # --------------------------------------- 29 | def get_error_bar(best_score, valid_examples): 30 | print("--------------------------------------------") 31 | print("Standard Error") # best_score: Average al of scores, valid_examples: num of all samples 32 | print("--------------------------------------------") 33 | 34 | err = np.sqrt((best_score * (1 - best_score)) / valid_examples) 35 | err_rounded_68 = round(err, 2) 36 | err_rounded_95 = round((err_rounded_68 * 2), 2) 37 | 38 | print('Error (68% CI): +- ' + str(err_rounded_68)) 39 | print('Error (95% CI): +- ' + str(err_rounded_95)) 40 | print() 41 | return err_rounded_68 42 | 43 | def plot_train_results(PREDICTIONS_PATH, train_results, idx): 44 | ''' 45 | 46 | :param PREDICTIONS_PATH: plot image save path 47 | :param train_results: {'valid_score':[...], 'valid_loss':[...], 'train_score': [...], 'train_loss':[...]} 48 | :param best_score: validation best acc 49 | :param idx: epoch index 50 | :return: None 51 | ''' 52 | 53 | # best_score = sum(train_results['valid_score']) / len(train_results['valid_score']) 54 | valid_examples = len(train_results['valid_score']) 55 | super_category = str(idx) 56 | 57 | best_score = train_results["best_score"] 58 | standard_error = get_error_bar(best_score, valid_examples) 59 | y_upper = train_results["valid_score"] + standard_error 60 | y_lower = train_results["valid_score"] - standard_error 61 | 62 | print("--------------------------------------------") 63 | print("Results") 64 | print("--------------------------------------------") 65 | 66 | fig = plt.figure(figsize=(15, 5)) 67 | 68 | plt.subplot(1, 2, 1) 69 | plt.plot(range(0, len(train_results["train_score"])), train_results["train_score"], label='train') 70 | 71 | plt.plot(range(0, len(train_results["valid_score"])), train_results["valid_score"], label='valid') 72 | 73 | kwargs = {'color': 'black', 'linewidth': 1, 'linestyle': '--', 'dashes': (5, 5)} 74 | plt.plot(range(0, len(train_results["valid_score"])), y_lower, **kwargs) 75 | plt.plot(range(0, len(train_results["valid_score"])), y_upper, **kwargs, label='validation SE (68% CI)') 76 | 77 | plt.title('Accuracy Plot - ' + super_category, fontsize=20) 78 | plt.ylabel('Accuracy', fontsize=16) 79 | plt.xlabel('Training Epochs', fontsize=16) 80 | plt.ylim(0, 1) 81 | plt.legend() 82 | 83 | plt.subplot(1, 2, 2) 84 | plt.plot(range(0, len(train_results["train_loss"])), train_results["train_loss"], label='train') 85 | plt.plot(range(0, len(train_results["valid_loss"])), train_results["valid_loss"], label='valid') 86 | 87 | plt.title('Loss Plot - ' + super_category, fontsize=20) 88 | plt.ylabel('Loss', fontsize=16) 89 | plt.xlabel('Training Epochs', fontsize=16) 90 | max_train_loss = max(train_results["train_loss"]) 91 | max_valid_loss = max(train_results["valid_loss"]) 92 | y_max_t_v = max_valid_loss if max_valid_loss > max_train_loss else max_train_loss 93 | ylim_loss = y_max_t_v if y_max_t_v > 1 else 1 94 | plt.ylim(0, ylim_loss) 95 | plt.legend() 96 | 97 | plt.show() 98 | 99 | fig.savefig(os.path.join(PREDICTIONS_PATH, "train_results_{}.png".format(idx)), dpi=fig.dpi) 100 | 101 | 102 | # --------------------------------------- 103 | # Plot Confusion Matrix 104 | # --------------------------------------- 105 | def plot_confusion_matrix(PREDICTIONS_PATH, grounds, preds, categories, idx, top=20): 106 | print("--------------------------------------------") 107 | print("Confusion Matrix") 108 | print("--------------------------------------------") 109 | 110 | super_category = str(idx) 111 | num_cat = [] 112 | for ind, cat in enumerate(categories): 113 | print("Class {0} : {1}".format(ind, cat)) 114 | num_cat.append(ind) 115 | print() 116 | numclass = len(num_cat) 117 | 118 | cm = confusion_matrix(grounds, preds, labels=num_cat) 119 | fig = plt.figure(figsize=(10, 8)) 120 | ax = fig.add_subplot() 121 | sns.heatmap(cm, annot=False, fmt='g', ax=ax); # annot=True to annotate cells, ftm='g' to disable scientific notation 122 | 123 | # labels, title and ticks 124 | ax.set_title('Confusion Matrix - ' + super_category, fontsize=20) 125 | ax.set_xlabel('Predicted labels', fontsize=16) 126 | ax.set_ylabel('True labels', fontsize=16) 127 | 128 | ax.set_xticks(range(0,len(num_cat), 1)) 129 | ax.set_yticks(range(0,len(num_cat), 1)) 130 | ax.xaxis.set_ticklabels(num_cat) 131 | ax.yaxis.set_ticklabels(num_cat) 132 | 133 | plt.pause(0.1) 134 | fig.savefig(os.path.join(PREDICTIONS_PATH, "confusion_matrix"), dpi=fig.dpi) 135 | 136 | # ------------------------------------------------- 137 | # Plot Accuracy and Precision 138 | # ------------------------------------------------- 139 | Accuracy = [(cm[i, i] / sum(cm[i, :])) * 100 if sum(cm[i, :]) != 0 else 0.000001 for i in range(cm.shape[0])] 140 | Precision = [(cm[i, i] / sum(cm[:, i])) * 100 if sum(cm[:, i]) != 0 else 0.000001 for i in range(cm.shape[1])] 141 | 142 | fig = plt.figure(figsize=(int((numclass*3)%300), 8)) 143 | ax = fig.add_subplot() 144 | 145 | bar_width = 0.4 146 | x = np.arange(len(Accuracy)) 147 | b1 = ax.bar(x, Accuracy, width=bar_width, label='Accuracy', color=sns.xkcd_rgb["pale red"], tick_label=x) 148 | 149 | ax2 = ax.twinx() 150 | b2 = ax2.bar(x + bar_width, Precision, width=bar_width, label='Precision', color=sns.xkcd_rgb["denim blue"]) 151 | 152 | average_acc = sum(Accuracy)/len(Accuracy) 153 | average_prec = sum(Precision)/len(Precision) 154 | b3 = plt.hlines(y=average_acc, xmin=-bar_width, xmax=numclass - 1 + bar_width * 2, linewidth=2, linestyles='--', color='r', 155 | label='Average Acc : %0.2f' % average_acc) 156 | b4 = plt.hlines(y=average_prec, xmin=-bar_width, xmax=numclass - 1 + bar_width * 2, linewidth=2, linestyles='--', color='b', 157 | label='Average Prec : %0.2f' % average_prec) 158 | plt.xticks(np.arange(numclass) + bar_width / 2, np.arange(numclass)) 159 | 160 | # labels, title and ticks 161 | ax.set_title('Accuracy and Precision Epoch #{}'.format(idx), fontsize=20) 162 | ax.set_xlabel('labels', fontsize=16) 163 | ax.set_ylabel('Acc(%)', fontsize=16) 164 | ax2.set_ylabel('Prec(%)', fontsize=16) 165 | ax.set_xticklabels(ax.get_xticklabels(), rotation=45) 166 | 167 | ax.tick_params(axis='y', colors=b1[0].get_facecolor()) 168 | ax2.tick_params(axis='y', colors=b2[0].get_facecolor()) 169 | 170 | plt.legend(handles=[b1, b2, b3, b4]) 171 | # fig.savefig(os.path.join(PREDICTIONS_PATH, "Accuracy-Precision_{}.png".format(idx)), dpi=fig.dpi) 172 | fig.savefig(os.path.join(PREDICTIONS_PATH, "Accuracy-Precision.png"), dpi=fig.dpi) 173 | 174 | plt.close() 175 | 176 | TopK_idx_acc = heapq.nlargest(top, range(len(Accuracy)), Accuracy.__getitem__) 177 | TopK_idx_prec = heapq.nlargest(top, range(len(Precision)), Precision.__getitem__) 178 | 179 | TopK_low_idx = heapq.nsmallest(top, range(len(Precision)), Precision.__getitem__) 180 | 181 | 182 | print('=' * 80) 183 | print('Accuracy Tok {0}: \n'.format(top)) 184 | print('| Class ID \t Accuracy(%) \t Precision(%) |') 185 | for i in TopK_idx_acc: 186 | print('| {0} \t {1} \t {2} |'.format(i, round(Accuracy[i], 2), round(Precision[i], 2))) 187 | print('-' * 80) 188 | print('Precision Tok {0}: \n'.format(top)) 189 | print('| Class ID \t Accuracy(%) \t Precision(%) |') 190 | for i in TopK_idx_prec: 191 | print('| {0} \t {1} \t {2} |'.format(i, round(Accuracy[i], 2), round(Precision[i], 2))) 192 | print('=' * 80) 193 | 194 | return TopK_low_idx 195 | 196 | 197 | # Fast Rank Pooling 198 | sample_size = 128 199 | def GenerateRPImage(imgs_path, sl): 200 | def get_DDI(video_arr): 201 | def get_w(N): 202 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)] 203 | 204 | w_arr = get_w(len(video_arr)) 205 | re = np.zeros((sample_size, sample_size, 3)) 206 | for a, b in zip(video_arr, w_arr): 207 | img = cv2.imread(os.path.join(imgs_path, "%06d.jpg" % a)) 208 | img = cv2.resize(img, (sample_size, sample_size)) 209 | re += img * b 210 | re -= np.min(re) 211 | re = 255.0 * re / np.max(re) if np.max(re) != 0 else 255.0 * re / (np.max(re) + 0.00001) 212 | 213 | return re.astype('uint8') 214 | 215 | return get_DDI(sl) 216 | 217 | # --------------------------------------- 218 | # Wrongly Classified Images 219 | # --------------------------------------- 220 | def plot_wrongly_classified_images(PREDICTIONS_PATH, TopK_low_idx, valid_images, idx): 221 | print("--------------------------------------------") 222 | print("Wrongly Classified Images") 223 | print("--------------------------------------------") 224 | 225 | v_paths, grounds, preds = valid_images 226 | f = lambda n, sn: [(lambda n, arr: n if arr == [] else int(np.mean(arr)))(n * i / sn, range(int(n * i / sn), 227 | max(int( 228 | n * i / sn) + 1, 229 | int(n * ( 230 | i + 1) / sn)))) 231 | for i in range(sn)] 232 | 233 | train_images = [] 234 | ground, pred, pred_lbl_file = [], [], [] 235 | for g, p, v in zip(grounds, preds, v_paths): 236 | assert p != g, 'Pred: {} equ to ground-truth: {}'.format(p, g) 237 | if g in TopK_low_idx[:10]: 238 | imgs = [transforms.ToTensor()(Image.open(os.path.join(v, "%06d.jpg" % a)).resize((200, 200))).unsqueeze(0) for a in f(len(os.listdir(v))//2, 10)] 239 | train_images.append(make_grid(torch.cat(imgs), nrow=10, padding=2).permute(1, 2, 0)) 240 | ground.append(g) 241 | pred.append(p) 242 | pred_lbl_file.append(v) 243 | if len(train_images) > 9: 244 | break 245 | 246 | fig = plt.figure(figsize=(30, 20)) 247 | k = 0 248 | for i in range(0, len(train_images)): 249 | fig.add_subplot(10, 1, k + 1) 250 | plt.axis('off') 251 | if i == 0: 252 | title = "Orig lbl: " + str(ground[i]) + " Pred lbl: " + str(pred[i]) + " " + pred_lbl_file[i] 253 | else: 254 | title = '\n'*10 + "Orig lbl: " + str(ground[i]) + " Pred lbl: " + str(pred[i]) + " " + pred_lbl_file[i] 255 | plt.title(title) 256 | plt.imshow(train_images[i]) 257 | k += 1 258 | 259 | plt.pause(0.1) 260 | print() 261 | fig.savefig(os.path.join(PREDICTIONS_PATH, "wrongly_classified_images.png".format(idx)), dpi=fig.dpi) 262 | plt.close() 263 | 264 | def EvaluateMetric(PREDICTIONS_PATH, train_results, idx): 265 | TopK_low_idx = plot_confusion_matrix(PREDICTIONS_PATH, train_results['grounds'], train_results['preds'], train_results['categories'], idx) 266 | -------------------------------------------------------------------------------- /utils/print_function.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited. 3 | ''' 4 | 5 | import logging 6 | 7 | def print_func(info): 8 | ''' 9 | :param info: {name: value} 10 | :return: 11 | ''' 12 | txts = [] 13 | for name, value in info.items(): 14 | txts.append('{}: {}'.format(name, value)) 15 | logging.info('\t'.join(txts)) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/yuhuixu1993/PC-DARTS/blob/master/utils.py 4 | ''' 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import shutil 10 | import torchvision.transforms as transforms 11 | from torch.autograd import Variable 12 | from collections import OrderedDict 13 | 14 | class ClassAcc(): 15 | def __init__(self, GESTURE_CLASSES): 16 | self.class_acc = dict(zip([i for i in range(GESTURE_CLASSES)], [0]*GESTURE_CLASSES)) 17 | self.single_class_num = [0]*GESTURE_CLASSES 18 | def update(self, logits, target): 19 | pred = torch.argmax(logits, dim=1) 20 | for p, t in zip(pred.cpu().numpy(), target.cpu().numpy()): 21 | if p == t: 22 | self.class_acc[t] += 1 23 | self.single_class_num[t] += 1 24 | def result(self): 25 | return [round(v / (self.single_class_num[k]+0.000000001), 4) for k, v in self.class_acc.items()] 26 | class AverageMeter(object): 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.avg = 0 33 | self.sum = 0 34 | self.cnt = 0 35 | 36 | def update(self, val, n=1): 37 | self.sum += val * n 38 | self.cnt += n 39 | self.avg = self.sum / self.cnt 40 | 41 | def adjust_learning_rate(optimizer, step, lr): 42 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 43 | df = 0.7 44 | ds = 40000.0 45 | lr = lr * np.power(df, step / ds) 46 | # lr = args.lr * (0.1**(epoch // 30)) 47 | for param_group in optimizer.param_groups: 48 | param_group['lr'] = lr 49 | return lr 50 | 51 | def accuracy(output, target, topk=(1,)): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0) 62 | res.append(correct_k.mul_(100.0/batch_size)) 63 | return res 64 | 65 | def calculate_accuracy(outputs, targets): 66 | with torch.no_grad(): 67 | batch_size = targets.size(0) 68 | _, pred = outputs.topk(1, 1, True) 69 | pred = pred.t() 70 | correct = pred.eq(targets.view(1, -1)) 71 | correct_k = correct.view(-1).float().sum(0, keepdim=True) 72 | #n_correct_elems = correct.float().sum().data[0] 73 | # n_correct_elems = correct.float().sum().item() 74 | # return n_correct_elems / batch_size 75 | return correct_k.mul_(1.0 / batch_size) 76 | 77 | def count_parameters_in_MB(model): 78 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 79 | 80 | 81 | def save_checkpoint(state, is_best=False, save='./', filename='checkpoint.pth.tar'): 82 | filename = os.path.join(save, filename) 83 | torch.save(state, filename) 84 | if is_best: 85 | best_filename = os.path.join(save, 'model_best.pth.tar') 86 | shutil.copyfile(filename, best_filename) 87 | 88 | def load_checkpoint(model, model_path, optimizer=None): 89 | # checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(4)) 90 | checkpoint = torch.load(model_path, map_location='cpu') 91 | model.load_state_dict(checkpoint['model']) 92 | if optimizer: 93 | optimizer.load_state_dict(checkpoint['optimizer']) 94 | epoch = checkpoint['epoch'] 95 | bestacc = checkpoint['bestacc'] 96 | return model, optimizer, epoch, bestacc 97 | 98 | def load_pretrained_checkpoint(model, model_path): 99 | # params = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(local_rank))['model'] 100 | params = torch.load(model_path, map_location='cpu')['model'] 101 | new_state_dict = OrderedDict() 102 | 103 | for k, v in params.items(): 104 | name = k[7:] if k[:7] == 'module.' else k 105 | try: 106 | if v.shape == model.state_dict()[name].shape: 107 | if name not in ['dtn.mlp_head_small.1.bias', "dtn.mlp_head_small.1.weight", 108 | 'dtn.mlp_head_media.1.bias', "dtn.mlp_head_media.1.weight", 109 | 'dtn.mlp_head_large.1.bias', "dtn.mlp_head_large.1.weight"]: 110 | new_state_dict[name] = v 111 | except: 112 | continue 113 | ret = model.load_state_dict(new_state_dict, strict=False) 114 | print('Missing keys: \n', ret.missing_keys) 115 | return model 116 | 117 | def drop_path(x, drop_prob): 118 | if drop_prob > 0.: 119 | keep_prob = 1.-drop_prob 120 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 121 | x.div_(keep_prob) 122 | x.mul_(mask) 123 | return x 124 | 125 | 126 | def create_exp_dir(path, scripts_to_save=None): 127 | if not os.path.exists(path): 128 | os.mkdir(path) 129 | print('Experiment dir : {}'.format(path)) 130 | 131 | if scripts_to_save is not None: 132 | os.mkdir(os.path.join(path, 'scripts')) 133 | for script in scripts_to_save: 134 | if os.path.isdir(script) and script != '__pycache__': 135 | dst_file = os.path.join(path, 'scripts', script) 136 | shutil.copytree(script, dst_file) 137 | else: 138 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 139 | shutil.copyfile(script, dst_file) 140 | 141 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 142 | warmup_schedule = np.array([]) 143 | warmup_iters = warmup_epochs * niter_per_ep 144 | if warmup_epochs > 0: 145 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 146 | 147 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 148 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 149 | 150 | schedule = np.concatenate((warmup_schedule, schedule)) 151 | assert len(schedule) == epochs * niter_per_ep 152 | return schedule -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from: 3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/utils/visualizer.py 4 | ''' 5 | 6 | 7 | #coding: utf8 8 | 9 | import numpy as np 10 | import time 11 | 12 | 13 | class Visualizer(): 14 | def __init__(self, env='default', **kwargs): 15 | import visdom 16 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) 17 | 18 | self.index = {} 19 | self.log_text = '' 20 | 21 | def reinit(self, env='defult', **kwargs): 22 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) 23 | return self 24 | 25 | def plot_many(self, d, modality, epoch=None): 26 | colmu_stac = [] 27 | for k, v in d.items(): 28 | colmu_stac.append(np.array(v)) 29 | if epoch: 30 | x = epoch 31 | else: 32 | x = self.index.get(modality, 0) 33 | # self.vis.line(Y=np.column_stack((np.array(dicts['loss1']), np.array(dicts['loss2']))), 34 | self.vis.line(Y=np.column_stack(tuple(colmu_stac)), 35 | X=np.array([x]), 36 | win=(modality), 37 | # opts=dict(title=modality,legend=['loss1', 'loss2'], ylabel='loss value'), 38 | opts=dict(title=modality, legend=list(d.keys()), ylabel='Value', xlabel='Iteration'), 39 | update=None if x == 0 else 'append') 40 | if not epoch: 41 | self.index[modality] = x + 1 42 | 43 | def plot(self, name, y): 44 | """ 45 | self.plot('loss',1.00) 46 | """ 47 | x = self.index.get(name, 0) 48 | self.vis.line(Y=np.array([y]), X=np.array([x]), 49 | win=(name), 50 | opts=dict(title=name), 51 | update=None if x == 0 else 'append' 52 | ) 53 | self.index[name] = x + 1 54 | 55 | def log(self, info, win='log_text'): 56 | """ 57 | self.log({'loss':1,'lr':0.0001}) 58 | """ 59 | 60 | self.log_text += ('[{time}] {info}
'.format( 61 | time=time.strftime('%m.%d %H:%M:%S'), 62 | info=info)) 63 | self.vis.text(self.log_text, win=win) 64 | 65 | def img_grid(self, name, input_3d, heatmap=False): 66 | self.vis.images( 67 | # np.random.randn(20, 3, 64, 64), 68 | show_image_grid(input_3d, name, heatmap), 69 | win=name, 70 | opts=dict(title=name, caption='img_grid.') 71 | ) 72 | def img(self, name, input): 73 | self.vis.images( 74 | input, 75 | win=name, 76 | opts=dict(title=name, caption='RGB Images.') 77 | ) 78 | 79 | def draw_curve(self, name, data): 80 | self.vis.line(Y=np.array(data), X=np.array(range(len(data))), 81 | win=(name), 82 | opts=dict(title=name), 83 | update=None 84 | ) 85 | 86 | def featuremap(self, name, input): 87 | self.vis.heatmap(input, win=name, opts=dict(title=name)) 88 | 89 | def draw_bar(self, name, inp): 90 | self.vis.bar( 91 | X=np.abs(np.array(inp)), 92 | win=name, 93 | opts=dict( 94 | stacked=True, 95 | legend=list(map(str, range(inp.shape[-1]))), 96 | rownames=list(map(str, range(inp.shape[0]))) 97 | ) 98 | ) 99 | 100 | 101 | def __getattr__(self, name): 102 | return getattr(self.vis, name) 103 | --------------------------------------------------------------------------------