├── .gitignore
├── Checkpoints
└── readme.md
├── LICENSE
├── README.md
├── config
├── IsoGD.yml
├── Jester.yml
├── NTU.yml
├── NetworkConfig.yml
├── NvGesture.yml
├── THU.yml
├── __init__.py
└── config.py
├── data
└── data_preprose.py
├── demo
├── decouple_recouple.jpg
├── pipline.jpg
└── readme.md
├── lib
├── __init__.py
├── datasets
│ ├── IsoGD.py
│ ├── Jester.py
│ ├── NTU.py
│ ├── NvGesture.py
│ ├── THU_READ.py
│ ├── __init__.py
│ ├── base.py
│ ├── build.py
│ └── distributed_sampler.py
└── model
│ ├── DSN.py
│ ├── DSN_Fusion.py
│ ├── DTN.py
│ ├── FRP.py
│ ├── __init__.py
│ ├── build.py
│ ├── fusion_Net.py
│ ├── trans_module.py
│ └── utils.py
├── run.sh
├── tools
├── fusion.py
├── readme.md
└── train.py
└── utils
├── __init__.py
├── build.py
├── evaluate_metric.py
├── print_function.py
├── utils.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | **/__pycache__/**
3 | .dumbo.json
4 | Checkpoints/
5 | out/
6 | demo/
7 | *.tar
8 | core*
9 | bk/
10 | *.ipynb
11 | *.ipynb*
--------------------------------------------------------------------------------
/Checkpoints/readme.md:
--------------------------------------------------------------------------------
1 | This folder is necessary because it is used to save all training logs and models.
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 DamoCV
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR2022](https://openaccess.thecvf.com/content/CVPR2022/html/Zhou_Decoupling_and_Recoupling_Spatiotemporal_Representation_for_RGB-D-Based_Motion_Recognition_CVPR_2022_paper.html) Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition
2 |
3 | [](https://paperswithcode.com/sota/hand-gesture-recognition-on-nvgesture-1?p=decoupling-and-recoupling-spatiotemporal)
4 |
5 | This repo is the official implementation of "Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition" as well as the follow-ups. It currently includes code and models for the following tasks:
6 | > **RGB-D-based Action Recognition**: Included in this repo.
7 |
8 | > **RGB-D-based Gesture Recognition**: Included in this repo.
9 |
10 | >**Dynamic motion attention capture based on native video frames**: Included in this repo. See FRP module in this paper.
11 |
12 | ## Updates
13 | ***27/07/2023***
14 | 1. Updated the link of the journal expansion version [UMDR-Net](https://github.com/zhoubenjia/MotionRGBD-PAMI)(TPAMI'23) of this conference paper.
15 | 2. Updated the link of its improved version [MFST](https://arxiv.org/pdf/2308.12006.pdf)(MM'23).
16 |
17 | ***27/10/2022***
18 | 1. Update the code of NTU data preprocessing.
19 | 2. Fixed a bug in the DTN.
20 |
21 | ***18/10/2022***
22 | 1. Update the code of NvGesture training.
23 |
24 | ## 1. Requirements
25 | This is a PyTorch implementation of our paper.
26 |
27 | torch>=1.7.0; torchvision>=0.8.0; Visdom(optional)
28 |
29 | data prepare: Database with the following folder structure:
30 |
31 | ```
32 | │NTURGBD/
33 | ├──dataset_splits/
34 | │ ├── @CS
35 | │ │ ├── train.txt
36 | video name total frames label
37 | │ │ │ ├──S001C001P001R001A001_rgb 103 0
38 | │ │ │ ├──S001C001P001R001A004_rgb 99 3
39 | │ │ │ ├──......
40 | │ │ ├── valid.txt
41 | │ ├── @CV
42 | │ │ ├── train.txt
43 | │ │ ├── valid.txt
44 | ├──Images/
45 | │ │ ├── S001C002P001R001A002_rgb
46 | │ │ │ ├──000000.jpg
47 | │ │ │ ├──000001.jpg
48 | │ │ │ ├──......
49 | ├──nturgb+d_depth_masked/
50 | │ │ ├── S001C002P001R001A002
51 | │ │ │ ├──MDepth-00000000.png
52 | │ │ │ ├──MDepth-00000001.png
53 | │ │ │ ├──......
54 | ```
55 | It is important to note that due to the RGB video resolution in the NTU dataset is relatively high, so we are not directly to resize the image from the original resolution to 320x240, but first crop the object-centered ROI area (640x480), and then resize it to 320x240 for training and testing.
56 |
57 | ## 2. Methodology
58 |
59 |
60 |
61 |
62 | We propose to decouple and recouple spatiotemporal representation for RGB-D-based motion recognition. The Figure in the first line illustrates the proposed multi-modal spatiotemporal representation learning framework. The RGB-D-based motion recognition can be described as spatiotemporal information decoupling modeling, compact representation recoupling learning, and cross-modal representation interactive learning.
63 | The Figure in the second line shows the process of decoupling and recoupling saptiotemporal representation of a unimodal data.
64 |
65 | ## 3. Train and Evaluate
66 | All of our models are pre-trained on the [20BN Jester V1 dataset](https://www.kaggle.com/toxicmender/20bn-jester) and the pretrained model can be download [here](https://drive.google.com/drive/folders/1eBXED3uXlzBZzix7TvtDlJrZ3SlDCSF6?usp=sharing). Before cross-modal representation interactive learning, we first separately perform unimodal representation learning on RGB and depth data modalities.
67 | ### Unimodal Training
68 | Take training an RGB model with 8 GPUs on the NTU-RGBD dataset as an example,
69 | some basic configuration:
70 | ```bash
71 | common:
72 | dataset: NTU
73 | batch_size: 6
74 | test_batch_size: 6
75 | num_workers: 6
76 | learning_rate: 0.01
77 | learning_rate_min: 0.00001
78 | momentum: 0.9
79 | weight_decay: 0.0003
80 | init_epochs: 0
81 | epochs: 100
82 | optim: SGD
83 | scheduler:
84 | name: cosin # Represent decayed learning rate with the cosine schedule
85 | warm_up_epochs: 3
86 | loss:
87 | name: CE # cross entropy loss function
88 | labelsmooth: True
89 | MultiLoss: True # Enable multi-loss training strategy.
90 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ] # The loss weight coefficient assigned for each sub-branch.
91 | distill: 1. # The loss weight coefficient assigned for distillation task.
92 |
93 | model:
94 | Network: I3DWTrans # I3DWTrans represent unimodal training, set FusionNet for multi-modal fusion training.
95 | sample_duration: 64 # Sampled frames in a video.
96 | sample_size: 224 # The image is croped into 224x224.
97 | grad_clip: 5.
98 | SYNC_BN: 1 # Utilize SyncBatchNorm.
99 | w: 10 # Sliding window size.
100 | temper: 0.5 # Distillation temperature setting.
101 | recoupling: True # Enable recoupling strategy during training.
102 | knn_attention: 0.7 # Hyperparameter used in k-NN attention: selecting Top-70% tokens.
103 | sharpness: True # Enable sharpness for each sub-branch's output.
104 | temp: [ 0.04, 0.07 ] # Temperature parameter follows a cosine schedule from 0.04 to 0.07 during the training.
105 | frp: True # Enable FRP module.
106 | SEHeads: 1 # Number of heads used in RCM module.
107 | N: 6 # Number of Transformer blochs configured for each sub-branch.
108 |
109 | dataset:
110 | type: M # M: RGB modality, K: Depth modality.
111 | flip: 0.5 # Horizontal flip.
112 | rotated: 0.5 # Horizontal rotation
113 | angle: (-10, 10) # Rotation angle
114 | Blur: False # Enable random blur operation for each video frame.
115 | resize: (320, 240) # The input is spatially resized to 320x240 for NTU dataset.
116 | crop_size: 224
117 | low_frames: 16 # Number of frames sampled for small Transformer.
118 | media_frames: 32 # Number of frames sampled for medium Transformer.
119 | high_frames: 48 # Number of frames sampled for large Transformer.
120 | ```
121 |
122 | ```bash
123 | bash run.sh tools/train.py config/NTU.yml 0,1,2,3,4,5,6,7 8
124 | ```
125 | or
126 | ```bash
127 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 train.py --config config/NTU.yml --nprocs 8
128 | ```
129 |
130 | ### Cross-modal Representation Interactive Learning
131 | Take training a fusion model with 8 GPUs on the NTU-RGBD dataset as an example.
132 | ```bash
133 | bash run.sh tools/fusion.py config/NTU.yml 0,1,2,3,4,5,6,7 8
134 | ```
135 | or
136 | ```bash
137 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 tools/fusion.py --config config/NTU.yml --nprocs 8
138 | ```
139 |
140 | ### Evaluation
141 | ```bash
142 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=1234 tools/train.py --config config/NTU.yml --nprocs 1 --eval_only --resume /path/to/model_best.pth.tar
143 | ```
144 |
145 | ## 4. Models Download
146 |
147 |
148 | | Dataset |
149 | Modality |
150 | Accuracy |
151 | Download |
152 |
153 |
154 | | NvGesture |
155 | RGB |
156 | 89.58 |
157 | Google Drive |
158 |
159 |
160 | | NvGesture |
161 | Depth |
162 | 90.62 |
163 | Google Drive |
164 |
165 |
166 | | NvGesture |
167 | RGB-D |
168 | 91.70 |
169 | Google Drive |
170 |
171 |
172 |
173 | | THU-READ |
174 | RGB |
175 | 81.25 |
176 | Google Drive |
177 |
178 |
179 | | THU-READ |
180 | Depth |
181 | 77.92 |
182 | Google Drive |
183 |
184 |
185 | | THU-READ |
186 | RGB-D |
187 | 87.04 |
188 | Google Drive |
189 |
190 |
191 |
192 | | NTU-RGBD(CS) |
193 | RGB |
194 | 90.3 |
195 | Google Drive |
196 |
197 |
198 | | NTU-RGBD(CS) |
199 | Depth |
200 | 92.7 |
201 | Google Drive |
202 |
203 |
204 | | NTU-RGBD(CS) |
205 | RGB-D |
206 | 94.2 |
207 | Google Drive |
208 |
209 |
210 |
211 | | NTU-RGBD(CV) |
212 | RGB |
213 | 95.4 |
214 | Google Drive |
215 |
216 |
217 | | NTU-RGBD(CV) |
218 | Depth |
219 | 96.2 |
220 | Google Drive |
221 |
222 |
223 | | NTU-RGBD(CV) |
224 | RGB-D |
225 | 97.3 |
226 | Google Drive |
227 |
228 |
229 |
230 | | IsoGD |
231 | RGB |
232 | 60.87 |
233 | Google Drive |
234 |
235 |
236 | | IsoGD |
237 | Depth |
238 | 60.17 |
239 | Google Drive |
240 |
241 |
242 | | IsoGD |
243 | RGB-D |
244 | 66.79 |
245 | Google Drive |
246 |
247 |
248 |
249 | # Citation
250 | ```
251 | @InProceedings{Zhou_2022_CVPR,
252 | author = {Zhou, Benjia and Wang, Pichao and Wan, Jun and Liang, Yanyan and Wang, Fan and Zhang, Du and Lei, Zhen and Li, Hao and Jin, Rong},
253 | title = {Decoupling and Recoupling Spatiotemporal Representation for RGB-D-Based Motion Recognition},
254 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
255 | month = {June},
256 | year = {2022},
257 | pages = {20154-20163}
258 | }
259 | ```
260 | # LICENSE
261 | The code is released under the MIT license.
262 | # Copyright
263 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
264 |
--------------------------------------------------------------------------------
/config/IsoGD.yml:
--------------------------------------------------------------------------------
1 | common:
2 | data: /path/to/IsoGD/Dataset
3 | splits: /path/to/IsoGD/Dataset/dataset_splits
4 |
5 | #-------basic Hyparameter----------
6 | visdom:
7 | enable: False
8 | visname: IsoGD
9 |
10 | dataset: IsoGD #Database name e.g., NTU, THUREAD ...
11 | batch_size: 6
12 | test_batch_size: 6
13 | num_workers: 6
14 | learning_rate: 0.01
15 | learning_rate_min: 0.00001
16 | momentum: 0.9
17 | weight_decay: 0.0003
18 | init_epochs: 0
19 | epochs: 300
20 | report_freq: 10
21 | optim: SGD
22 | dist: True
23 | vis_feature: True # Feature Visualization?
24 | DEBUG: False
25 |
26 | scheduler:
27 | name: ReduceLR
28 | patience: 4
29 | warm_up_epochs: 3
30 | loss:
31 | name: CE
32 | labelsmooth: True
33 | MultiLoss: True
34 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
35 | distill: 1.
36 | resume_scheduler: 0
37 | model:
38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet
39 | pretrained: ''
40 | # resume: ''
41 | resumelr: 0.0001
42 | sample_duration: 64
43 | sample_size: 224
44 | grad_clip: 5.
45 | SYNC_BN: 1
46 | w: 10
47 | temper: 0.5
48 | recoupling: True
49 | knn_attention: 0.7
50 | sharpness: True
51 | temp: [ 0.04, 0.07 ]
52 | frp: True
53 | SEHeads: 1
54 | N: 6 # Number of Transformer Blocks
55 | #-------Used for fusion network----------
56 | rgb_checkpoint: ''
57 | depth_checkpoint: ''
58 | dataset:
59 | type: M # M: rgb, K: depth
60 | flip: 0.0
61 | rotated: 0.5
62 | angle: (-10, 10) # Rotation angle
63 | Blur: False
64 | resize: (256, 256)
65 | crop_size: 224
66 | low_frames: 16
67 | media_frames: 32
68 | high_frames: 48
69 |
--------------------------------------------------------------------------------
/config/Jester.yml:
--------------------------------------------------------------------------------
1 | common:
2 | data: /media/ssd1/bjzhou/ssd2/bjzhou/Jester/20bn-jester-v1
3 | splits: /media/ssd1/bjzhou/ssd2/bjzhou/Jester/dataset_splits
4 |
5 | #-------basic Hyparameter----------
6 | visname: Jester
7 | dataset: Jester
8 | batch_size: 2
9 | test_batch_size: 2
10 | num_workers: 4
11 | learning_rate: 0.01
12 | learning_rate_min: 0.00001
13 | momentum: 0.9
14 | weight_decay: 0.0003
15 | init_epochs: 0
16 | epochs: 100
17 | report_freq: 10
18 | metric_freq: 10
19 | show_cluster_result: 100
20 | optim: SGD
21 | dist: True
22 | vis_feature: True # Visualization?
23 |
24 | scheduler:
25 | name: cosin
26 | patience: 4
27 | warm_up_epochs: 0
28 | loss:
29 | name: CE
30 | labelsmooth: True
31 | mse_weight: 10.0
32 | MultiLoss: True
33 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
34 | distill_lamdb: 1.
35 |
36 | model:
37 | Network: I3DWTrans
38 | pretrained: ''
39 | # resume: ''
40 | resumelr: False
41 | sample_duration: 64
42 | sample_size: 224
43 | grad_clip: 5.
44 | SYNC_BN: 1
45 | w: 10
46 | temper: 0.4
47 | recoupling: True
48 | knn_attention: 0.8
49 | sharpness: True
50 | temp: [ 0.04, 0.07 ]
51 | frp: True
52 | SEHeads: 1
53 | N: 6
54 |
55 | dataset:
56 | type: M
57 | flip: 0.0
58 | rotated: 0.5
59 | angle: (-10, 10)
60 | Blur: False
61 | resize: (256, 256)
62 | crop_size: 224
63 | low_frames: 16
64 | media_frames: 32
65 | high_frames: 48
--------------------------------------------------------------------------------
/config/NTU.yml:
--------------------------------------------------------------------------------
1 | # '''
2 | # Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | # '''
4 |
5 | common:
6 | data: /mnt/workspace//Dataset/NTU-RGBD/
7 | splits: /mnt/workspace//Dataset/NTU-RGBD/dataset_splits/@CS
8 |
9 | #-------basic Hyparameter----------
10 | visdom:
11 | enable: False
12 | visname: NTU
13 | dataset: NTU #Database name e.g., NTU, THU ...
14 | batch_size: 4
15 | test_batch_size: 6
16 | num_workers: 10
17 | learning_rate: 0.005
18 | learning_rate_min: 0.00001
19 | momentum: 0.9
20 | weight_decay: 0.0003
21 | init_epochs: 0
22 | epochs: 100
23 | report_freq: 10
24 | optim: SGD
25 | dist: True
26 | vis_feature: True # Visualization?
27 |
28 | scheduler:
29 | name: cosin
30 | patience: 4
31 | warm_up_epochs: 3
32 | loss:
33 | name: CE
34 | labelsmooth: True
35 | MultiLoss: True
36 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
37 | distill: 1.
38 |
39 | model:
40 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet
41 | pretrained: '' #./Checkpoints/I3DWTrans-EXP-20211204-211826//model_best.pth.tar' #'../MultiScale/Checkpoints/I3DWTrans-EXP-20211024-171224/model_best.pth.tar' #'../MultiScale/Checkpoints/I3DWTrans-EXP-20211019-124405/model_best.pth.tar'
42 | # resume: '' #'./Checkpoints/I3DWTrans-NTU-M-20211214-195641/model_best-DTN.pth.tar'
43 | resumelr: ''
44 | sample_duration: 64
45 | sample_size: 224
46 | grad_clip: 5.
47 | SYNC_BN: 1
48 | w: 10
49 | temper: 0.5
50 | recoupling: True
51 | knn_attention: 0.7
52 | sharpness: True
53 | temp: [ 0.04, 0.07 ]
54 | frp: True
55 | SEHeads: 1
56 | N: 6 # Number of Transformer Blocks
57 |
58 | #-------Used for fusion network----------
59 | rgb_checkpoint: './Checkpoints/I3DWTrans-EXP-20211204-211826//model_best.pth.tar'
60 | depth_checkpoint: './Checkpoints/I3DWTrans-EXP-20211204-214434//model_best.pth.tar'
61 |
62 | dataset:
63 | type: M # M: rgb, K: depth
64 | flip: 0.5
65 | rotated: 0.5
66 | angle: (-10, 10) # Rotation angle
67 | Blur: False
68 | resize: (320, 240)
69 | crop_size: 224
70 | low_frames: 16
71 | media_frames: 32
72 | high_frames: 48
73 |
74 | # I3DWTrans-EXP-20211204-211826 M 90.25
75 | # I3DWTrans-EXP-20211204-214434 K 92.81
76 |
--------------------------------------------------------------------------------
/config/NetworkConfig.yml:
--------------------------------------------------------------------------------
1 | common:
2 | data: /path/to/dataset/NTU-RGBD
3 | splits: /path/to/dataset/dataset/NTU-RGBD/dataset_splits/@CS # include: train.txt and test.txt
4 |
5 | #-------basic Hyparameter----------
6 | visdom:
7 | enable: True
8 | visname: NTU
9 | dataset: NTU #Database name e.g., NTU, THUREAD, NvGesture and IsoGD ...
10 | batch_size: 6
11 | test_batch_size: 6
12 | num_workers: 6
13 | learning_rate: 0.01
14 | learning_rate_min: 0.00001
15 | momentum: 0.9
16 | weight_decay: 0.0003
17 | init_epochs: 0
18 | epochs: 100 # if training on IsoGD dataset, set 300 is better.
19 | report_freq: 100
20 | optim: SGD
21 | dist: True
22 | vis_feature: True # Visualization?
23 |
24 | scheduler:
25 | name: cosin
26 | patience: 4
27 | warm_up_epochs: 3
28 | loss:
29 | name: CE
30 | labelsmooth: True
31 | MultiLoss: True
32 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
33 | distill: 1.
34 |
35 | model:
36 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet
37 | pretrained: '' # all of experiments are pre-trained on 20BN Jester V1 dataset except for NTU-RGBD.
38 | resume: ''
39 | resumelr: False
40 | sample_duration: 64
41 | sample_size: 224
42 | grad_clip: 5.
43 | SYNC_BN: 1
44 | w: 10
45 | temper: 0.5 # 0.5 for THUREAD and NTU-RGBD; 0.4 for NvGesture and IsoGD
46 | recoupling: True
47 | knn_attention: 0.7
48 | sharpness: True
49 | temp: [ 0.04, 0.07 ]
50 | frp: True
51 | SEHeads: 1
52 | N: 6 # Number of Transformer Blocks
53 |
54 | #-------Used for fusion network----------
55 | rgb_checkpoint: ''
56 | depth_checkpoint: ''
57 |
58 | dataset:
59 | type: M # M: rgb, K: depth
60 | flip: 0.5 # set 0.0 for NvGesture and IsoGD
61 | rotated: 0.5 # THUREAD: 0.8, others: 0.5
62 | angle: (-10, 10) # Rotation angle. THUREAD: (-45, 45), others: (-10, 10)
63 | Blur: False
64 | resize: (320, 240) #NTU and THUREAD: (320, 240), others:(256, 256)
65 | crop_size: 224 # THUREAD: 200, others: 224
66 | low_frames: 16
67 | media_frames: 32
68 | high_frames: 48
69 |
70 |
--------------------------------------------------------------------------------
/config/NvGesture.yml:
--------------------------------------------------------------------------------
1 |
2 | common:
3 | data: /mnt/workspace/Dataset/NvGesture/
4 | splits: /mnt/workspace/Dataset/NvGesture/dataset_splits/
5 |
6 | #-------basic Hyparameter----------
7 | visdom:
8 | enable: False
9 | visname: NvGesture
10 |
11 | dataset: NvGesture #Database name e.g., NTU, THUREAD ...
12 | batch_size: 4
13 | test_batch_size: 2
14 | num_workers: 10
15 | learning_rate: 0.01
16 | learning_rate_min: 0.00001
17 | momentum: 0.9
18 | weight_decay: 0.0003
19 | init_epochs: 0
20 | epochs: 100
21 | report_freq: 10
22 | optim: SGD
23 | dist: True
24 | vis_feature: True # Visualization?
25 | DEBUG: False
26 |
27 | scheduler:
28 | name: cosin
29 | patience: 4
30 | warm_up_epochs: 3
31 | loss:
32 | name: CE
33 | labelsmooth: True
34 | MultiLoss: True
35 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
36 | distill: 1.
37 | model:
38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet
39 | pretrained: /mnt/workspace/Code/CVPR/Checkpoints/I3DWTrans-NvGesture-K-20221017-113943/model_best-Nv-K.pth.tar-v1
40 | resumelr: False
41 | sample_duration: 64
42 | sample_size: 224
43 | grad_clip: 5.
44 | SYNC_BN: 1
45 | w: 4 # 4 is best for Nv
46 | temper: 0.4
47 | recoupling: True
48 | knn_attention: 0.7
49 | sharpness: True
50 | temp: [ 0.04, 0.07 ]
51 | frp: True
52 | SEHeads: 1
53 | N: 6 # Number of Transformer Blocks
54 | #-------Used for fusion network----------
55 | rgb_checkpoint: '/mnt/workspace/Code/CVPR/Checkpoints/model_best-Nv-K.pth.tar'
56 | depth_checkpoint: '/mnt/workspace/Code/CVPR/Checkpoints/model_best-Nv-M.pth.tar'
57 |
58 | dataset:
59 | type: M
60 | flip: 0.0
61 | rotated: 0.5
62 | angle: (-10, 10)
63 | Blur: False
64 | resize: (256, 256)
65 | crop_size: 224
66 | low_frames: 16
67 | media_frames: 32
68 | high_frames: 48
69 |
70 | # I3DWTrans-NvGesture-K-20221017-113943 90.83
71 | # I3DWTrans-NvGesture-M-20221018-123442 88.75
--------------------------------------------------------------------------------
/config/THU.yml:
--------------------------------------------------------------------------------
1 | common:
2 | data: /mnt/workspace//Dataset/THU-READ/frames
3 | splits: /mnt/workspace//Dataset/THU-READ/dataset_splits/@2
4 |
5 |
6 | #-------basic Hyparameter----------
7 | visdom:
8 | enable: False
9 | visname: THU
10 | dataset: THUREAD
11 | batch_size: 6
12 | test_batch_size: 6
13 | num_workers: 6
14 | learning_rate: 0.01
15 | learning_rate_min: 0.00001
16 | momentum: 0.9
17 | weight_decay: 0.0003
18 | init_epochs: 0
19 | epochs: 100
20 | report_freq: 10
21 | optim: SGD
22 | dist: True
23 | vis_feature: True # Visualization?
24 | DEBUG: False
25 |
26 | scheduler:
27 | name: cosin
28 | patience: 4
29 | warm_up_epochs: 3 # 10 may work
30 | loss:
31 | name: CE
32 | labelsmooth: True
33 | MultiLoss: True
34 | loss_lamdb: [ 1, 0.5, 0.5, 0.5 ]
35 | distill: 1.
36 |
37 | model:
38 | Network: I3DWTrans # e.g., I3DWTrans or FusionNet
39 | pretrained: '/mnt/workspace/Code/CVPR/Checkpoints/I3DWTrans-NvGesture-K-20221017-113943/model_best-Nv-K.pth.tar-v1' #'./Checkpoints/I3DWTrans-THUREAD-M-20211211-194730/model_best.pth.tar'
40 | # resume: ./Checkpoints/FusionNet-THUREAD-M-20211213-195422/model_best.pth.tar
41 | resumelr: False
42 | sample_duration: 64
43 | sample_size: 224
44 | grad_clip: 5.
45 | SYNC_BN: 1
46 | w: 10
47 | temper: 0.5
48 | recoupling: True
49 | knn_attention: 0.7
50 | sharpness: True
51 | temp: [ 0.04, 0.07 ]
52 | frp: True
53 | SEHeads: 1
54 | N: 6 # Number of Transformer Blocks
55 |
56 | rgb_checkpoint: './Checkpoints/I3DWTrans-THUREAD-M-20211211-194730/model_best.pth.tar'
57 | depth_checkpoint: './Checkpoints/I3DWTrans-THUREAD-K-20211211-124150/model_best.pth.tar'
58 |
59 | dataset:
60 | type: M
61 | flip: 0.5
62 | rotated: 0.8
63 | angle: (-45, 45)
64 | Blur: False
65 | resize: (320, 240)
66 | crop_size: 200
67 | low_frames: 16
68 | media_frames: 32
69 | high_frames: 48
70 |
71 | # I3DWTrans-THUREAD-K-20211212-122941 K 83.33
72 | # I3DWTrans-THUREAD-M-20211211-194730 M 82.08
73 |
74 | # Local + Global + multi loss I3DWTrans_SAtt-EXP-20210823-115604 55.0
75 | # THU-READPre I3DWTrans-EXP-20210825-000754 74.58
76 | # READPreonlyGlobal I3DWTrans-EXP-20210825-085754 72.08%
77 | # THU-READPreSingleLoss I3DWTrans-EXP-20210825-154036 72.50
78 | # READPreNew I3DWTrans-EXP-20210827-110508 75.00
79 | # THU-READPreNewTopKSche I3DWTrans-EXP-20210828-215011 75.83
80 | # THU-READPreNewTopKSche depth I3DWTrans-EXP-20210829-125851 80.0
81 | # THU-READPreNewTopKSchecross I3DWTrans-EXP-20210829-161637 77.08
82 | # THU-READPreCrossSche I3DWTrans-EXP-20210830-002827 74.58
83 | # THU-READPreCrossTopKSchegradclip 75.42
84 | # @2 THU-READDatt I3DWTrans-EXP-20210831-214258 77.92
85 | # THU-READDattSize(320, 320) I3DWTrans-EXP-20210901-014119 75.00 --> no work
86 | # THU-READDatt depth I3DWTrans-EXP-20210901-014025 79.58
87 | # THU-READDatt@4 I3DWTrans-EXP-20210901-092801 76.25
88 | # THU-READDatt@1Blur I3DWTrans-EXP-20210909-105432 74.17
89 |
90 | # THU-READDatt@1New M I3DWTrans-EXP-20210911-103936 80.00
91 | # THU-READDatt@1New K I3DWTrans-EXP-20210911-104131 75.42
92 | # Fusion add @1 84.17 | strategy 86.67
93 | # THU-READDatt@2New M I3DWTrans-EXP-20210910-224232 81.25
94 | # THU-READDatt@2New K I3DWTrans-EXP-20210910-223415 82.08
95 | # Fusion add @2 86.25 | strategy 90.41
96 | # THU-READDatt@3New M I3DWTrans-EXP-20210911-191641 77.50
97 | # THU-READDatt@3New K I3DWTrans-EXP-20210911-191748 77.92
98 | # Fusion add @3 | strategy 82.02
99 | # THU-READDatt@4New M I3DWTrans-EXP-20210912-005825 82.92
100 | # THU-READDatt@4New K I3DWTrans-EXP-20210912-101129 76.25
101 | # Fusion add @4 | strategy
102 |
103 | # @1 M
104 | # @1 K
105 | # fusion 85.42
106 |
107 | # @2 M 84.17
108 | # @2 K 82.91
109 | # fusion 88.75
110 |
111 | # ./Checkpoints/FusionNet-EXP-20211008-232502/model_best.pth.tar
112 | # @3 M 81.94
113 | # @3 K 84.86
114 | # fusion 85.69
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | from .config import Config
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import yaml
6 | # from easydict import EasyDict as edict
7 | def Config(args):
8 | print()
9 | print('='*80)
10 | with open(args.config) as f:
11 | config = yaml.load(f, Loader=yaml.FullLoader)
12 | for dic in config:
13 | for k, v in config[dic].items():
14 | setattr(args, k, v)
15 | print(k, ':\t', v)
16 | print('='*80)
17 | print()
18 | return args
--------------------------------------------------------------------------------
/data/data_preprose.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import cv2
6 | from PIL import Image
7 | import numpy as np
8 |
9 | import os, glob, re
10 | import argparse
11 | import csv
12 | import random
13 | from tqdm import tqdm
14 | from multiprocessing import Process
15 | import shutil
16 | from multiprocessing import Pool, cpu_count
17 |
18 | def resize_pos(center,src_size,tar_size):
19 | x, y = center
20 | w1=src_size[1]
21 | h1=src_size[0]
22 | w=tar_size[1]
23 | h=tar_size[0]
24 |
25 | y1 = int((h / h1) * y)
26 | x1 = int((w / w1) * x)
27 | return (x1, y1)
28 |
29 | '''
30 | For NTU-RGBD
31 | '''
32 | def video2image(v_p):
33 | m_path='nturgb+d_depth_masked/'
34 | img_path = os.path.join('Images', v_p[:-4].split('/')[-1])
35 | if not os.path.exists(img_path):
36 | os.makedirs(img_path)
37 | cap = cv2.VideoCapture(v_p)
38 | suc, frame = cap.read()
39 | frame_count = 1
40 | while suc:
41 | # frame [1920, 1080]
42 | mask_path = os.path.join(m_path, v_p[:-8].split('/')[-1], 'MDepth-%08d.png'%frame_count)
43 | mask = cv2.imread(mask_path)
44 | mask = mask*255
45 | w, h, c = mask.shape
46 | h2, w2, _ = frame.shape
47 | ori = frame
48 | frame = cv2.resize(frame, (h, w))
49 | h1, w1, _ = frame.shape
50 |
51 | # image = cv2.add(frame, mask)
52 |
53 | # find contour
54 | mask = cv2.erode(mask, np.ones((3, 3),np.uint8))
55 | mask = cv2.dilate(mask ,np.ones((10, 10),np.uint8))
56 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
57 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
58 |
59 | # Find Max Maxtri
60 | Idx = []
61 | for i in range(len(contours)):
62 | Area = cv2.contourArea(contours[i])
63 | if Area > 500:
64 | Idx.append(i)
65 | # max_idx = np.argmax(area)
66 |
67 | centers = []
68 | for i in Idx:
69 | rect = cv2.minAreaRect(contours[i])
70 | center, (h, w), degree = rect
71 | centers.append(center)
72 |
73 | finall_center = np.int0(np.array(centers))
74 | c_x = min(finall_center[:, 0])
75 | c_y = min(finall_center[:, 1])
76 |
77 | center = (c_x, c_y)
78 | # finall_center = finall_center.sum(0)/len(finall_center)
79 |
80 | # rect = cv2.minAreaRect(contours[max_idx])
81 | # center, (h, w), degree = rect
82 | # center = tuple(np.int0(finall_center))
83 | center_new = resize_pos(center, (h1, w1), (h2, w2))
84 |
85 | #-----------------------------------
86 | # Image Crop
87 | #-----------------------------------
88 | # ori = cv2.circle(ori, center_new, 2, (0, 0, 255), 2)
89 | crop_y, crop_x = h2//2, w2//2
90 | # print(crop_x, crop_y)
91 | left = center_new[0] - crop_x//2 if center_new[0] - crop_x//2 > 0 else 0
92 | top = center_new[1] - crop_y//2 if center_new[1] - crop_y//2 > 0 else 0
93 | # ori = cv2.circle(ori, (left, top), 2, (0, 0, 255), 2)
94 | # cv2.imwrite('demo/ori.png', ori)
95 | crop_w = left + crop_x if left + crop_x < w2 else w2
96 | crop_h = top + crop_y if top + crop_y < h2 else h2
97 | rect = (left, top, crop_w, crop_h)
98 | image = Image.fromarray(cv2.cvtColor(ori, cv2.COLOR_BGR2RGB))
99 | image = image.crop(rect)
100 | image.save('{}/{:0>6d}.jpg'.format(img_path, frame_count))
101 |
102 | # box = cv2.boxPoints(rect)
103 | # box = np.int0(box)
104 | # drawImage = frame.copy()
105 | # drawImage = cv2.drawContours(drawImage, [box], 0, (255, 0, 0), -1) # draw one contour
106 | # cv2.imwrite('demo/drawImage.png', drawImage)
107 | # frame = cv2.circle(frame, center, 2, (0, 255, 255), 2)
108 | # cv2.imwrite('demo/Image.png', frame)
109 | # cv2.imwrite('demo/mask.png', mask)
110 | # ori = cv2.circle(ori, center_new, 2, (0, 0, 255), 2)
111 | # cv2.imwrite('demo/ORI.png', ori)
112 | # cv2.imwrite('demo/maskImage.png', image)
113 |
114 | # cv2.imwrite('{}/{:0>6d}.jpg'.format(img_path, frame_count), frame)
115 | frame_count += 1
116 | suc, frame = cap.read()
117 | cap.release()
118 |
119 | '''
120 | For IsoGD, Nv...
121 | '''
122 | # def video2image(v_p):
123 | # img_path = v_p[:-4].replace('UCF-101', 'UCF-101-images')
124 | # if not os.path.exists(img_path):
125 | # os.makedirs(img_path)
126 | # cap = cv2.VideoCapture(v_p)
127 | # suc, frame = cap.read()
128 | # frame_count = 0
129 | # while suc:
130 | # h, w, c = frame.shape
131 | # cv2.imwrite('{}/{:0>6d}.jpg'.format(img_path, frame_count), frame)
132 | # frame_count += 1
133 | # suc, frame = cap.read()
134 | # cap.release()
135 |
136 | def GeneratLabel(sample):
137 | path = sample[:-4].split('/')[-1]
138 | cap = cv2.VideoCapture(sample)
139 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
140 | label = int(sample.split('A')[-1][:3])-1
141 | txt = ' '.join(map(str, [path, frame_count, label, '\n']))
142 | if args.proto == '@CV':
143 | if 'C001' in sample:
144 | with open(args.validTXT, 'a') as vf:
145 | vf.writelines(txt)
146 | else:
147 | with open(args.trainTXT, 'a') as tf:
148 | tf.writelines(txt)
149 | elif args.proto == '@CS':
150 | pattern = re.findall(r'P\d+', sample)
151 | if int(pattern[0][1:]) in [1, 2, 4, 5, 8, 9, 13, 14, 15,16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38]:
152 | with open(args.trainTXT, 'a') as tf:
153 | tf.writelines(txt)
154 | else:
155 | with open(args.validTXT, 'a') as vf:
156 | vf.writelines(txt)
157 |
158 | def ResizeImage(img_path):
159 | save_path = img_path.replace('Images', 'ImagesResize')
160 | if not os.path.exists(save_path):
161 | os.makedirs(save_path)
162 | for img in os.listdir(img_path):
163 | im_path = os.path.join(img_path, img)
164 | image = cv2.imread(im_path)
165 | image = cv2.resize(image, (320, 240))
166 | cv2.imwrite(os.path.join(save_path, img), image)
167 |
168 | data_root = '/mnt/workspace/Dataset/NTU-RGBD'
169 | Image_paths = glob.glob(os.path.join(data_root, 'nturgb+d_rgb/*.avi'))
170 | print('Total Images: {}'.format(len(Image_paths)))
171 | mask_paths = os.listdir(os.path.join(data_root, 'nturgb+d_depth_masked/'))
172 | print('Total Masks: {}'.format(len(mask_paths)))
173 |
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('--proto', default='@CS')
176 | args = parser.parse_args()
177 |
178 |
179 | #---------------------------------------------
180 | # Generate label .txt
181 | #---------------------------------------------
182 | trainTXT = os.path.join(data_root, 'dataset_splits', args.proto, 'train.txt')
183 | validTXT = os.path.join(data_root, 'dataset_splits', args.proto, 'valid.txt')
184 | args.trainTXT = trainTXT
185 | args.validTXT = validTXT
186 | if os.path.isfile(args.trainTXT):
187 | os.system('rm {}'.format(args.trainTXT))
188 | if os.path.isfile(args.validTXT):
189 | os.system('rm {}'.format(args.validTXT))
190 |
191 | with Pool(20) as pool:
192 | for a in tqdm(pool.imap_unordered(GeneratLabel, Image_paths), total=len(Image_paths), desc='Processes'):
193 | if a is not None:
194 | pass
195 | print('Write file list done'.center(80, '*'))
196 |
197 | #---------------------------------------------
198 | # video --> Images
199 | #---------------------------------------------
200 | print(len(Image_paths))
201 | with Pool(20) as pool:
202 | for a in tqdm(pool.imap_unordered(video2image, Image_paths), total=len(Image_paths), desc='Processes'):
203 | if a is not None:
204 | pass
205 | print('Write Image done'.center(80, '*'))
206 |
207 | #---------------------------------------------
208 | # Images size to (320, 240)
209 | #---------------------------------------------
210 | trainTXT = '/mnt/workspace/Dataset/NTU-RGBD/dataset_splits/@CS/train.txt'
211 | validTXT = '/mnt/workspace/Dataset/NTU-RGBD/dataset_splits/@CS/valid.txt'
212 | Image_paths = ['./Images/'+ l.split()[0] for l in open(validTXT, 'r').readlines()]
213 | with Pool(40) as pool:
214 | for a in tqdm(pool.imap_unordered(ResizeImage, Image_paths), total=len(Image_paths), desc='Processes'):
215 | if a is not None:
216 | pass
217 | print('Write Image done'.center(80, '*'))
218 |
219 | # data_root = '/mnt/workspace/Dataset/UCF-101/'
220 | # label_dict = dict([(lambda x: (x[1], int(x[0])-1))(l.strip().split(' ')) for l in open(data_root + 'dataset_splits/lableind.txt').readlines()])
221 | # print(label_dict)
222 |
223 | # def split_func(file_list):
224 | # class_list = []
225 | # fl = open(file_list).readlines()
226 | # for d in tqdm(fl):
227 | # path = d.strip().split()[0][:-4]
228 | # label = label_dict[path.split('/')[0]]
229 | # frame_num = len(os.listdir(os.path.join(data_root, 'UCF-101-images', path)))
230 | # class_list.append([path, str(frame_num), str(label), '\n'])
231 | # return class_list
232 |
233 | # def save_list(file_list, file_name):
234 | # with open(file_name, 'w') as f:
235 | # class_list = split_func(file_list)
236 | # for l in class_list:
237 | # f.writelines(' '.join(l))
238 |
239 | # prot = '@3'
240 | # data_train_split = data_root + f'dataset_splits/{prot}/trainlist.txt'
241 | # data_test_split = data_root + f'dataset_splits/{prot}/testlist.txt'
242 |
243 | # train_file_name = data_root + f'dataset_splits/{prot}/train.txt'
244 | # test_file_name = data_root + f'dataset_splits/{prot}/valid.txt'
245 | # save_list(data_train_split, train_file_name)
246 | # save_list(data_test_split, test_file_name)
247 |
248 |
249 |
--------------------------------------------------------------------------------
/demo/decouple_recouple.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/MotionRGBD/b13673a10e3f259ddef4911a2a91b6eedaf104a1/demo/decouple_recouple.jpg
--------------------------------------------------------------------------------
/demo/pipline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/MotionRGBD/b13673a10e3f259ddef4911a2a91b6eedaf104a1/demo/pipline.jpg
--------------------------------------------------------------------------------
/demo/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 | from .datasets import *
5 | from .model import *
--------------------------------------------------------------------------------
/lib/datasets/IsoGD.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .base import Datasets
7 | from torchvision import transforms, set_image_backend
8 | import random, os
9 | from PIL import Image
10 | import numpy as np
11 |
12 | class IsoGDData(Datasets):
13 | def __init__(self, args, ground_truth, modality, phase='train'):
14 | super(IsoGDData, self).__init__(args, ground_truth, modality, phase)
15 | def __getitem__(self, index):
16 | """
17 | Args:
18 | index (int): Index
19 | Returns:
20 | tuple: (image, target) where target is class_index of the target class.
21 | """
22 | sl = self.get_sl(self.inputs[index][1])
23 | self.data_path = os.path.join(self.dataset_root, self.typ, self.inputs[index][0])
24 | if self.typ == 'depth':
25 | self.data_path = self.data_path.replace('M_', 'K_')
26 |
27 | if self.args.Network == 'FusionNet':
28 | assert self.typ == 'rgb'
29 | self.data_path1 = self.data_path.replace('rgb', 'depth')
30 | self.data_path1 = self.data_path1.replace('M', 'K')
31 |
32 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
33 | self.clip1, skgmaparr1 = self.image_propose(self.data_path1, sl)
34 | return (self.clip.permute(0, 3, 1, 2), skgmaparr), (self.clip1.permute(0, 3, 1, 2), skgmaparr1), self.inputs[index][2], self.inputs[index][0]
35 |
36 | else:
37 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
38 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.inputs[index][0]
39 |
40 | def __len__(self):
41 | return len(self.inputs)
42 |
--------------------------------------------------------------------------------
/lib/datasets/Jester.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .base import Datasets
7 | from torchvision import transforms, set_image_backend
8 | import random, os
9 | from PIL import Image
10 | import numpy as np
11 | import logging
12 | np.random.seed(123)
13 |
14 | class JesterData(Datasets):
15 | def __init__(self, args, ground_truth, modality, phase='train'):
16 | super(JesterData, self).__init__(args, ground_truth, modality, phase)
17 |
18 | def LoadKeypoints(self):
19 | if self.phase == 'train':
20 | kpt_file = os.path.join(self.dataset_root, self.args.splits, 'train_kp.data')
21 | else:
22 | kpt_file = os.path.join(self.dataset_root, self.args.splits, 'valid_kp.data')
23 | with open(kpt_file, 'r') as f:
24 | kpt_data = [(lambda arr: (os.path.join(self.dataset_root, self.typ, self.phase, arr[0]), list(map(lambda x: int(float(x)), arr[1:]))))(l[:-1].split()) for l in f.readlines()]
25 | kpt_data = dict(kpt_data)
26 |
27 | for k, v in kpt_data.items():
28 | pose = v[:18*2]
29 | r_hand = v[18*2: 18*2+21*2]
30 | l_hand = v[18*2+21*2: 18*2+21*2+21*2]
31 | kpt_data[k] = {'people': [{'pose_keypoints_2d': pose, 'hand_right_keypoints_2d': r_hand, 'hand_left_keypoints_2d': l_hand}]}
32 |
33 | logging.info('Load Keypoints files Done, Total: {}'.format(len(kpt_data)))
34 | return kpt_data
35 | def get_path(self, imgs_path, a):
36 | return os.path.join(imgs_path, "%05d.jpg" % int(a + 1))
37 | def __getitem__(self, index):
38 | """
39 | Args:
40 | index (int): Index
41 | Returns:
42 | tuple: (image, target) where target is class_index of the target class.
43 | """
44 | sl = self.get_sl(self.inputs[index][1])
45 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0])
46 | # self.clip = self.image_propose(self.data_path, sl)
47 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
48 |
49 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path
50 |
51 | def __len__(self):
52 | return len(self.inputs)
53 |
--------------------------------------------------------------------------------
/lib/datasets/NTU.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .base import Datasets
7 | from torchvision import transforms, set_image_backend
8 | import random, os
9 | from PIL import Image
10 | import numpy as np
11 |
12 | class NTUData(Datasets):
13 | def __init__(self, args, ground_truth, modality, phase='train'):
14 | super(NTUData, self).__init__(args, ground_truth, modality, phase)
15 |
16 | def __getitem__(self, index):
17 | """
18 | Args:
19 | index (int): Index
20 | Returns:
21 | tuple: (image, target) where target is class_index of the target class.
22 | """
23 | sl = self.get_sl(self.inputs[index][1])
24 |
25 | if self.typ == 'rgb':
26 | self.data_path = os.path.join(self.dataset_root, 'Images', self.inputs[index][0])
27 |
28 | if self.typ == 'depth':
29 | self.data_path = os.path.join(self.dataset_root, 'nturgb+d_depth_masked', self.inputs[index][0][:-4])
30 |
31 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
32 |
33 | if self.args.Network == 'FusionNet':
34 | assert self.typ == 'rgb'
35 | self.data_path = os.path.join(self.dataset_root, 'nturgb+d_depth_masked', self.inputs[index][0][:-4])
36 | self.clip1, skgmaparr1 = self.image_propose(self.data_path, sl)
37 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), \
38 | self.inputs[index][2], self.data_path
39 |
40 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.inputs[index][0]
41 |
42 | def get_path(self, imgs_path, a):
43 |
44 | if self.typ == 'rgb':
45 | return os.path.join(imgs_path, "%06d.jpg" % int(a + 1))
46 | else:
47 | return os.path.join(imgs_path, "MDepth-%08d.png" % int(a + 1))
48 |
49 | def __len__(self):
50 | return len(self.inputs)
51 |
--------------------------------------------------------------------------------
/lib/datasets/NvGesture.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .base import Datasets
7 | from torchvision import transforms, set_image_backend
8 | import random, os
9 | from PIL import Image
10 | import numpy as np
11 | import logging
12 | set_image_backend('accimage')
13 | np.random.seed(123)
14 |
15 | class NvData(Datasets):
16 | def __init__(self, args, ground_truth, modality, phase='train'):
17 | super(NvData, self).__init__(args, ground_truth, modality, phase)
18 | def transform_params(self, resize=(320, 240), crop_size=224, flip=0.5):
19 | if self.phase == 'train':
20 | left, top = random.randint(10, resize[0] - crop_size), random.randint(10, resize[1] - crop_size)
21 | is_flip = True if random.uniform(0, 1) < flip else False
22 | else:
23 | left, top = 32, 32
24 | is_flip = False
25 | return (left, top, left + crop_size, top + crop_size), is_flip
26 |
27 | def __getitem__(self, index):
28 | """
29 | Args:
30 | index (int): Index
31 | Returns:
32 | tuple: (image, target) where target is class_index of the target class.
33 | """
34 | sl = self.get_sl(self.inputs[index][1])
35 | self.data_path = os.path.join(self.dataset_root, self.typ, self.inputs[index][0])
36 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
37 |
38 | if self.args.Network == 'FusionNet':
39 | assert self.typ == 'rgb'
40 | self.data_path = self.data_path.replace('rgb', 'depth')
41 | self.clip1, skgmaparr1 = self.image_propose(self.data_path, sl)
42 |
43 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), self.inputs[index][2], self.data_path
44 |
45 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path
46 |
47 | def __len__(self):
48 | return len(self.inputs)
49 |
--------------------------------------------------------------------------------
/lib/datasets/THU_READ.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .base import Datasets
7 | from torchvision import transforms, set_image_backend
8 | import random, os
9 | from PIL import Image
10 | import numpy as np
11 | import logging
12 |
13 | np.random.seed(123)
14 |
15 |
16 | class THUREAD(Datasets):
17 | def __init__(self, args, ground_truth, modality, phase='train'):
18 | super(THUREAD, self).__init__(args, ground_truth, modality, phase)
19 |
20 | def __getitem__(self, index):
21 | """
22 | Args:
23 | index (int): Index
24 | Returns:
25 | tuple: (image, target) where target is class_index of the target class.
26 | """
27 | sl = self.get_sl(self.inputs[index][1])
28 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0])
29 | self.clip, skgmaparr = self.image_propose(self.data_path, sl)
30 |
31 | if self.args.Network == 'FusionNet':
32 | assert self.typ == 'rgb'
33 | self.data_path1 = self.data_path.replace('RGB', 'Depth')
34 | self.data_path1 = '/'.join(self.data_path1.split('/')[:-1]) + '/{}'.format(
35 | self.data_path1.split('/')[-1].replace('Depth', 'D'))
36 |
37 | self.clip1, skgmaparr1 = self.image_propose(self.data_path1, sl)
38 |
39 | return (self.clip.permute(0, 3, 1, 2), self.clip1.permute(0, 3, 1, 2)), (skgmaparr, skgmaparr1), \
40 | self.inputs[index][2], self.data_path
41 |
42 | return self.clip.permute(0, 3, 1, 2), skgmaparr, self.inputs[index][2], self.data_path
43 |
44 | def __len__(self):
45 | return len(self.inputs)
46 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | from .build import *
6 |
--------------------------------------------------------------------------------
/lib/datasets/base.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/lib/datasets/base.py
4 | '''
5 |
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision import transforms, set_image_backend
9 | import torch.nn.functional as F
10 |
11 | from PIL import Image
12 | from PIL import ImageFilter, ImageOps
13 | import os, glob
14 | import math, random
15 | import numpy as np
16 | import logging
17 | from tqdm import tqdm as tqdm
18 | import pandas as pd
19 | from multiprocessing import Pool, cpu_count
20 | import multiprocessing as mp
21 | import cv2
22 | import json
23 | from scipy.ndimage.filters import gaussian_filter
24 |
25 | # import functools
26 | import matplotlib.pyplot as plt # For graphics
27 | np.random.seed(123)
28 |
29 | class GaussianBlur(object):
30 | """
31 | Apply Gaussian Blur to the PIL image.
32 | """
33 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
34 | self.prob = p
35 | self.radius_min = radius_min
36 | self.radius_max = radius_max
37 |
38 | def __call__(self, img):
39 | do_it = random.random() <= self.prob
40 | if not do_it:
41 | return img
42 |
43 | return img.filter(
44 | ImageFilter.GaussianBlur(
45 | radius=random.uniform(self.radius_min, self.radius_max)
46 | )
47 | )
48 | class Normaliztion(object):
49 | """
50 | same as mxnet, normalize into [-1, 1]
51 | image = (image - 127.5)/128
52 | """
53 |
54 | def __call__(self, Image):
55 | new_video_x = (Image - 127.5) / 128
56 | return new_video_x
57 |
58 | class Datasets(Dataset):
59 | global kpt_dict
60 | def __init__(self, args, ground_truth, modality, phase='train'):
61 |
62 | def get_data_list_and_label(data_df):
63 | return [(lambda arr: (arr[0], int(arr[1]), int(arr[2])))(i[:-1].split(' '))
64 | for i in open(data_df).readlines()]
65 |
66 | self.dataset_root = args.data
67 | self.sample_duration = args.sample_duration
68 | self.sample_size = args.sample_size
69 | self.phase = phase
70 | args.phase = phase
71 | self.typ = modality
72 | self.args = args
73 | self._w = args.w
74 |
75 | self.transform = transforms.Compose([Normaliztion(), transforms.ToTensor()])
76 |
77 | self.inputs = list(filter(lambda x: x[1] > 16, get_data_list_and_label(ground_truth)))
78 | self.inputs = list(self.inputs)
79 | if phase == 'train':
80 | while len(self.inputs) % (args.batch_size * args.nprocs) != 0:
81 | sample = random.choice(self.inputs)
82 | self.inputs.append(sample)
83 | logging.info('Training Data Size is: {}'.format(len(self.inputs)))
84 | frames = [n[1] for n in self.inputs]
85 | logging.info('Average Train Data frames are: {}, max frames: {}, min frames: {}'.format(sum(frames)//len(self.inputs), max(frames), min(frames)))
86 | else:
87 | logging.info('Validation Data Size is: {} '.format(len(self.inputs)))
88 | frames = [n[1] for n in self.inputs]
89 | logging.info('Average Train Data frames are: {}, max frames: {}, min frames: {}'.format(
90 | sum(frames) // len(self.inputs), max(frames), min(frames)))
91 |
92 | def transform_params(self, resize=(320, 240), crop_size=224, flip=0.5):
93 | if self.phase == 'train':
94 | left, top = np.random.randint(0, resize[0] - crop_size), np.random.randint(0, resize[1] - crop_size)
95 | is_flip = True if np.random.uniform(0, 1) < flip else False
96 | else:
97 | left, top = (resize[0] - crop_size) // 2, (resize[1] - crop_size) // 2
98 |
99 | is_flip = False
100 | return (left, top, left + crop_size, top + crop_size), is_flip
101 |
102 | def rotate(self, image, angle, center=None, scale=1.0):
103 | (h, w) = image.shape[:2]
104 | if center is None:
105 | center = (w / 2, h / 2)
106 | M = cv2.getRotationMatrix2D(center, angle, scale)
107 | rotated = cv2.warpAffine(image, M, (w, h))
108 | return rotated
109 |
110 | def get_path(self, imgs_path, a):
111 | return os.path.join(imgs_path, "%06d.jpg" % a)
112 |
113 | def depthProposess(self, img):
114 | h2, w2 = img.shape
115 |
116 | mask = img.copy()
117 | mask = cv2.erode(mask, np.ones((3, 3), np.uint8))
118 | mask = cv2.dilate(mask, np.ones((10, 10), np.uint8))
119 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
120 | # Find Max Maxtri
121 | Idx = []
122 | for i in range(len(contours)):
123 | Area = cv2.contourArea(contours[i])
124 | if Area > 500:
125 | Idx.append(i)
126 | centers = []
127 |
128 | for i in Idx:
129 | rect = cv2.minAreaRect(contours[i])
130 | center, (h, w), degree = rect
131 | centers.append(center)
132 |
133 | finall_center = np.int0(np.array(centers))
134 | c_x = min(finall_center[:, 0])
135 | c_y = min(finall_center[:, 1])
136 | center = (c_x, c_y)
137 |
138 | crop_x, crop_y = 320, 240
139 | left = center[0] - crop_x // 2 if center[0] - crop_x // 2 > 0 else 0
140 | top = center[1] - crop_y // 2 if center[1] - crop_y // 2 > 0 else 0
141 | crop_w = left + crop_x if left + crop_x < w2 else w2
142 | crop_h = top + crop_y if top + crop_y < h2 else h2
143 | rect = (left, top, crop_w, crop_h)
144 | image = Image.fromarray(img)
145 | image = image.crop(rect)
146 | return image
147 |
148 | def image_propose(self, data_path, sl):
149 | sample_size = self.sample_size
150 | resize = eval(self.args.resize)
151 | crop_rect, is_flip = self.transform_params(resize=resize, crop_size=self.args.crop_size, flip=self.args.flip) # no flip
152 | if np.random.uniform(0, 1) < self.args.rotated and self.phase == 'train':
153 | r, l = eval(self.args.angle)
154 | rotated = np.random.randint(r, l)
155 | else:
156 | rotated = 0
157 |
158 | def transform(img):
159 | img = np.asarray(img)
160 | if img.shape[-1] != 3:
161 | img = np.uint8(255 * img)
162 | img = self.depthProposess(img)
163 | img = cv2.applyColorMap(np.asarray(img), cv2.COLORMAP_JET)
164 | img = self.rotate(np.asarray(img), rotated)
165 | img = Image.fromarray(img)
166 | img = img.resize(resize)
167 | img = img.crop(crop_rect)
168 | if self.args.Blur and self.args.phase == 'train':
169 | img = GaussianBlur()(img)
170 | if is_flip:
171 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
172 | return np.array(img.resize((sample_size, sample_size)))
173 |
174 | def Sample_Image(imgs_path, sl):
175 | frams = []
176 | for a in sl:
177 | try:
178 | ori_image = Image.open(self.get_path(imgs_path, a))
179 | except:
180 | ori_image = Image.open(os.path.join(imgs_path, "MDepth-%08d.png" % int(a+1))) # For NTU fusion
181 | img = transform(ori_image)
182 | frams.append(self.transform(img).view(3, sample_size, sample_size, 1))
183 | skgmaparr = DynamicImage(frams, dynamic_only=False)
184 | return torch.cat(frams, dim=3).type(torch.FloatTensor), skgmaparr.unsqueeze(0)
185 |
186 | def DynamicImage(frames, dynamic_only): # frames: [[3, 224, 224, 1], ]
187 | def tensor_arr_rp(arr):
188 | l = len(arr)
189 | statics = []
190 | def tensor_rankpooling(video_arr, lamb=1.):
191 | def get_w(N):
192 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)]
193 |
194 | # re = torch.zeros(video_arr[0].size(0), 1, video_arr[0].size(2), video_arr[0].size(3)).cuda()
195 | re = torch.zeros(video_arr[0].size())
196 | for a, b in zip(video_arr, get_w(len(video_arr))):
197 | # a = transforms.Grayscale(1)(a)
198 | re += a * b
199 | re = F.relu(re) * lamb
200 | re -= torch.min(re)
201 | re = re / torch.max(re) if torch.max(re) != 0 else re / (torch.max(re) + 0.00001)
202 |
203 | re = transforms.Grayscale(1)(re.squeeze())
204 | # Static Attention
205 | static = torch.where(re > torch.mean(re), re, torch.full_like(re, 0))
206 | static = np.asarray(static.squeeze())
207 | # static = cv2.morphologyEx(static, cv2.MORPH_OPEN, kernel=np.ones((3, 3), np.uint8))
208 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
209 | static = cv2.erode(static, kernel)
210 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
211 | static = cv2.dilate(static, kernel)
212 | static -= np.min(static)
213 | static = static / np.max(static) if np.max(static) != 0 else static / (np.max(static) + 0.00001)
214 | statics.append(torch.from_numpy(static).unsqueeze(0))
215 | return re
216 |
217 | return [tensor_rankpooling(arr[i:i + self._w]) for i in range(l)], statics
218 | arrrp, statics = tensor_arr_rp(frames)
219 | arrrp = torch.cat(arrrp, dim=0) # torch.Size([64, 224, 224])
220 | t, h, w = arrrp.shape
221 | mask = torch.zeros(self._w - 1, h, w)
222 | garrs = torch.cat((mask, arrrp), dim=0)[:t, :]
223 | statics = torch.cat(statics)
224 | statics = torch.cat((mask, statics))[:t, :]
225 | if dynamic_only:
226 | return garrs
227 | return (garrs + statics) * statics
228 | return Sample_Image(data_path, sl)
229 |
230 | def get_sl(self, clip):
231 | sn = self.sample_duration
232 | if self.phase == 'train':
233 | f = lambda n: [(lambda n, arr: n if arr == [] else np.random.choice(arr))(n * i / sn,
234 | range(int(n * i / sn),
235 | max(int(n * i / sn) + 1,
236 | int(n * (
237 | i + 1) / sn))))
238 | for i in range(sn)]
239 | else:
240 | f = lambda n: [(lambda n, arr: n if arr == [] else int(np.mean(arr)))(n * i / sn, range(int(n * i / sn),
241 | max(int(
242 | n * i / sn) + 1,
243 | int(n * (
244 | i + 1) / sn))))
245 | for i in range(sn)]
246 | return f(int(clip))
247 | def __getitem__(self, index):
248 | """
249 | Args:
250 | index (int): Index
251 | Returns:
252 | tuple: (image, target) where target is class_index of the target class.
253 | """
254 | sl = self.get_sl(self.inputs[index][1])
255 | self.data_path = os.path.join(self.dataset_root, self.inputs[index][0])
256 | self.clip = self.image_propose(self.data_path, sl)
257 | return self.clip.permute(0, 3, 1, 2), self.inputs[index][2]
258 | def __len__(self):
259 | return len(self.inputs)
260 |
261 | if __name__ == '__main__':
262 | import argparse
263 | from config import Config
264 | from lib import *
265 | parser = argparse.ArgumentParser()
266 | parser.add_argument('--config', default='', help='Place config Congfile!')
267 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?')
268 | parser.add_argument('--local_rank', type=int, default=0)
269 | parser.add_argument('--nprocs', type=int, default=1)
270 |
271 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?')
272 | parser.add_argument('--save_output', action='store_true', help='Save logits?')
273 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo')
274 |
275 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability')
276 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment name')
277 | parser.add_argument('--seed', type=int, default=123, help='random seed')
278 | args = parser.parse_args()
279 | args = Config(args)
280 | np.random.seed(args.seed)
281 | torch.manual_seed(args.seed)
282 | args.dist = False
283 | args.eval_only = True
284 | args.test_batch_size = 1
285 |
286 | valid_queue, valid_sampler = build_dataset(args, phase='val')
287 | for step, (inputs, heatmap, target, _) in enumerate(valid_queue):
288 | print(inputs.shape)
289 | input()
--------------------------------------------------------------------------------
/lib/datasets/build.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from .distributed_sampler import DistributedSampler
7 | from .IsoGD import IsoGDData
8 | from .NvGesture import NvData
9 | from .THU_READ import THUREAD
10 | from .Jester import JesterData
11 | from .NTU import NTUData
12 | import logging
13 |
14 | def build_dataset(args, phase):
15 | modality = dict(
16 | M='rgb',
17 | K='depth',
18 | F='Flow'
19 | )
20 | assert args.type in modality, 'Error in modality!'
21 | Datasets_func = dict(
22 | NvGesture=NvData,
23 | IsoGD=IsoGDData,
24 | THUREAD=THUREAD,
25 | Jester=JesterData,
26 | NTU=NTUData
27 | )
28 | assert args.dataset in Datasets_func, 'Error in dataset Function!'
29 | if args.local_rank == 0:
30 | logging.info('Dataset:{}, Modality:{}'.format(args.dataset, modality[args.type]))
31 |
32 | if args.dataset in ['THUREAD'] and args.type == 'K':
33 | splits = args.splits + '/depth_{}_lst.txt'.format(phase)
34 | else:
35 | splits = args.splits + '/{}.txt'.format(phase)
36 | dataset = Datasets_func[args.dataset](args, splits, modality[args.type], phase=phase)
37 | if args.dist:
38 | data_sampler = DistributedSampler(dataset)
39 | else:
40 | data_sampler = None
41 |
42 | if phase == 'train':
43 | return torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
44 | shuffle=(data_sampler is None),
45 | sampler=data_sampler, pin_memory=True), data_sampler
46 | else:
47 | # if args.eval_only and args.nprocs == 1:
48 | # args.test_batch_size = 8
49 | return torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, num_workers=args.num_workers,
50 | shuffle=False,
51 | sampler=data_sampler, pin_memory=True, drop_last=False if args.eval_only else True), data_sampler
--------------------------------------------------------------------------------
/lib/datasets/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/samplers/distributed_sampler.py
4 | '''
5 |
6 | import math
7 | import torch
8 | from torch.utils.data import DistributedSampler as _DistributedSampler
9 |
10 | class DistributedSampler(_DistributedSampler):
11 |
12 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
13 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
14 | self.shuffle = shuffle
15 |
16 | def __iter__(self):
17 | # deterministically shuffle based on epoch
18 | if self.shuffle:
19 | g = torch.Generator()
20 | g.manual_seed(self.epoch)
21 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
22 | else:
23 | indices = torch.arange(len(self.dataset)).tolist()
24 |
25 | # add extra samples to make it evenly divisible
26 | # in case that indices is shorter than half of total_size
27 | indices = (indices *
28 | math.ceil(self.total_size / len(indices)))[:self.total_size]
29 | assert len(indices) == self.total_size
30 |
31 | # subsample
32 | indices = indices[self.rank:self.total_size:self.num_replicas]
33 | assert len(indices) == self.num_samples
34 |
35 | return iter(indices)
36 |
--------------------------------------------------------------------------------
/lib/model/DSN.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/deepmind/kinetics-i3d/i3d.py
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops.layers.torch import Rearrange
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import cv2
13 | import os, math
14 | import sys
15 | from .DTN import DTNNet
16 | from .FRP import FRP_Module
17 | from .utils import *
18 |
19 | import os, math
20 | import sys
21 | sys.path.append('../../')
22 | from collections import OrderedDict
23 | from utils import load_pretrained_checkpoint
24 | import logging
25 |
26 | class DSNNet(nn.Module):
27 | VALID_ENDPOINTS = (
28 | 'Conv3d_1a_7x7',
29 | 'MaxPool3d_2a_3x3',
30 | 'Conv3d_2b_1x1',
31 | 'Conv3d_2c_3x3',
32 | 'MaxPool3d_3a_3x3',
33 |
34 | 'Mixed_3b',
35 | 'Mixed_3c',
36 | 'MaxPool3d_4a_3x3',
37 | 'Mixed_4b',
38 | 'Mixed_4c',
39 | 'MaxPool3d_5a_2x2',
40 | 'Mixed_5b',
41 | 'Mixed_5c'
42 | )
43 |
44 | def __init__(self, args, num_classes=400, spatial_squeeze=True, name='inception_i3d', in_channels=3, dropout_keep_prob=0.5,
45 | pretrained: str = False,
46 | dropout_spatial: float = 0.0):
47 |
48 | super(DSNNet, self).__init__()
49 | self._num_classes = num_classes
50 | self._spatial_squeeze = spatial_squeeze
51 | self.logits = None
52 | self.args = args
53 |
54 | self.end_points = {}
55 |
56 | '''
57 | Low Level Features Extraction
58 | '''
59 | end_point = 'Conv3d_1a_7x7'
60 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[1, 7, 7],
61 | stride=(1, 2, 2), padding=(0, 3, 3), name=name + end_point)
62 |
63 | end_point = 'MaxPool3d_2a_3x3'
64 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
65 | padding=0)
66 |
67 | end_point = 'Conv3d_2b_1x1'
68 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
69 | name=name + end_point)
70 |
71 | end_point = 'Conv3d_2c_3x3'
72 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[1, 3, 3],
73 | padding=(0, 1, 1),
74 | name=name + end_point)
75 |
76 | end_point = 'MaxPool3d_3a_3x3'
77 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
78 | padding=0)
79 |
80 | '''
81 | Spatial Multi-scale Features Learning
82 | '''
83 | end_point = 'Mixed_3b'
84 | self.end_points[end_point] = SpatialInceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
85 |
86 | end_point = 'Mixed_3c'
87 | self.end_points[end_point] = SpatialInceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
88 |
89 | end_point = 'MaxPool3d_4a_3x3'
90 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
91 | padding=0)
92 |
93 | end_point = 'Mixed_4b'
94 | self.end_points[end_point] = SpatialInceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
95 |
96 | end_point = 'Mixed_4c'
97 | self.end_points[end_point] = SpatialInceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
98 |
99 | end_point = 'MaxPool3d_5a_2x2'
100 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2),
101 | padding=0)
102 |
103 | end_point = 'Mixed_5b'
104 | self.end_points[end_point] = SpatialInceptionModule(160 + 224 + 64 + 64, [256, 160, 320, 32, 128, 128],
105 | name + end_point)
106 |
107 | end_point = 'Mixed_5c'
108 | self.end_points[end_point] = SpatialInceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
109 | name + end_point)
110 |
111 | self.LinearMap = nn.Sequential(
112 | nn.LayerNorm(1024),
113 | nn.Linear(1024, 512),
114 | # nn.Dropout(dropout_spatial)
115 | )
116 |
117 | self.avg_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
118 | self.dropout = nn.Dropout(dropout_keep_prob)
119 | self.build()
120 | self.dtn = DTNNet(args, num_classes=self._num_classes)
121 | self.rrange = Rearrange('b c t h w -> b t c h w')
122 |
123 | if args.frp:
124 | self.frp_module = FRP_Module(w=args.w, inplanes=64)
125 |
126 | if pretrained:
127 | load_pretrained_checkpoint(self, pretrained)
128 |
129 | def build(self):
130 | for k in self.end_points.keys():
131 | self.add_module(k, self.end_points[k])
132 |
133 | def forward(self, x, garr):
134 | inp = x
135 | for end_point in self.VALID_ENDPOINTS:
136 | if end_point in self.end_points:
137 | if end_point in ['Mixed_3b']:
138 | x = self._modules[end_point](x)
139 | if self.args.frp:
140 | x = self.frp_module(x, garr) + x
141 | elif end_point in ['Mixed_4b']:
142 | x = self._modules[end_point](x)
143 | if self.args.frp:
144 | x = self.frp_module(x, garr) + x
145 | f = x
146 | elif end_point in ['Mixed_5b']:
147 | x = self._modules[end_point](x)
148 | if self.args.frp:
149 | x = self.frp_module(x, garr) + x
150 | else:
151 | x = self._modules[end_point](x)
152 | feat = x
153 |
154 | x = self.avg_pool(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
155 | x = self.LinearMap(x)
156 | cnn_vison = self.rrange(f.sum(dim=1, keepdim=True))
157 | logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight) = self.dtn(x)
158 | # return logits, distillation_loss, (cnn_vison[0].detach(), att_map, inp[0, :],
159 | # cosin_similar, MHAS, (feat, logits[0]))
160 | return logits, distillation_loss, (cnn_vison[0], att_map, cosin_similar, visweight, MHAS, (feat, inp[0, :]))
161 |
--------------------------------------------------------------------------------
/lib/model/DSN_Fusion.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/deepmind/kinetics-i3d/i3d.py
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops.layers.torch import Rearrange
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import cv2
13 | import os, math
14 | import sys
15 | from .DTN import DTNNet
16 | from .FRP import FRP_Module
17 | from .utils import *
18 |
19 | import os, math
20 | import sys
21 | sys.path.append('../../')
22 | from collections import OrderedDict
23 | from utils import load_pretrained_checkpoint
24 | import logging
25 |
26 |
27 | class DSNNet(nn.Module):
28 | VALID_ENDPOINTS = (
29 | 'Conv3d_1a_7x7',
30 | 'MaxPool3d_2a_3x3',
31 | 'Conv3d_2b_1x1',
32 | 'Conv3d_2c_3x3',
33 | 'MaxPool3d_3a_3x3',
34 |
35 | 'Mixed_3b',
36 | 'Mixed_3c',
37 | 'MaxPool3d_4a_3x3',
38 | 'Mixed_4b',
39 | 'Mixed_4c',
40 | 'MaxPool3d_5a_2x2',
41 | 'Mixed_5b',
42 | 'Mixed_5c'
43 | )
44 |
45 | def __init__(self, args, num_classes=400, spatial_squeeze=True, name='inception_i3d', in_channels=3, dropout_keep_prob=0.5,
46 | pretrained: str = False):
47 |
48 | super(DSNNet, self).__init__()
49 | self._num_classes = num_classes
50 | self._spatial_squeeze = spatial_squeeze
51 | self.logits = None
52 | self.args = args
53 |
54 | self.end_points = {}
55 |
56 | '''
57 | Low Level Features Extraction
58 | '''
59 | end_point = 'Conv3d_1a_7x7'
60 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[1, 7, 7],
61 | stride=(1, 2, 2), padding=(0, 3, 3), name=name + end_point)
62 |
63 | end_point = 'MaxPool3d_2a_3x3'
64 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
65 | padding=0)
66 |
67 | end_point = 'Conv3d_2b_1x1'
68 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
69 | name=name + end_point)
70 |
71 | end_point = 'Conv3d_2c_3x3'
72 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[1, 3, 3],
73 | padding=(0, 1, 1),
74 | name=name + end_point)
75 |
76 | end_point = 'MaxPool3d_3a_3x3'
77 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
78 | padding=0)
79 |
80 | '''
81 | Spatial Multi-scale Features Learning
82 | '''
83 | end_point = 'Mixed_3b'
84 | self.end_points[end_point] = SpatialInceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point)
85 |
86 | end_point = 'Mixed_3c'
87 | self.end_points[end_point] = SpatialInceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point)
88 |
89 | end_point = 'MaxPool3d_4a_3x3'
90 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
91 | padding=0)
92 |
93 | end_point = 'Mixed_4b'
94 | self.end_points[end_point] = SpatialInceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
95 |
96 | end_point = 'Mixed_4c'
97 | self.end_points[end_point] = SpatialInceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
98 |
99 | end_point = 'MaxPool3d_5a_2x2'
100 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 2, 2], stride=(1, 2, 2),
101 | padding=0)
102 |
103 | end_point = 'Mixed_5b'
104 | self.end_points[end_point] = SpatialInceptionModule(160 + 224 + 64 + 64, [256, 160, 320, 32, 128, 128],
105 | name + end_point)
106 |
107 | end_point = 'Mixed_5c'
108 | self.end_points[end_point] = SpatialInceptionModule(256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
109 | name + end_point)
110 |
111 | self.LinearMap = nn.Sequential(
112 | nn.LayerNorm(1024),
113 | nn.Linear(1024, 512),
114 |
115 | )
116 |
117 | self.avg_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
118 | self.dropout = nn.Dropout(dropout_keep_prob)
119 | self.build()
120 | self.dtn = DTNNet(args, num_classes=self._num_classes)
121 | self.rrange = Rearrange('b c t h w -> b t c h w')
122 |
123 | if args.frp:
124 | self.frp_module = FRP_Module(w=args.w, inplanes=64)
125 |
126 | if pretrained:
127 | load_pretrained_checkpoint(self, pretrained)
128 |
129 | def build(self):
130 | for k in self.end_points.keys():
131 | self.add_module(k, self.end_points[k])
132 |
133 | def forward(self, x=None, garr=None, endpoint=None):
134 | if endpoint == 'spatial':
135 | for end_point in self.VALID_ENDPOINTS:
136 | if end_point in self.end_points:
137 | if end_point in ['Mixed_3b']:
138 | x = self._modules[end_point](x)
139 | if self.args.frp:
140 | x = self.frp_module(x, garr) + x
141 | elif end_point in ['Mixed_4b']:
142 | x = self._modules[end_point](x)
143 | if self.args.frp:
144 | x = self.frp_module(x, garr) + x
145 | f = x
146 | elif end_point in ['Mixed_5b']:
147 | x = self._modules[end_point](x)
148 | if self.args.frp:
149 | x = self.frp_module(x, garr) + x
150 | else:
151 | x = self._modules[end_point](x)
152 |
153 | x = self.avg_pool(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
154 | x = self.LinearMap(x)
155 | return x
156 | else:
157 | logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight) = self.dtn(x)
158 | return logits, distillation_loss, (att_map, cosin_similar, MHAS, visweight)
159 |
--------------------------------------------------------------------------------
/lib/model/DTN.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | from torch.autograd import Variable
7 | from torch import nn, einsum
8 | import torch.nn.functional as F
9 | from torch.nn import init
10 |
11 | from einops import rearrange, repeat
12 | from einops.layers.torch import Rearrange
13 | import numpy as np
14 | import random, math
15 | from .utils import *
16 | from .trans_module import *
17 |
18 | np.random.seed(123)
19 | random.seed(123)
20 |
21 |
22 | class Transformer(nn.Module):
23 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., apply_transform=False, knn_attention=0.7):
24 | super().__init__()
25 | self.layers = nn.ModuleList([])
26 | for _ in range(depth):
27 | self.layers.append(nn.ModuleList([
28 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout,
29 | apply_transform=apply_transform, knn_attention=knn_attention)),
30 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
31 | ]))
32 |
33 | def forward(self, x):
34 | for attn, ff in self.layers:
35 | x = attn(x) + x
36 | x = ff(x) + x
37 | return x
38 |
39 |
40 | class MultiScaleTransformerEncoder(nn.Module):
41 |
42 | def __init__(self, args, small_dim=1024, small_depth=4, small_heads=8, small_dim_head=64, hidden_dim_small=768,
43 | media_dim=1024, media_depth=4, media_heads=8, media_dim_head=64, hidden_dim_media=768,
44 | large_dim=1024, large_depth=4, large_heads=8, large_dim_head=64, hidden_dim_large=768,
45 | dropout=0., Local_flag=True):
46 | super().__init__()
47 |
48 | self.transformer_enc_small = Transformer(small_dim, small_depth, small_heads, small_dim_head,
49 | mlp_dim=hidden_dim_small, dropout=dropout, knn_attention=args.knn_attention)
50 | self.transformer_enc_media = Transformer(media_dim, media_depth, media_heads, media_dim_head,
51 | mlp_dim=hidden_dim_media, dropout=dropout, knn_attention=args.knn_attention)
52 | self.transformer_enc_large = Transformer(large_dim, large_depth, large_heads, large_dim_head,
53 | mlp_dim=hidden_dim_large, dropout=dropout, knn_attention=args.knn_attention)
54 | if Local_flag:
55 | self.Mixed_small = TemporalInceptionModule(512, [160,112,224,24,64,64], 'Mixed_small')
56 | self.Mixed_media = TemporalInceptionModule(512, [160,112,224,24,64,64], 'Mixed_media')
57 | self.Mixed_large = TemporalInceptionModule(512, [160, 112, 224, 24, 64, 64], 'Mixed_large')
58 | self.MaxPool = MaxPool3dSamePadding(kernel_size=[3, 1, 1], stride=(1, 1, 1), padding=0)
59 |
60 | def forward(self, xs, xm, xl, Local_flag=False):
61 | # Local Modeling
62 | if Local_flag:
63 | cls_small = xs[:, 0]
64 | xs = self.Mixed_small(xs[:, 1:, :].permute(0, 2, 1).view(xs.size(0), xs.size(-1), -1, 1, 1))
65 | xs = self.MaxPool(xs)
66 | xs = torch.cat((cls_small.unsqueeze(1), xs.view(xs.size(0), xs.size(1), -1).permute(0, 2, 1)), dim=1)
67 |
68 | cls_media = xm[:, 0]
69 | xm = self.Mixed_media(xm[:, 1:, :].permute(0, 2, 1).view(xm.size(0), xm.size(-1), -1, 1, 1))
70 | xm = self.MaxPool(xm)
71 | xm = torch.cat((cls_media.unsqueeze(1), xm.view(xm.size(0), xm.size(1), -1).permute(0, 2, 1)), dim=1)
72 |
73 | cls_large = xl[:, 0]
74 | xl = self.Mixed_large(xl[:, 1:, :].permute(0, 2, 1).view(xl.size(0), xl.size(-1), -1, 1, 1))
75 | xl = self.MaxPool(xl)
76 | xl = torch.cat((cls_large.unsqueeze(1), xl.view(xl.size(0), xl.size(1), -1).permute(0, 2, 1)), dim=1)
77 |
78 | # Global Modeling
79 | xs = self.transformer_enc_small(xs)
80 | xm = self.transformer_enc_media(xm)
81 | xl = self.transformer_enc_large(xl)
82 |
83 | return xs, xm, xl
84 |
85 |
86 | class RCMModule(nn.Module):
87 | def __init__(self, args, dim_head=64, method='New', merge='GAP'):
88 | super(RCMModule, self).__init__()
89 | self.merge = merge
90 | self.heads = args.SEHeads
91 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
92 | self.avg_pool3d = nn.AdaptiveAvgPool3d((None, 1, None))
93 |
94 | # Self Attention Layers
95 | self.q = nn.Linear(64, dim_head * self.heads, bias=False)
96 | self.k = nn.Linear(64, dim_head * self.heads, bias=False)
97 | self.scale = dim_head ** -0.5
98 |
99 | self.method = method
100 | if method == 'Ori':
101 | self.norm = nn.LayerNorm(128)
102 | self.project = nn.Sequential(
103 | nn.Linear(64, 512, bias=False),
104 | nn.GELU(),
105 | nn.Linear(512, 512, bias=False),
106 | nn.LayerNorm(512)
107 | )
108 | elif method == 'New':
109 | if args.dataset == 'THU':
110 | hidden_dim = 128
111 | else:
112 | hidden_dim = 256
113 | self.project = nn.Sequential(
114 | nn.Linear(64, hidden_dim, bias=False),
115 | nn.GELU(),
116 | nn.Linear(hidden_dim, 64, bias=False),
117 | nn.LayerNorm(64),
118 | )
119 | self.linear = nn.Linear(64, 512)
120 | # init.kaiming_uniform_(self.linear, a=math.sqrt(5))
121 |
122 | if self.heads > 1:
123 | self.mergefc = nn.Sequential(
124 | nn.Dropout(0.4),
125 | nn.Linear(512 * self.heads, 512, bias=False),
126 | nn.LayerNorm(512)
127 | )
128 |
129 | def forward(self, x):
130 | b, c, t = x.shape
131 | inp = x.clone()
132 |
133 | # Sequence (Y) direction
134 | xd_weight = self.project(self.avg_pool(inp.permute(0, 2, 1)).view(b, -1))
135 | xd_weight = torch.sigmoid(xd_weight).view(b, -1, 1)
136 |
137 | # Feature (X) direction
138 | q, k = self.q(x), self.k(x)
139 | q, k = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), [q, k])
140 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
141 | if self.merge == 'mean':
142 | dots = dots.mean(dim=2)
143 | elif self.merge == 'GAP':
144 | dots = self.avg_pool3d(dots).squeeze()
145 |
146 | if self.heads > 1:
147 | dots = dots.view(b, -1)
148 | dots = self.mergefc(dots)
149 | else:
150 | dots = dots.squeeze()
151 | y = torch.sigmoid(dots).view(b, c, 1)
152 |
153 | if self.method == 'Ori':
154 | out = x * (y.expand_as(x) + xd_weight.expand_as(x))
155 | visweight = xd_weight # for visualization
156 | return out, xd_weight, visweight
157 |
158 | elif self.method == 'New':
159 | weight = einsum('b i d, b j d -> b i j', xd_weight, y)
160 | out = x * weight.permute(0, 2, 1)
161 | visweight = weight # for visualization
162 | return out, self.linear(xd_weight.squeeze()), visweight
163 |
164 | class DTNNet(nn.Module):
165 | def __init__(self, args, num_classes=249, small_dim=512, media_dim=512, large_dim=512,
166 | small_depth=1, media_depth=1, large_depth=1,
167 | heads=8, pool='cls', dropout=0.1, emb_dropout=0.0, branch_merge='pool',
168 | init: bool = False,
169 | warmup_temp_epochs: int = 30):
170 | super().__init__()
171 | self.low_frames = args.low_frames
172 | self.media_frames = args.media_frames
173 | self.high_frames = args.high_frames
174 | self.branch_merge = branch_merge
175 | self._args = args
176 | warmup_temp, temp = map(float, args.temp)
177 |
178 | multi_scale_enc_depth = args.N
179 | num_patches_small = self.low_frames
180 | num_patches_media = self.media_frames
181 | num_patches_large = self.high_frames
182 |
183 | self.pos_embedding_small = nn.Parameter(torch.randn(1, num_patches_small + 1, small_dim))
184 | self.cls_token_small = nn.Parameter(torch.randn(1, 1, small_dim))
185 | self.dropout_small = nn.Dropout(emb_dropout)
186 |
187 | self.pos_embedding_media = nn.Parameter(torch.randn(1, num_patches_media + 1, media_dim))
188 | self.cls_token_media = nn.Parameter(torch.randn(1, 1, media_dim))
189 | self.dropout_media = nn.Dropout(emb_dropout)
190 |
191 | self.pos_embedding_large = nn.Parameter(torch.randn(1, num_patches_large + 1, large_dim))
192 | self.cls_token_large = nn.Parameter(torch.randn(1, 1, large_dim))
193 | self.dropout_large = nn.Dropout(emb_dropout)
194 |
195 | self.multi_scale_transformers = nn.ModuleList([])
196 | Local_flag = True
197 | for _ in range(multi_scale_enc_depth):
198 | self.multi_scale_transformers.append(
199 | MultiScaleTransformerEncoder(args, small_dim=small_dim, small_depth=small_depth,
200 | small_heads=heads,
201 |
202 | media_dim=media_dim, media_depth=media_depth,
203 | media_heads=heads,
204 |
205 | large_dim=large_dim, large_depth=large_depth,
206 | large_heads=heads,
207 | dropout=dropout,
208 | Local_flag=Local_flag))
209 | Local_flag = False
210 | self.pool = pool
211 | # self.to_latent = nn.Identity()
212 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
213 | self.max_pool = nn.AdaptiveMaxPool1d(1)
214 |
215 | if self._args.recoupling:
216 | self.rcm = RCMModule(args)
217 |
218 | if args.Network != 'FusionNet':
219 | self.mlp_head_small = nn.Sequential(
220 | nn.LayerNorm(small_dim),
221 | nn.Linear(small_dim, num_classes),
222 | # nn.Dropout(0.4)
223 | )
224 | self.mlp_head_media = nn.Sequential(
225 | nn.LayerNorm(media_dim),
226 | nn.Linear(media_dim, num_classes),
227 | # nn.Dropout(0.4)
228 |
229 | )
230 |
231 | self.mlp_head_large = nn.Sequential(
232 | nn.LayerNorm(large_dim),
233 | nn.Linear(large_dim, num_classes),
234 | # nn.Dropout(0.4)
235 | )
236 |
237 | self.show_res = Rearrange('b t (c p1 p2) -> b t c p1 p2', p1=int(small_dim ** 0.5), p2=int(small_dim ** 0.5))
238 | self.temp_schedule = np.concatenate((
239 | np.linspace(warmup_temp,
240 | temp, warmup_temp_epochs),
241 | np.ones(args.epochs - warmup_temp_epochs) * temp
242 | ))
243 |
244 | if init:
245 | self.init_weights()
246 |
247 | @torch.no_grad()
248 | def init_weights(self):
249 | def _init(m):
250 | if isinstance(m, nn.Linear):
251 | nn.init.xavier_uniform_(
252 | m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal
253 | if hasattr(m, 'bias') and m.bias is not None:
254 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0)
255 |
256 | self.apply(_init)
257 |
258 | # ----------------------------------
259 | # frames simple function
260 | # ----------------------------------
261 | def f(self, n, sn):
262 | SL = lambda n, sn: [(lambda n, arr: n if arr == [] else random.choice(arr))(n * i / sn,
263 | range(int(n * i / sn),
264 | max(int(n * i / sn) + 1,
265 | int(n * (
266 | i + 1) / sn))))
267 | for i in range(sn)]
268 | return SL(n, sn)
269 |
270 | def forward(self, img): # img size: [2, 64, 1024]
271 | # ----------------------------------
272 | # Recoupling:
273 | # ----------------------------------
274 | if self._args.recoupling:
275 | img, spatial_weights, visweight = self.rcm(img.permute(0, 2, 1))
276 | img = img.permute(0, 2, 1)
277 | else:
278 | visweight = img
279 |
280 | # ----------------------------------
281 | sl_low = self.f(img.size(1), self.low_frames)
282 | xs = img[:, sl_low, :]
283 | b, n, _ = xs.shape
284 |
285 | cls_token_small = repeat(self.cls_token_small, '() n d -> b n d', b=b)
286 | xs = torch.cat((cls_token_small, xs), dim=1)
287 | xs += self.pos_embedding_small[:, :(n + 1)]
288 | xs = self.dropout_small(xs)
289 |
290 | # ----------------------------------
291 | sl_media = self.f(img.size(1), self.media_frames)
292 | xm = img[:, sl_media, :]
293 | b, n, _ = xm.shape
294 |
295 | cls_token_media = repeat(self.cls_token_media, '() n d -> b n d', b=b)
296 | xm = torch.cat((cls_token_media, xm), dim=1)
297 | xm += self.pos_embedding_media[:, :(n + 1)]
298 | xm = self.dropout_media(xm)
299 |
300 | # ----------------------------------
301 | sl_high = self.f(img.size(1), self.high_frames)
302 | xl = img[:, sl_high, :]
303 | b, n, _ = xl.shape
304 |
305 | cls_token_large = repeat(self.cls_token_large, '() n d -> b n d', b=b)
306 | xl = torch.cat((cls_token_large, xl), dim=1)
307 | xl += self.pos_embedding_large[:, :(n + 1)]
308 | xl = self.dropout_large(xl)
309 |
310 | # ----------------------------------
311 | # Temporal Multi-scale features learning
312 | # ----------------------------------
313 | Local_flag = True
314 | for multi_scale_transformer in self.multi_scale_transformers:
315 | xs, xm, xl = multi_scale_transformer(xs, xm, xl, Local_flag)
316 | Local_flag = False
317 |
318 | xs = xs.mean(dim=1) if self.pool == 'mean' else xs[:, 0]
319 | xm = xm.mean(dim=1) if self.pool == 'mean' else xm[:, 0]
320 | xl = xl.mean(dim=1) if self.pool == 'mean' else xl[:, 0]
321 |
322 | if self._args.recoupling:
323 | T = self._args.temper
324 | distillation_loss = F.kl_div(F.log_softmax(spatial_weights.squeeze() / T, dim=-1),
325 | F.softmax(((xs + xm + xl) / 3.).detach() / T, dim=-1),
326 | reduction='sum')
327 | else:
328 | distillation_loss = torch.zeros(1).cuda()
329 |
330 | if self._args.Network != 'FusionNet':
331 | if self._args.sharpness:
332 | temp = self.temp_schedule[self._args.epoch]
333 | xs = self.mlp_head_small(xs) / temp
334 | xm = self.mlp_head_media(xm) / temp
335 | xl = self.mlp_head_large(xl) / temp
336 | else:
337 | xs = self.mlp_head_small(xs)
338 | xm = self.mlp_head_media(xm)
339 | xl = self.mlp_head_large(xl)
340 |
341 | if self.branch_merge == 'sum':
342 | x = xs + xm + xl
343 | elif self.branch_merge == 'pool':
344 | x = self.max_pool(torch.cat((xs.unsqueeze(2), xm.unsqueeze(2), xl.unsqueeze(2)), dim=-1)).squeeze()
345 |
346 | # ---------------------------------
347 | # Get score from multi-branch Trans for visualization
348 | # ---------------------------------
349 | scores_small = self.multi_scale_transformers[2].transformer_enc_small.layers[-1][0].fn.scores
350 | scores_media = self.multi_scale_transformers[2].transformer_enc_media.layers[-1][0].fn.scores
351 | scores_large = self.multi_scale_transformers[2].transformer_enc_large.layers[-1][0].fn.scores
352 |
353 | # resize attn
354 | attn_media = scores_media.detach().clone()
355 | attn_media.resize_(*scores_small.size())
356 |
357 | attn_large = scores_large.detach().clone()
358 | attn_large.resize_(*scores_small.size())
359 |
360 | att_small = scores_small.detach().clone()
361 |
362 | scores = torch.cat((att_small, attn_media, attn_large), dim=1) # [2, 24, 17, 17]
363 | att_map = torch.zeros(scores.size(0), scores.size(1), scores.size(1), dtype=torch.float)
364 | for b in range(scores.size(0)):
365 | for i, s1 in enumerate(scores[b]):
366 | for j, s2 in enumerate(scores[b]):
367 | cosin_simil = torch.cosine_similarity(s1.view(1, -1), s2.view(1, -1))
368 | att_map[b][i][j] = cosin_simil
369 |
370 | # --------------------------------
371 | # Measure cosine similarity of xs and xl
372 | # --------------------------------
373 | cosin_similar_xs_xm = torch.cosine_similarity(xs[0], xm[0], dim=-1)
374 | cosin_similar_xs_xl = torch.cosine_similarity(xs[0], xl[0], dim=-1)
375 | cosin_similar_xm_xl = torch.cosine_similarity(xm[0], xl[0], dim=-1)
376 | cosin_similar_sum = cosin_similar_xs_xm + cosin_similar_xs_xl + cosin_similar_xm_xl
377 |
378 | return (x, xs, xm, xl), distillation_loss, (att_map, cosin_similar_sum.cpu(),
379 | (scores_small[0], scores_media[0], scores_large[0]), visweight[0])
--------------------------------------------------------------------------------
/lib/model/FRP.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/lib/model/RAAR3DNet.py
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops.layers.torch import Rearrange
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | from torchvision import transforms
12 | import numpy as np
13 | import cv2
14 | from torchvision.utils import save_image, make_grid
15 |
16 | def tensor_split(t):
17 | arr = torch.split(t, 1, dim=2)
18 | arr = [x.squeeze(2) for x in arr]
19 | return arr
20 |
21 | def tensor_merge(arr):
22 | arr = [x.unsqueeze(1) for x in arr]
23 | t = torch.cat(arr, dim=1)
24 | return t.permute(0, 2, 1, 3, 4)
25 |
26 | class FRP_Module(nn.Module):
27 | def __init__(self, w, inplanes):
28 | super(FRP_Module, self).__init__()
29 | self._w = w
30 | self.rpconv1d = nn.Conv1d(2, 1, 1, bias=False) # Rank Pooling Conv1d, Kernel Size 2x1x1
31 | self.rpconv1d.weight.data = torch.FloatTensor([[[1.0], [0.0]]])
32 | # self.bnrp = nn.BatchNorm3d(inplanes) # BatchNorm Rank Pooling
33 | # self.relu = nn.ReLU(inplace=True)
34 | self.hapooling = nn.MaxPool2d(kernel_size=2)
35 |
36 | def forward(self, x, datt=None):
37 | inp = x
38 | if self._w < 1:
39 | return x
40 | def run_layer_on_arr(arr, l):
41 | return [l(x) for x in arr]
42 | def oneconv(a, b):
43 | s = a.size()
44 | c = torch.cat([a.contiguous().view(s[0], -1, 1), b.contiguous().view(s[0], -1, 1)], dim=2)
45 | c = self.rpconv1d(c.permute(0, 2, 1)).permute(0, 2, 1)
46 | return c.view(s)
47 | if datt is not None:
48 | tarr = tensor_split(x)
49 | garr = tensor_split(datt)
50 | while tarr[0].size()[3] < garr[0].size()[3]: # keep feature map and heatmap the same size
51 | garr = run_layer_on_arr(garr, self.hapooling)
52 |
53 | attarr = [a * (b + torch.ones(a.size()).cuda()) for a, b in zip(tarr, garr)]
54 | datt = [oneconv(a, b) for a, b in zip(tarr, attarr)]
55 | return tensor_merge(datt)
56 |
57 | def tensor_arr_rp(arr):
58 | l = len(arr)
59 | def tensor_rankpooling(video_arr):
60 | def get_w(N):
61 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)]
62 |
63 | # re = torch.zeros(video_arr[0].size(0), 1, video_arr[0].size(2), video_arr[0].size(3)).cuda()
64 | re = torch.zeros(video_arr[0].size()).cuda()
65 | for a, b in zip(video_arr, get_w(len(video_arr))):
66 | # a = transforms.Grayscale(1)(a)
67 | re += a * b
68 | re = F.gelu(re)
69 | re -= torch.min(re)
70 | re = re / torch.max(re) if torch.max(re) != 0 else re / (torch.max(re) + 0.00001)
71 | return transforms.Grayscale(1)(re)
72 |
73 | return [tensor_rankpooling(arr[i:i + self._w]) for i in range(l)]
74 |
75 | arrrp = tensor_arr_rp(tensor_split(x))
76 |
77 | b, c, t, h, w = tensor_merge(arrrp).shape
78 | mask = torch.zeros(b, c, self._w-1, h, w, device=tensor_merge(arrrp).device)
79 | garrs = torch.cat((mask, tensor_merge(arrrp)), dim=2)
80 | return garrs
81 |
82 | if __name__ == '__main__':
83 | model = SATT_Module().cuda()
84 | inp = torch.randn(2, 3, 64, 224, 224).cuda()
85 | out = model(inp)
86 | print(out.shape)
87 |
--------------------------------------------------------------------------------
/lib/model/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | from .build import *
--------------------------------------------------------------------------------
/lib/model/build.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | from .DSN import DSNNet
6 | from .fusion_Net import CrossFusionNet
7 |
8 | import logging
9 |
10 | def build_model(args):
11 | num_classes = dict(
12 | IsoGD=249,
13 | NvGesture=25,
14 | Jester=27,
15 | THUREAD=40,
16 | NTU=60
17 | )
18 | func_dict = dict(
19 | I3DWTrans=DSNNet,
20 | FusionNet=CrossFusionNet
21 | )
22 | assert args.dataset in num_classes, 'Error in load dataset !'
23 | assert args.Network in func_dict, 'Error in Network function !'
24 | args.num_classes = num_classes[args.dataset]
25 | if args.local_rank == 0:
26 | logging.info('Model:{}, Total Categories:{}'.format(args.Network, args.num_classes))
27 | return func_dict[args.Network](args, num_classes=args.num_classes, pretrained=args.pretrained)
28 |
--------------------------------------------------------------------------------
/lib/model/fusion_Net.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | from collections import OrderedDict
10 |
11 | import numpy as np
12 |
13 | import os
14 | import sys
15 | from collections import OrderedDict
16 |
17 | sys.path.append(['../../', '../'])
18 | from utils import load_pretrained_checkpoint, load_checkpoint
19 | import logging
20 | from .DSN_Fusion import DSNNet
21 |
22 | class LabelSmoothingCrossEntropy(torch.nn.Module):
23 | def __init__(self, smoothing: float = 0.1,
24 | reduction="mean", weight=None):
25 | super(LabelSmoothingCrossEntropy, self).__init__()
26 | self.smoothing = smoothing
27 | self.reduction = reduction
28 | self.weight = weight
29 |
30 | def reduce_loss(self, loss):
31 | return loss.mean() if self.reduction == 'mean' else loss.sum() \
32 | if self.reduction == 'sum' else loss
33 |
34 | def linear_combination(self, x, y):
35 | return self.smoothing * x + (1 - self.smoothing) * y
36 |
37 | def forward(self, preds, target):
38 | assert 0 <= self.smoothing < 1
39 |
40 | if self.weight is not None:
41 | self.weight = self.weight.to(preds.device)
42 |
43 | n = preds.size(-1)
44 | log_preds = F.log_softmax(preds, dim=-1)
45 | loss = self.reduce_loss(-log_preds.sum(dim=-1))
46 | nll = F.nll_loss(
47 | log_preds, target, reduction=self.reduction, weight=self.weight
48 | )
49 | return self.linear_combination(loss / n, nll)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, C_in, C_out, dilation=2):
54 | super(Encoder, self).__init__()
55 | self.enconv = nn.Sequential(
56 | nn.Conv2d(C_in, C_in, kernel_size=1, stride=1, padding=0, bias=False),
57 | nn.BatchNorm2d(C_in),
58 | nn.ReLU(inplace=False),
59 |
60 | nn.Conv2d(C_in, C_in // 2, kernel_size=1, stride=1, padding=0, bias=False),
61 | nn.BatchNorm2d(C_in // 2),
62 | nn.ReLU(inplace=False),
63 |
64 | nn.Conv2d(C_in // 2, C_in // 4, kernel_size=1, stride=1, padding=0, bias=False),
65 | nn.BatchNorm2d(C_in // 4),
66 | nn.ReLU(inplace=False),
67 |
68 | nn.Conv2d(C_in // 4, C_out, kernel_size=1, stride=1, padding=0, bias=False),
69 | )
70 |
71 | def forward(self, x1, x2):
72 | b, c = x1.shape
73 | x = torch.cat((x1, x2), dim=1).view(b, -1, 1, 1)
74 | x = self.enconv(x)
75 | return x
76 |
77 |
78 | class Decoder(nn.Module):
79 | def __init__(self, C_in, C_out, dilation=2):
80 | super(Decoder, self).__init__()
81 | self.deconv = nn.Sequential(
82 | nn.Conv2d(C_in, C_out // 4, kernel_size=1, padding=0, bias=False),
83 | nn.BatchNorm2d(C_out // 4),
84 | nn.ReLU(),
85 |
86 | nn.Conv2d(C_out // 4, C_out // 2, kernel_size=1, padding=0, bias=False),
87 | nn.BatchNorm2d(C_out // 2),
88 | nn.ReLU(),
89 | )
90 |
91 | def forward(self, x):
92 | x = self.deconv(x)
93 | return x
94 |
95 |
96 | class FusionModule(nn.Module):
97 | def __init__(self, channel_in=1024, channel_out=256, num_classes=60):
98 | super(FusionModule, self).__init__()
99 | self.encoder = Encoder(channel_in, channel_out)
100 | self.decoder = Decoder(channel_out, channel_in)
101 | self.efc = nn.Conv2d(channel_out, num_classes, kernel_size=1, padding=0, bias=False)
102 |
103 | def forward(self, r, d):
104 | en_x = self.encoder(r, d) # [4, 256, 1, 1]
105 | de_x = self.decoder(en_x)
106 | en_x = self.efc(en_x)
107 | return en_x.squeeze(), de_x
108 |
109 | class CrossFusionNet(nn.Module):
110 | def __init__(self, args, num_classes, pretrained, spatial_interact=True, temporal_interact=True):
111 | super(CrossFusionNet, self).__init__()
112 | self._MES = torch.nn.MSELoss()
113 | self._BCE = torch.nn.BCELoss()
114 | self._CE = LabelSmoothingCrossEntropy()
115 | self.spatial_interact = spatial_interact
116 | self.temporal_interact = temporal_interact
117 |
118 | self.fusion_model = FusionModule(channel_out=256, num_classes=num_classes)
119 | self.avg_pool = nn.AdaptiveAvgPool3d(1)
120 | self.fc = nn.Conv2d(512, 1, kernel_size=1, padding=0, bias=False)
121 | self.dropout = nn.Dropout(0.5)
122 |
123 | assert args.rgb_checkpoint and args.depth_checkpoint
124 | self.Modalit_rgb = DSNNet(args, num_classes=num_classes,
125 | pretrained=args.rgb_checkpoint)
126 |
127 | self.Modalit_depth = DSNNet(args, num_classes=num_classes,
128 | pretrained=args.depth_checkpoint)
129 |
130 | if self.spatial_interact:
131 | self.crossFusion = nn.Sequential(
132 | nn.Conv2d(512 * 2, 512, kernel_size=1, stride=1, padding=0, bias=False),
133 | nn.BatchNorm2d(512),
134 | nn.ReLU(),
135 | nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
136 | nn.Dropout(0.4)
137 |
138 | )
139 | if self.temporal_interact:
140 | self.crossFusionT = nn.Sequential(
141 | nn.Conv2d(512 * 2, 512, kernel_size=1, stride=1, padding=0, bias=False),
142 | nn.BatchNorm2d(512),
143 | nn.ReLU(),
144 | nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
145 | nn.Dropout(0.4)
146 | )
147 |
148 | self.classifier1 = nn.Sequential(
149 | nn.LayerNorm(512),
150 | nn.Linear(512, num_classes)
151 | )
152 | self.classifier2 = nn.Sequential(
153 | nn.LayerNorm(512),
154 | nn.Linear(512, num_classes)
155 | )
156 |
157 | if pretrained:
158 | load_pretrained_checkpoint(self, pretrained)
159 | logging.info("Load Pre-trained model state_dict Done !")
160 |
161 | def forward(self, inputs, garrs, target):
162 | rgb, depth = inputs
163 | rgb_garr, depth_garr = garrs
164 |
165 | spatial_M = self.Modalit_rgb(rgb, rgb_garr, endpoint='spatial')
166 | spatial_K = self.Modalit_depth(depth, depth_garr, endpoint='spatial')
167 |
168 | if self.spatial_interact:
169 | b, t, c = spatial_M.shape
170 | spatial_fusion_features = self.crossFusion(F.normalize(torch.cat((spatial_M, spatial_K), dim=-1), p = 2, dim=-1).view(b, c*2, t, 1)).squeeze()
171 |
172 | (temporal_M, M_xs, M_xm, M_xl), distillationM, _ = self.Modalit_rgb(x=spatial_M + spatial_fusion_features.view(spatial_M.shape) if self.spatial_interact else spatial_M,
173 | endpoint='temporal') # size[4, 512]
174 | (temporal_K, K_xs, K_xm, K_xl), distillationK, _ = self.Modalit_depth(x=spatial_K + spatial_fusion_features.view(spatial_M.shape) if self.spatial_interact else spatial_K,
175 | endpoint='temporal')
176 | logit_r = self.classifier1(temporal_M)
177 | logit_d = self.classifier2(temporal_K)
178 |
179 | if self.temporal_interact:
180 | b, c = temporal_M.shape
181 | temporal_fusion_features = self.crossFusionT(F.normalize(torch.cat((temporal_M, temporal_K), dim=-1), p = 2, dim=-1).view(b, c*2, 1, 1)).squeeze()
182 | temporal_M, temporal_K = temporal_M+temporal_fusion_features, temporal_K+temporal_fusion_features
183 |
184 | en_x, de_x = self.fusion_model(temporal_M, temporal_K)
185 | b, c = temporal_M.shape
186 | bce_r = torch.sigmoid(self.fc(self.dropout(temporal_M).view(b, c, 1, 1))).view(b, -1)
187 | bce_d = torch.sigmoid(self.fc(self.dropout(temporal_K).view(b, c, 1, 1))).view(b, -1)
188 |
189 | BCE_loss = self._BCE(bce_r, torch.ones(bce_r.size(0), 1).cuda()) + self._BCE(bce_d, torch.zeros(bce_d.size(0),
190 | 1).cuda())
191 | MSE_loss = self._MES(de_x.view(b, c), temporal_M) + self._MES(de_x.view(b, c), temporal_K)
192 | CE_loss = self._CE(en_x, target) + self._CE(logit_r, target) + self._CE(logit_d, target)
193 | distillation = distillationM + distillationK
194 |
195 | return (en_x, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation)
196 |
--------------------------------------------------------------------------------
/lib/model/trans_module.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/rishikksh20/CrossViT-pytorch/blob/master/crossvit.py
4 | '''
5 |
6 | import torch
7 | from torch import nn, einsum
8 | import torch.nn.functional as F
9 |
10 | import math
11 |
12 | from einops import rearrange, repeat
13 | from einops.layers.torch import Rearrange
14 |
15 | class Residual(nn.Module):
16 | def __init__(self, fn):
17 | super().__init__()
18 | self.fn = fn
19 |
20 | def forward(self, x, **kwargs):
21 | return self.fn(x, **kwargs) + x
22 |
23 |
24 | class PreNorm(nn.Module):
25 | def __init__(self, dim, fn):
26 | super().__init__()
27 | self.norm = nn.LayerNorm(dim)
28 | self.fn = fn
29 |
30 | def forward(self, x, **kwargs):
31 | return self.fn(self.norm(x), **kwargs)
32 |
33 |
34 |
35 | # class FeedForward(nn.Module):
36 | # def __init__(self, dim, hidden_dim, dropout=0.):
37 | # super().__init__()
38 | # self.net = nn.Sequential(
39 | # nn.Linear(dim, hidden_dim),
40 | # nn.GELU(),
41 | # nn.Dropout(dropout),
42 | # nn.Linear(hidden_dim, dim),
43 | # nn.Dropout(dropout)
44 | # )
45 |
46 | # def forward(self, x):
47 | # return self.net(x)
48 |
49 | class FeedForward(nn.Module):
50 | """FeedForward Neural Networks for each position"""
51 | def __init__(self, dim, hidden_dim, dropout=0.):
52 | super().__init__()
53 | self.fc1 = nn.Linear(dim, hidden_dim)
54 | self.fc2 = nn.Linear(hidden_dim, dim)
55 | self.dropout = nn.Dropout(dropout)
56 |
57 | def forward(self, x):
58 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
59 | return self.dropout(self.fc2(self.dropout(F.gelu(self.fc1(x)))))
60 |
61 | class Attention(nn.Module):
62 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., apply_transform=False, transform_scale=True, knn_attention=0.7):
63 | super().__init__()
64 | inner_dim = dim_head * heads
65 | project_out = not (heads == 1 and dim_head == dim)
66 |
67 | self.heads = heads
68 | self.scale = dim_head ** -0.5
69 | self.apply_transform = apply_transform
70 | self.knn_attention = bool(knn_attention)
71 | self.topk = knn_attention
72 |
73 | if apply_transform:
74 | self.reatten_matrix = torch.nn.Conv2d(heads, heads, 1, 1)
75 | self.var_norm = torch.nn.BatchNorm2d(heads)
76 | self.reatten_scale = self.scale if transform_scale else 1.0
77 |
78 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
79 |
80 | self.to_out = nn.Sequential(
81 | nn.Linear(inner_dim, dim),
82 | nn.Dropout(dropout)
83 | ) if project_out else nn.Identity()
84 | self.scores = None
85 |
86 | def forward(self, x):
87 | b, n, _, h = *x.shape, self.heads
88 | qkv = self.to_qkv(x).chunk(3, dim=-1)
89 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
90 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
91 |
92 | if self.knn_attention:
93 | mask = torch.zeros(b, self.heads, n, n, device=x.device, requires_grad=False)
94 | index = torch.topk(dots, k=int(dots.size(-1)*self.topk), dim=-1, largest=True)[1]
95 | mask.scatter_(-1, index, 1.)
96 | dots = torch.where(mask > 0, dots, torch.full_like(dots, float('-inf')))
97 | attn = dots.softmax(dim=-1)
98 | if self.apply_transform:
99 | attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
100 |
101 | self.scores = attn
102 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
103 |
104 |
105 | out = rearrange(out, 'b h n d -> b n (h d)')
106 | out = self.to_out(out)
107 | return out
108 |
109 |
110 | class CrossAttention(nn.Module):
111 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
112 | super().__init__()
113 | inner_dim = dim_head * heads
114 | project_out = not (heads == 1 and dim_head == dim)
115 |
116 | self.heads = heads
117 | self.scale = dim_head ** -0.5
118 |
119 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
120 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
121 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
122 |
123 | self.to_out = nn.Sequential(
124 | nn.Linear(inner_dim, dim),
125 | nn.Dropout(dropout)
126 | ) if project_out else nn.Identity()
127 |
128 | def forward(self, x_qkv):
129 | b, n, _, h = *x_qkv.shape, self.heads
130 |
131 | k = self.to_k(x_qkv)
132 | k = rearrange(k, 'b n (h d) -> b h n d', h=h)
133 |
134 | v = self.to_v(x_qkv)
135 | v = rearrange(v, 'b n (h d) -> b h n d', h=h)
136 |
137 | q = self.to_q(x_qkv[:, 0].unsqueeze(1))
138 | q = rearrange(q, 'b n (h d) -> b h n d', h=h)
139 |
140 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
141 |
142 | attn = dots.softmax(dim=-1)
143 |
144 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
145 | out = rearrange(out, 'b h n d -> b n (h d)')
146 | out = self.to_out(out)
147 | return out
148 |
--------------------------------------------------------------------------------
/lib/model/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/deepmind/kinetics-i3d/blob/master/i3d.py
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 | import numpy as np
11 | import os
12 | import sys
13 |
14 | class MaxPool3dSamePadding(nn.MaxPool3d):
15 | def compute_pad(self, dim, s):
16 | if s % self.stride[dim] == 0:
17 | return max(self.kernel_size[dim] - self.stride[dim], 0)
18 | else:
19 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
20 |
21 | def forward(self, x):
22 | (batch, channel, t, h, w) = x.size()
23 | pad_t = self.compute_pad(0, t)
24 | pad_h = self.compute_pad(1, h)
25 | pad_w = self.compute_pad(2, w)
26 | pad_t_f = pad_t // 2
27 | pad_t_b = pad_t - pad_t_f
28 | pad_h_f = pad_h // 2
29 | pad_h_b = pad_h - pad_h_f
30 | pad_w_f = pad_w // 2
31 | pad_w_b = pad_w - pad_w_f
32 |
33 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
34 | x = F.pad(x, pad)
35 | return super(MaxPool3dSamePadding, self).forward(x)
36 |
37 | class Unit3D(nn.Module):
38 |
39 | def __init__(self, in_channels,
40 | output_channels,
41 | kernel_shape=(1, 1, 1),
42 | stride=(1, 1, 1),
43 | padding=0,
44 | activation_fn=F.relu,
45 | use_batch_norm=True,
46 | use_bias=False,
47 | name='unit_3d'):
48 |
49 | """Initializes Unit3D module."""
50 | super(Unit3D, self).__init__()
51 |
52 | self._output_channels = output_channels
53 | self._kernel_shape = kernel_shape
54 | self._stride = stride
55 | self._use_batch_norm = use_batch_norm
56 | self._activation_fn = activation_fn
57 | self._use_bias = use_bias
58 | self.name = name
59 | self.padding = padding
60 |
61 | self.conv3d = nn.Conv3d(in_channels=in_channels,
62 | out_channels=self._output_channels,
63 | kernel_size=self._kernel_shape,
64 | stride=self._stride,
65 | padding=0,
66 | bias=self._use_bias)
67 |
68 | if self._use_batch_norm:
69 | self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
70 |
71 | def compute_pad(self, dim, s):
72 | if s % self._stride[dim] == 0:
73 | return max(self._kernel_shape[dim] - self._stride[dim], 0)
74 | else:
75 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
76 |
77 |
78 | def forward(self, x):
79 | (batch, channel, t, h, w) = x.size()
80 | pad_t = self.compute_pad(0, t)
81 | pad_h = self.compute_pad(1, h)
82 | pad_w = self.compute_pad(2, w)
83 | pad_t_f = pad_t // 2
84 | pad_t_b = pad_t - pad_t_f
85 | pad_h_f = pad_h // 2
86 | pad_h_b = pad_h - pad_h_f
87 | pad_w_f = pad_w // 2
88 | pad_w_b = pad_w - pad_w_f
89 |
90 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
91 | x = F.pad(x, pad)
92 | x = self.conv3d(x)
93 | if self._use_batch_norm:
94 | x = self.bn(x)
95 | if self._activation_fn is not None:
96 | x = self._activation_fn(x)
97 | return x
98 |
99 | class TemporalInceptionModule(nn.Module):
100 | def __init__(self, in_channels, out_channels, name):
101 | super(TemporalInceptionModule, self).__init__()
102 |
103 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
104 | name=name+'/Branch_0/Conv3d_0a_1x1')
105 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
106 | name=name+'/Branch_1/Conv3d_0a_1x1')
107 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 1, 1],
108 | name=name+'/Branch_1/Conv3d_0b_3x3')
109 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
110 | name=name+'/Branch_2/Conv3d_0a_1x1')
111 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 1, 1],
112 | name=name+'/Branch_2/Conv3d_0b_3x3')
113 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 1, 1],
114 | stride=(1, 1, 1), padding=0)
115 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
116 | name=name+'/Branch_3/Conv3d_0b_1x1')
117 | self.name = name
118 |
119 | def forward(self, x):
120 | b0 = self.b0(x)
121 | b1 = self.b1b(self.b1a(x))
122 | b2 = self.b2b(self.b2a(x))
123 | b3 = self.b3b(self.b3a(x))
124 | return torch.cat([b0,b1,b2,b3], dim=1)
125 |
126 |
127 | class SpatialInceptionModule(nn.Module):
128 | def __init__(self, in_channels, out_channels, name):
129 | super(SpatialInceptionModule, self).__init__()
130 |
131 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
132 | name=name + '/Branch_0/Conv3d_0a_1x1')
133 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
134 | name=name + '/Branch_1/Conv3d_0a_1x1')
135 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[1, 3, 3],
136 | name=name + '/Branch_1/Conv3d_0b_3x3')
137 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
138 | name=name + '/Branch_2/Conv3d_0a_1x1')
139 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[1, 3, 3],
140 | name=name + '/Branch_2/Conv3d_0b_3x3')
141 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
142 | stride=(1, 1, 1), padding=0)
143 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
144 | name=name + '/Branch_3/Conv3d_0b_1x1')
145 | self.name = name
146 |
147 | def forward(self, x):
148 | b0 = self.b0(x)
149 | b1 = self.b1b(self.b1a(x))
150 | b2 = self.b2b(self.b2a(x))
151 | b3 = self.b3b(self.b3a(x))
152 | return torch.cat([b0, b1, b2, b3], dim=1)
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | TRAIN=$1
3 | CONFIG=$2
4 | GPUID=$3
5 | GPUNUM=$4
6 | PORT=${PORT:-29509}
7 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node=$GPUNUM --master_port=$PORT $TRAIN --config $CONFIG --nprocs $GPUNUM --save_output
8 |
--------------------------------------------------------------------------------
/tools/fusion.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 | import os, random, math
5 | import time
6 | import glob
7 | import numpy as np
8 | import shutil
9 |
10 | import torch
11 |
12 | import logging
13 | import argparse
14 | import traceback
15 | import torch.nn as nn
16 | import torch.utils
17 | import torchvision.datasets as dset
18 | import torch.backends.cudnn as cudnn
19 |
20 | import sys
21 | sys.path.append(os.path.abspath(os.path.join("..", os.getcwd())))
22 | from config import Config
23 | from lib import *
24 | import torch.distributed as dist
25 | from utils import *
26 | from utils.build import *
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--config', help='Place config Congfile!')
31 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?')
32 | parser.add_argument('--local_rank', type=int, default=0)
33 | parser.add_argument('--nprocs', type=int, default=1)
34 |
35 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?')
36 | parser.add_argument('--save_output', action='store_true', help='Save logits?')
37 | parser.add_argument('--fp16', action='store_true', help='Training with fp16')
38 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo')
39 | parser.add_argument('--resume', type=str, default='', help='resume model path.')
40 |
41 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability')
42 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment name')
43 | parser.add_argument('--seed', type=int, default=123, help='random seed')
44 | args = parser.parse_args()
45 | args = Config(args)
46 |
47 | #====================================================
48 | # Some configuration
49 | #====================================================
50 |
51 | try:
52 | if args.resume:
53 | args.save = os.path.split(args.resume)[0]
54 | else:
55 | args.save = '{}/{}-{}-{}-{}'.format(args.save, args.Network, args.dataset, args.type, time.strftime("%Y%m%d-%H%M%S"))
56 | utils.create_exp_dir(args.save, scripts_to_save=[args.config] + glob.glob('./tools/*.py') + glob.glob('./lib/*'))
57 | except:
58 | pass
59 | log_format = '%(asctime)s %(message)s'
60 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
61 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
62 | fh = logging.FileHandler(os.path.join(args.save, 'log{}.txt'.format(time.strftime("%Y%m%d-%H%M%S"))))
63 | fh.setFormatter(logging.Formatter(log_format))
64 | logging.getLogger().addHandler(fh)
65 |
66 | #---------------------------------
67 | # Fusion Net Training
68 | #---------------------------------
69 | def reduce_mean(tensor, nprocs):
70 | rt = tensor.clone()
71 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
72 | rt /= nprocs
73 | return rt.item()
74 |
75 |
76 | def main(local_rank, nprocs, args):
77 | if not torch.cuda.is_available():
78 | logging.info('no gpu device available')
79 | sys.exit(1)
80 |
81 | np.random.seed(args.seed)
82 | cudnn.benchmark = True
83 | torch.manual_seed(args.seed)
84 | cudnn.enabled = True
85 | torch.cuda.manual_seed(args.seed)
86 | logging.info('gpu device = %d' % local_rank)
87 |
88 | # ---------------------------
89 | # Init distribution
90 | # ---------------------------
91 | torch.cuda.set_device(local_rank)
92 | torch.distributed.init_process_group(backend='nccl')
93 |
94 | # ----------------------------
95 | # build function
96 | # ----------------------------
97 | model = build_model(args)
98 | model = model.cuda(local_rank)
99 |
100 | criterion = build_loss(args)
101 | optimizer = build_optim(args, model)
102 | scheduler = build_scheduler(args, optimizer)
103 |
104 | train_queue, train_sampler = build_dataset(args, phase='train')
105 | valid_queue, valid_sampler = build_dataset(args, phase='valid')
106 |
107 | if args.resume:
108 | model, optimizer, strat_epoch, best_acc = load_checkpoint(model, args.resume, optimizer)
109 | logging.info("The network will resume training.")
110 | logging.info("Start Epoch: {}, Learning rate: {}, Best accuracy: {}".format(strat_epoch, [g['lr'] for g in
111 | optimizer.param_groups],
112 | round(best_acc, 4)))
113 | if args.resumelr:
114 | for g in optimizer.param_groups: g['lr'] = args.resumelr
115 | args.resume_scheduler = cosine_scheduler(args.resumelr, 1e-5, args.epochs - strat_epoch, len(train_queue))
116 | args.resume_epoch = strat_epoch - 1
117 |
118 | else:
119 | strat_epoch = 0
120 | best_acc = 0.0
121 | args.resume_epoch = 0
122 | scheduler[0].last_epoch = strat_epoch
123 |
124 |
125 | if args.SYNC_BN and args.nprocs > 1:
126 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
127 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
128 | if local_rank == 0:
129 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
130 |
131 |
132 | train_results = dict(
133 | train_score=[],
134 | train_loss=[],
135 | valid_score=[],
136 | valid_loss=[],
137 | best_score=0.0
138 | )
139 | if args.eval_only:
140 | valid_acc, _, _, meter_dict = infer(valid_queue, model, criterion, local_rank, 0)
141 | valid_acc = max(meter_dict['Acc_all'].avg, meter_dict['Acc'].avg, meter_dict['Acc_3'].avg)
142 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}'.format(valid_acc, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg))
143 | return
144 |
145 | #---------------------------
146 | # Mixed Precision Training
147 | # --------------------------
148 | if args.fp16:
149 | scaler = torch.cuda.amp.GradScaler()
150 | else:
151 | scaler = None
152 | for epoch in range(strat_epoch, args.epochs):
153 | train_sampler.set_epoch(epoch)
154 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
155 |
156 | if epoch < args.scheduler['warm_up_epochs']:
157 | for g in optimizer.param_groups:
158 | g['lr'] = scheduler[-1](epoch)
159 |
160 | args.epoch = epoch
161 | train_acc, train_obj, meter_dict_train = train(train_queue, model, criterion, optimizer, epoch, local_rank, scaler)
162 | valid_acc, valid_obj, valid_dict, meter_dict_val = infer(valid_queue, model, criterion, local_rank, epoch)
163 | valid_acc = max(meter_dict_val['Acc_all'].avg, meter_dict_val['Acc'].avg, meter_dict_val['Acc_3'].avg)
164 | if epoch >= args.scheduler['warm_up_epochs']:
165 | if args.scheduler['name'] == 'ReduceLR':
166 | scheduler[0].step(valid_acc)
167 | else:
168 | scheduler[0].step()
169 |
170 | if local_rank == 0:
171 | if valid_acc > best_acc:
172 | best_acc = valid_acc
173 | isbest = True
174 | else:
175 | isbest = False
176 | logging.info('train_acc %f', train_acc)
177 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}, best acc: {}'.format(meter_dict_val['Acc'].avg, meter_dict_val['Acc_1'].avg,
178 | meter_dict_val['Acc_2'].avg,
179 | meter_dict_val['Acc_3'].avg, best_acc))
180 |
181 | state = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'bestacc': best_acc}
182 | save_checkpoint(state, isbest, args.save)
183 |
184 | train_results['train_score'].append(train_acc)
185 | train_results['train_loss'].append(train_obj)
186 | train_results['valid_score'].append(valid_acc)
187 | train_results['valid_loss'].append(valid_obj)
188 | train_results['best_score'] = best_acc
189 | train_results.update(valid_dict)
190 | train_results['categories'] = np.unique(valid_dict['grounds'])
191 |
192 | if isbest:
193 | EvaluateMetric(PREDICTIONS_PATH=args.save, train_results=train_results, idx=epoch)
194 | for k, v in train_results.items():
195 | if isinstance(v, list):
196 | v.clear()
197 |
198 | def train(train_queue, model, criterion, optimizer, epoch, local_rank, scaler):
199 | model.train()
200 | meter_dict = dict(
201 | Total_loss=AverageMeter(),
202 | MSE_loss=AverageMeter(),
203 | CE_loss=AverageMeter(),
204 | BCE_loss=AverageMeter(),
205 | Distill_loss = AverageMeter()
206 |
207 | )
208 | meter_dict['Data_Time'] = AverageMeter()
209 | meter_dict.update(dict(
210 | Acc_1=AverageMeter(),
211 | Acc_2=AverageMeter(),
212 | Acc_3=AverageMeter(),
213 | Acc=AverageMeter()
214 | ))
215 |
216 | end = time.time()
217 | for step, (inputs, heatmap, target, _) in enumerate(train_queue):
218 | meter_dict['Data_Time'].update((time.time() - end)/args.batch_size)
219 | inputs, target, heatmap = map(lambda x: [d.cuda(local_rank, non_blocking=True) for d in x] if isinstance(x, list) else x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap])
220 |
221 | if args.resumelr:
222 | for g in optimizer.param_groups:
223 | g['lr'] = args.resume_scheduler[len(train_queue) * args.resume_epoch + step]
224 | # ---------------------------
225 | # Mixed Precision Training
226 | # --------------------------
227 | if args.fp16:
228 | print('Train with FP16')
229 | optimizer.zero_grad()
230 | # Runs the forward pass with autocasting.
231 | with torch.cuda.amp.autocast():
232 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target)
233 | globals()['CE_loss'] = CE_loss
234 | globals()['MSE_loss'] = MSE_loss
235 | globals()['BCE_loss'] = BCE_loss
236 | globals()['Distill_loss'] = distillation
237 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation
238 |
239 | scaler.scale(Total_loss).backward()
240 | # Unscales the gradients of optimizer's assigned params in-place
241 | scaler.unscale_(optimizer)
242 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
243 | scaler.step(optimizer)
244 | scaler.update()
245 | else:
246 | # ---------------------------
247 | # Fp32 Precision Training
248 | # --------------------------
249 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target)
250 | globals()['CE_loss'] = CE_loss
251 | globals()['MSE_loss'] = MSE_loss
252 | globals()['BCE_loss'] = BCE_loss
253 | globals()['Distill_loss'] = distillation
254 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation
255 |
256 | optimizer.zero_grad()
257 | Total_loss.backward()
258 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
259 | optimizer.step()
260 |
261 | #---------------------
262 | # Meter performance
263 | #---------------------
264 | torch.distributed.barrier()
265 | globals()['Acc'] = calculate_accuracy(logits, target)
266 | globals()['Acc_1'] = calculate_accuracy(logit_r, target)
267 | globals()['Acc_2'] = calculate_accuracy(logit_d, target)
268 | globals()['Acc_3'] = calculate_accuracy(logit_r+logit_d, target)
269 |
270 | for name in meter_dict:
271 | if 'loss' in name:
272 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
273 | if 'Acc' in name:
274 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
275 |
276 | if step % args.report_freq == 0 and local_rank == 0:
277 |
278 | log_info = {
279 | 'Epoch': '{}/{}'.format(epoch + 1, args.epochs),
280 | 'Mini-Batch': '{:0>5d}/{:0>5d}'.format(step + 1,
281 | len(train_queue.dataset) // (args.batch_size * args.nprocs)),
282 | 'Lr': ['{:.4f}'.format(g['lr']) for g in optimizer.param_groups],
283 | }
284 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items()))
285 | print_func(log_info)
286 | end = time.time()
287 | args.resume_epoch += 1
288 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, meter_dict
289 |
290 | @torch.no_grad()
291 | def concat_all_gather(tensor):
292 | """
293 | Performs all_gather operation on the provided tensors.
294 | *** Warning ***: torch.distributed.all_gather has no gradient.
295 | """
296 | tensors_gather = [torch.ones_like(tensor)
297 | for _ in range(torch.distributed.get_world_size())]
298 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
299 |
300 | output = torch.cat(tensors_gather, dim=0)
301 | return output
302 |
303 | @torch.no_grad()
304 | def infer(valid_queue, model, criterion, local_rank, epoch):
305 | model.eval()
306 |
307 | meter_dict = dict(
308 | Total_loss=AverageMeter(),
309 | MSE_loss=AverageMeter(),
310 | CE_loss=AverageMeter(),
311 | Distill_loss=AverageMeter()
312 | )
313 | meter_dict.update(dict(
314 | Acc_1=AverageMeter(),
315 | Acc_2=AverageMeter(),
316 | Acc_3=AverageMeter(),
317 | Acc = AverageMeter(),
318 | Acc_all=AverageMeter(),
319 | ))
320 |
321 | meter_dict['Infer_Time'] = AverageMeter()
322 | grounds, preds, v_paths = [], [], []
323 | for step, (inputs, heatmap, target, v_path) in enumerate(valid_queue):
324 | end = time.time()
325 | inputs, target, heatmap = map(
326 | lambda x: [d.cuda(local_rank, non_blocking=True) for d in x] if isinstance(x, list) else x.cuda(local_rank,
327 | non_blocking=True),
328 | [inputs, target, heatmap])
329 | if args.fp16:
330 | with torch.cuda.amp.autocast():
331 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target)
332 | else:
333 | (logits, logit_r, logit_d), (CE_loss, BCE_loss, MSE_loss, distillation) = model(inputs, heatmap, target)
334 | meter_dict['Infer_Time'].update((time.time() - end) / args.test_batch_size)
335 | globals()['CE_loss'] = CE_loss
336 | globals()['MSE_loss'] = MSE_loss
337 | globals()['BCE_loss'] = BCE_loss
338 | globals()['Distill_loss'] = distillation
339 | globals()['Total_loss'] = CE_loss + MSE_loss + BCE_loss + distillation
340 |
341 | torch.distributed.barrier()
342 | globals()['Acc'] = calculate_accuracy(logits, target)
343 | globals()['Acc_1'] = calculate_accuracy(logit_r, target)
344 | globals()['Acc_2'] = calculate_accuracy(logit_d, target)
345 | globals()['Acc_3'] = calculate_accuracy(logit_r+logit_d, target)
346 | globals()['Acc_all'] = calculate_accuracy(logit_r+logit_d+logits, target)
347 |
348 |
349 | grounds += target.cpu().tolist()
350 | preds += torch.argmax(logits, dim=1).cpu().tolist()
351 | v_paths += v_path
352 | for name in meter_dict:
353 | if 'loss' in name:
354 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
355 | if 'Acc' in name:
356 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
357 |
358 | if step % args.report_freq == 0 and local_rank == 0:
359 | log_info = {
360 | 'Epoch': epoch + 1,
361 | 'Mini-Batch': '{:0>4d}/{:0>4d}'.format(step + 1, len(valid_queue.dataset) // (
362 | args.test_batch_size * args.nprocs)),
363 | }
364 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items()))
365 | print_func(log_info)
366 |
367 | torch.distributed.barrier()
368 | grounds_gather = concat_all_gather(torch.tensor(grounds).cuda(local_rank))
369 | preds_gather = concat_all_gather(torch.tensor(preds).cuda(local_rank))
370 | grounds_gather, preds_gather = list(map(lambda x: x.cpu().numpy(), [grounds_gather, preds_gather]))
371 |
372 | if local_rank == 0:
373 | v_paths = np.array(v_paths)
374 | grounds = np.array(grounds)
375 | preds = np.array(preds)
376 | wrong_idx = np.where(grounds != preds)
377 | v_paths = v_paths[wrong_idx[0]]
378 | grounds = grounds[wrong_idx[0]]
379 | preds = preds[wrong_idx[0]]
380 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, dict(grounds=grounds_gather, preds=preds_gather, valid_images=(v_paths, grounds, preds)), meter_dict
381 |
382 | if __name__ == '__main__':
383 | try:
384 | main(args.local_rank, args.nprocs, args)
385 | except KeyboardInterrupt:
386 | torch.cuda.empty_cache()
387 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3:
388 | print(f'remove {args.save}: Directory')
389 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save))
390 | os._exit(0)
391 | except Exception:
392 | print(traceback.print_exc())
393 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3:
394 | print(f'remove {args.save}: Directory')
395 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save))
396 | os._exit(0)
397 | finally:
398 | torch.cuda.empty_cache()
399 |
--------------------------------------------------------------------------------
/tools/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import time
6 | import glob
7 | import numpy as np
8 | import shutil
9 | import cv2
10 | import os, random, math
11 | import sys
12 | sys.path.append(os.path.join('..', os.path.abspath(os.path.join(os.getcwd()))) )
13 |
14 | import torch
15 | import utils
16 | import logging
17 | import argparse
18 | import traceback
19 | import torch.nn as nn
20 | import torch.utils
21 | import torchvision.datasets as dset
22 | import torch.backends.cudnn as cudnn
23 | import torch.distributed as dist
24 |
25 | # import flops_benchmark
26 | from utils.visualizer import Visualizer
27 | from config import Config
28 | from lib import *
29 | from utils import *
30 |
31 | #------------------------
32 | # evaluation metrics
33 | #------------------------
34 | from sklearn.decomposition import PCA
35 | from sklearn import manifold
36 | import pandas as pd
37 | import matplotlib.pyplot as plt # For graphics
38 | import seaborn as sns
39 | from torchvision.utils import save_image, make_grid
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--config', help='Load Congfile.')
43 | parser.add_argument('--eval_only', action='store_true', help='Eval only. True or False?')
44 | parser.add_argument('--local_rank', type=int, default=0)
45 | parser.add_argument('--nprocs', type=int, default=1)
46 |
47 | parser.add_argument('--save_grid_image', action='store_true', help='Save samples?')
48 | parser.add_argument('--save_output', action='store_true', help='Save logits?')
49 | parser.add_argument('--demo_dir', type=str, default='./demo', help='The dir for save all the demo')
50 | parser.add_argument('--resume', type=str, default='', help='resume model path.')
51 |
52 | parser.add_argument('--distill-lamdb', type=float, default=0.0, help='initial distillation loss weight')
53 |
54 | parser.add_argument('--drop_path_prob', type=float, default=0.5, help='drop path probability')
55 | parser.add_argument('--save', type=str, default='Checkpoints/', help='experiment dir')
56 | parser.add_argument('--seed', type=int, default=123, help='random seed')
57 | args = parser.parse_args()
58 | args = Config(args)
59 |
60 | try:
61 | if args.resume:
62 | args.save = os.path.split(args.resume)[0]
63 | else:
64 | args.save = '{}/{}-{}-{}-{}'.format(args.save, args.Network, args.dataset, args.type, time.strftime("%Y%m%d-%H%M%S"))
65 | utils.create_exp_dir(args.save, scripts_to_save=[args.config] + glob.glob('./tools/*.py') + glob.glob('./lib/*'))
66 | except:
67 | pass
68 | log_format = '%(asctime)s %(message)s'
69 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
70 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
71 | fh = logging.FileHandler(os.path.join(args.save, 'log{}.txt'.format(time.strftime("%Y%m%d-%H%M%S"))))
72 | fh.setFormatter(logging.Formatter(log_format))
73 | logging.getLogger().addHandler(fh)
74 |
75 |
76 | def reduce_mean(tensor, nprocs):
77 | rt = tensor.clone()
78 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
79 | rt /= nprocs
80 | return rt.item()
81 |
82 | def main(local_rank, nprocs, args):
83 | if not torch.cuda.is_available():
84 | logging.info('no gpu device available')
85 | sys.exit(1)
86 |
87 | np.random.seed(args.seed)
88 | cudnn.benchmark = True
89 | torch.manual_seed(args.seed)
90 | cudnn.enabled = True
91 | torch.cuda.manual_seed(args.seed)
92 | logging.info('gpu device = %d' % local_rank)
93 |
94 | #---------------------------
95 | # Init distribution
96 | #---------------------------
97 | torch.cuda.set_device(local_rank)
98 | torch.distributed.init_process_group(backend='nccl')
99 |
100 | #----------------------------
101 | # build function
102 | #----------------------------
103 | model = build_model(args)
104 | model = model.cuda(local_rank)
105 |
106 | criterion = build_loss(args)
107 | optimizer = build_optim(args, model)
108 | scheduler = build_scheduler(args, optimizer)
109 |
110 | train_queue, train_sampler = build_dataset(args, phase='train')
111 | valid_queue, valid_sampler = build_dataset(args, phase='valid')
112 |
113 |
114 | if args.resume:
115 | model, optimizer, strat_epoch, best_acc = load_checkpoint(model, args.resume, optimizer)
116 | logging.info("Start Epoch: {}, Learning rate: {}, Best accuracy: {}".format(strat_epoch, [g['lr'] for g in
117 | optimizer.param_groups],
118 | round(best_acc, 4)))
119 | if args.resumelr:
120 | for g in optimizer.param_groups:
121 | args.resumelr = g['lr'] if not isinstance(args.resumelr, float) else args.resumelr
122 | g['lr'] = args.resumelr
123 | #resume_scheduler = np.linspace(args.resumelr, 1e-5, args.epochs - strat_epoch)
124 | resume_scheduler = cosine_scheduler(args.resumelr, 1e-5, args.epochs - strat_epoch + 1, niter_per_ep=1).tolist()
125 | resume_scheduler.pop(0)
126 |
127 | args.epoch = strat_epoch - 1
128 | else:
129 | strat_epoch = 0
130 | best_acc = 0.0
131 | args.epoch = strat_epoch
132 |
133 | scheduler[0].last_epoch = strat_epoch
134 |
135 | if args.SYNC_BN and args.nprocs > 1:
136 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
137 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False)
138 | if local_rank == 0:
139 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
140 | # logging.info('FLOPs: {}'.format(flops_benchmark.count_flops(model)))
141 |
142 | train_results = dict(
143 | train_score=[],
144 | train_loss=[],
145 | valid_score=[],
146 | valid_loss=[],
147 | best_score=0.0
148 | )
149 | if args.eval_only:
150 | valid_acc, _, _, meter_dict, output = infer(valid_queue, model, criterion, local_rank, strat_epoch)
151 | logging.info('valid_acc: {}, Acc_1: {}, Acc_2: {}, Acc_3: {}'.format(valid_acc, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg))
152 | if args.save_output:
153 | torch.save(output, os.path.join(args.save, '{}-output.pth'.format(args.type)))
154 | return
155 |
156 | for epoch in range(strat_epoch, args.epochs):
157 | train_sampler.set_epoch(epoch)
158 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
159 |
160 | if epoch < args.scheduler['warm_up_epochs']-1:
161 | for g in optimizer.param_groups:
162 | g['lr'] = scheduler[-1](epoch)
163 | else:
164 | args.distill_lamdb = args.distill
165 |
166 | args.epoch = epoch
167 | train_acc, train_obj, meter_dict_train = train(train_queue, model, criterion, optimizer, epoch, local_rank)
168 | valid_acc, valid_obj, valid_dict, meter_dict_val, output = infer(valid_queue, model, criterion, local_rank, epoch)
169 |
170 | # scheduler_func.step(scheduler, valid_acc)
171 | if epoch >= args.scheduler['warm_up_epochs']:
172 | if args.resume and args.resumelr:
173 | for g in optimizer.param_groups:
174 | g['lr'] = resume_scheduler[0]
175 | resume_scheduler.pop(0)
176 | elif args.scheduler['name'] == 'ReduceLR':
177 | scheduler[0].step(valid_acc)
178 | else:
179 | scheduler[0].step()
180 |
181 | if local_rank == 0:
182 | if valid_acc > best_acc:
183 | best_acc = valid_acc
184 | isbest = True
185 | else:
186 | isbest = False
187 | logging.info('train_acc %f', train_acc)
188 | logging.info('valid_acc %f, best_acc %f', valid_acc, best_acc)
189 | state = {'model': model.module.state_dict(),'optimizer': optimizer.state_dict(), 'epoch': epoch + 1, 'bestacc': best_acc}
190 | save_checkpoint(state, isbest, args.save)
191 |
192 | train_results['train_score'].append(train_acc)
193 | train_results['train_loss'].append(train_obj)
194 | train_results['valid_score'].append(valid_acc)
195 | train_results['valid_loss'].append(valid_obj)
196 | train_results['best_score'] = best_acc
197 | train_results.update(valid_dict)
198 | train_results['categories'] = np.unique(valid_dict['grounds'])
199 |
200 | if args.visdom['enable']:
201 | vis.plot_many({'train_acc': train_acc, 'loss': train_obj,
202 | 'cosin_similar': meter_dict_train['cosin_similar'].avg}, 'Train-' + args.type, epoch)
203 | vis.plot_many({'valid_acc': valid_acc, 'loss': valid_obj,
204 | 'cosin_similar': meter_dict_val['cosin_similar'].avg}, 'Valid-' + args.type, epoch)
205 |
206 | if isbest:
207 | if args.save_output:
208 | torch.save(output, os.path.join(args.save, '{}-output.pth'.format(args.type)))
209 | EvaluateMetric(PREDICTIONS_PATH=args.save, train_results=train_results, idx=epoch)
210 | for k, v in train_results.items():
211 | if isinstance(v, list):
212 | v.clear()
213 |
214 | def Visfeature(inputs, feature, v_path=None):
215 | if args.visdom['enable']:
216 | vis.featuremap('CNNVision',
217 | torch.sum(make_grid(feature[0].detach(), nrow=int(feature[0].size(0) ** 0.5), padding=2), dim=0).flipud())
218 | vis.featuremap('Attention Maps Similarity',
219 | make_grid(feature[1], nrow=int(feature[1].detach().cpu().size(0) ** 0.5), padding=2)[0].flipud())
220 |
221 | vis.featuremap('Enhancement Weights', feature[3].flipud())
222 | else:
223 | fig = plt.figure()
224 | ax = fig.add_subplot()
225 | sns.heatmap(
226 | torch.sum(make_grid(feature[0].detach(), nrow=int(feature[0].size(0) ** 0.5), padding=2), dim=0).cpu().numpy(),
227 | annot=False, fmt='g', ax=ax)
228 | ax.set_title('CNNVision', fontsize=10)
229 | fig.savefig(os.path.join(args.save, 'CNNVision.jpg'), dpi=fig.dpi)
230 | plt.close()
231 |
232 | fig = plt.figure()
233 | ax = fig.add_subplot()
234 | sns.heatmap(make_grid(feature[1].detach(), nrow=int(feature[1].size(0) ** 0.5), padding=2)[0].cpu().numpy(), annot=False,
235 | fmt='g', ax=ax)
236 | ax.set_title('Attention Maps Similarity', fontsize=10)
237 | fig.savefig(os.path.join(args.save, 'AttMapSimilarity.jpg'), dpi=fig.dpi)
238 | plt.close()
239 |
240 | fig = plt.figure()
241 | ax = fig.add_subplot()
242 | sns.heatmap(feature[3].detach().cpu().numpy(), annot=False, fmt='g', ax=ax)
243 | ax.set_title('Enhancement Weights', fontsize=10)
244 | fig.savefig(os.path.join(args.save, 'EnhancementWeights.jpg'), dpi=fig.dpi)
245 | plt.close()
246 |
247 | #------------------------------------------
248 | # Spatial feature visualization
249 | #------------------------------------------
250 | headmap = feature[-1][0][0,:].detach().cpu().numpy()
251 | headmap = np.mean(headmap, axis=0)
252 | headmap /= np.max(headmap) # torch.Size([64, 7, 7])
253 | headmap = torch.from_numpy(headmap)
254 | img = feature[-1][1]
255 |
256 | result = []
257 | for map, mg in zip(headmap.unsqueeze(1), img.permute(1,2,3,0)):
258 | map = cv2.resize(map.squeeze().cpu().numpy(), (mg.shape[0]//2, mg.shape[1]//2))
259 | map = np.uint8(255 * map)
260 | map = cv2.applyColorMap(map, cv2.COLORMAP_JET)
261 |
262 | mg = np.uint8(mg.cpu().numpy() * 128 + 127.5)
263 | mg = cv2.resize(mg, (mg.shape[0]//2, mg.shape[1]//2))
264 | superimposed_img = cv2.addWeighted(mg, 0.4, map, 0.6, 0)
265 |
266 | result.append(torch.from_numpy(superimposed_img).unsqueeze(0))
267 | superimposed_imgs = torch.cat(result).permute(0, 3, 1, 2)
268 | # save_image(superimposed_imgs, os.path.join(args.save, 'CAM-Features.png'), nrow=int(superimposed_imgs.size(0) ** 0.5), padding=2).permute(1,2,0)
269 | superimposed_imgs = make_grid(superimposed_imgs, nrow=int(superimposed_imgs.size(0) ** 0.5), padding=2).permute(1,2,0)
270 | cv2.imwrite(os.path.join(args.save, 'CAM-Features.png'), superimposed_imgs.numpy())
271 |
272 | if args.eval_only:
273 | MHAS_s, MHAS_m, MHAS_l = feature[4]
274 | MHAS_s, MHAS_m, MHAS_l = MHAS_s.detach().cpu(), MHAS_m.detach().cpu(), MHAS_l.detach().cpu()
275 | # Normalize
276 | att_max, index_max = torch.max(MHAS_s.view(MHAS_s.size(0), -1), dim=-1)
277 | att_min, index_min = torch.min(MHAS_s.view(MHAS_s.size(0), -1), dim=-1)
278 | MHAS_s = (MHAS_s - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1))
279 |
280 | att_max, index_max = torch.max(MHAS_m.view(MHAS_m.size(0), -1), dim=-1)
281 | att_min, index_min = torch.min(MHAS_m.view(MHAS_m.size(0), -1), dim=-1)
282 | MHAS_m = (MHAS_m - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1))
283 |
284 | att_max, index_max = torch.max(MHAS_l.view(MHAS_l.size(0), -1), dim=-1)
285 | att_min, index_min = torch.min(MHAS_l.view(MHAS_l.size(0), -1), dim=-1)
286 | MHAS_l = (MHAS_l - att_min.view(-1, 1, 1))/(att_max.view(-1, 1, 1) - att_min.view(-1, 1, 1))
287 |
288 | mhas_s = make_grid(MHAS_s.unsqueeze(1), nrow=int(MHAS_s.size(0) ** 0.5), padding=2)[0]
289 | mhas_m = make_grid(MHAS_m.unsqueeze(1), nrow=int(MHAS_m.size(0) ** 0.5), padding=2)[0]
290 | mhas_l = make_grid(MHAS_l.unsqueeze(1), nrow=int(MHAS_l.size(0) ** 0.5), padding=2)[0]
291 | if args.visdom['enable']:
292 | vis.featuremap('MHAS Map', mhas_l)
293 |
294 | fig = plt.figure(figsize=(20, 10))
295 | ax = fig.add_subplot(131)
296 | sns.heatmap(mhas_s.squeeze(), annot=False, fmt='g', ax=ax)
297 | ax.set_title('\nMHSA Small', fontsize=10)
298 |
299 | ax = fig.add_subplot(132)
300 | sns.heatmap(mhas_m.squeeze(), annot=False, fmt='g', ax=ax)
301 | ax.set_title('\nMHSA Medium', fontsize=10)
302 |
303 | ax = fig.add_subplot(133)
304 | sns.heatmap(mhas_l.squeeze(), annot=False, fmt='g', ax=ax)
305 | ax.set_title('\nMHSA Large', fontsize=10)
306 | plt.suptitle('{}'.format(v_path[0].split('/')[-1]), fontsize=20)
307 | fig.savefig('demo/{}-MHAS.jpg'.format(args.save.split('/')[-1]), dpi=fig.dpi)
308 | plt.close()
309 |
310 | def train(train_queue, model, criterion, optimizer, epoch, local_rank):
311 | model.train()
312 |
313 | meter_dict = dict(
314 | Total_loss=AverageMeter(),
315 | CE_loss=AverageMeter(),
316 | Distil_loss=AverageMeter()
317 | )
318 | meter_dict.update(dict(
319 | cosin_similar=AverageMeter()
320 | ))
321 | meter_dict['Data_Time'] = AverageMeter()
322 | meter_dict.update(dict(
323 | Acc_1=AverageMeter(),
324 | Acc_2=AverageMeter(),
325 | Acc_3=AverageMeter(),
326 | Acc=AverageMeter()
327 | ))
328 |
329 | end = time.time()
330 | CE = criterion
331 | for step, (inputs, heatmap, target, _) in enumerate(train_queue):
332 | meter_dict['Data_Time'].update((time.time() - end)/args.batch_size)
333 | inputs, target, heatmap = map(lambda x: x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap])
334 |
335 | (logits, xs, xm, xl), distillation_loss, feature = model(inputs, heatmap)
336 | if args.MultiLoss:
337 | lamd1, lamd2, lamd3, lamd4 = map(float, args.loss_lamdb)
338 | globals()['CE_loss'] = lamd1*CE(logits, target) + lamd2*CE(xs, target) + lamd3*CE(xm, target) + lamd4*CE(xl, target)
339 | else:
340 | globals()['CE_loss'] = CE(logits, target)
341 | globals()['Distil_loss'] = distillation_loss * args.distill_lamdb
342 | globals()['Total_loss'] = CE_loss + Distil_loss
343 |
344 | optimizer.zero_grad()
345 | Total_loss.backward()
346 | nn.utils.clip_grad_norm_(model.module.parameters(), args.grad_clip)
347 | optimizer.step()
348 |
349 | #---------------------
350 | # Meter performance
351 | #---------------------
352 | torch.distributed.barrier()
353 | globals()['Acc'] = calculate_accuracy(logits, target)
354 | globals()['Acc_1'] = calculate_accuracy(xs, target)
355 | globals()['Acc_2'] = calculate_accuracy(xm, target)
356 | globals()['Acc_3'] = calculate_accuracy(xl, target)
357 |
358 | for name in meter_dict:
359 | if 'loss' in name:
360 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
361 | if 'cosin' in name:
362 | meter_dict[name].update(float(feature[2]))
363 | if 'Acc' in name:
364 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
365 |
366 | if (step+1) % args.report_freq == 0 and local_rank == 0:
367 | log_info = {
368 | 'Epoch': '{}/{}'.format(epoch + 1, args.epochs),
369 | 'Mini-Batch': '{:0>5d}/{:0>5d}'.format(step + 1,
370 | len(train_queue.dataset) // (args.batch_size * args.nprocs)),
371 | 'Lr': [round(float(g['lr']), 7) for g in optimizer.param_groups],
372 | }
373 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items()))
374 | print_func(log_info)
375 |
376 | if args.vis_feature:
377 | Visfeature(inputs, feature)
378 | end = time.time()
379 |
380 | return meter_dict['Acc'].avg, meter_dict['Total_loss'].avg, meter_dict
381 |
382 | @torch.no_grad()
383 | def concat_all_gather(tensor):
384 | """
385 | Performs all_gather operation on the provided tensors.
386 | *** Warning ***: torch.distributed.all_gather has no gradient.
387 | """
388 | tensors_gather = [torch.ones_like(tensor)
389 | for _ in range(torch.distributed.get_world_size())]
390 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
391 | output = torch.cat(tensors_gather, dim=0)
392 | return output
393 |
394 | @torch.no_grad()
395 | def infer(valid_queue, model, criterion, local_rank, epoch):
396 | model.eval()
397 |
398 | meter_dict = dict(
399 | Total_loss=AverageMeter(),
400 | CE_loss=AverageMeter(),
401 | Distil_loss=AverageMeter()
402 | )
403 | meter_dict.update(dict(
404 | cosin_similar=AverageMeter(),
405 | ))
406 | meter_dict.update(dict(
407 | Acc_1=AverageMeter(),
408 | Acc_2=AverageMeter(),
409 | Acc_3=AverageMeter(),
410 | Acc=AverageMeter()
411 | ))
412 |
413 | meter_dict['Infer_Time'] = AverageMeter()
414 | CE = criterion
415 | grounds, preds, v_paths = [], [], []
416 | output = {}
417 | for step, (inputs, heatmap, target, v_path) in enumerate(valid_queue):
418 | n = inputs.size(0)
419 | end = time.time()
420 | inputs, target, heatmap = map(lambda x: x.cuda(local_rank, non_blocking=True), [inputs, target, heatmap])
421 |
422 | (xs, xm, xl, logits), distillation_loss, feature = model(inputs, heatmap)
423 | meter_dict['Infer_Time'].update((time.time() - end) / n)
424 |
425 | if args.MultiLoss:
426 | lamd1, lamd2, lamd3, lamd4 = map(float, args.loss_lamdb)
427 | globals()['CE_loss'] = lamd1 * CE(logits, target) + lamd2 * CE(xs, target) + lamd3 * CE(xm,
428 | target) + lamd4 * CE(
429 | xl, target)
430 | else:
431 | globals()['CE_loss'] = CE(logits, target)
432 | globals()['Distil_loss'] = distillation_loss * args.distill_lamdb
433 | globals()['Total_loss'] = CE_loss + Distil_loss
434 |
435 | grounds += target.cpu().tolist()
436 | preds += torch.argmax(logits, dim=1).cpu().tolist()
437 | v_paths += v_path
438 | torch.distributed.barrier()
439 | globals()['Acc'] = calculate_accuracy(logits, target)
440 | globals()['Acc_1'] = calculate_accuracy(xs+xm, target)
441 | globals()['Acc_2'] = calculate_accuracy(xs+xl, target)
442 | globals()['Acc_3'] = calculate_accuracy(xl+xm, target)
443 |
444 | for name in meter_dict:
445 | if 'loss' in name:
446 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
447 | if 'cosin' in name:
448 | meter_dict[name].update(float(feature[2]))
449 | if 'Acc' in name:
450 | meter_dict[name].update(reduce_mean(globals()[name], args.nprocs))
451 |
452 | if step % args.report_freq == 0 and local_rank == 0:
453 | log_info = {
454 | 'Epoch': epoch + 1,
455 | 'Mini-Batch': '{:0>4d}/{:0>4d}'.format(step + 1, len(valid_queue.dataset) // (
456 | args.test_batch_size * args.nprocs)),
457 | }
458 | log_info.update(dict((name, '{:.4f}'.format(value.avg)) for name, value in meter_dict.items()))
459 | print_func(log_info)
460 | if args.vis_feature:
461 | Visfeature(inputs, feature, v_path)
462 |
463 | if args.save_output:
464 | for t, logit in zip(v_path, logits):
465 | output[t] = logit
466 | torch.distributed.barrier()
467 | grounds_gather = concat_all_gather(torch.tensor(grounds).cuda(local_rank))
468 | preds_gather = concat_all_gather(torch.tensor(preds).cuda(local_rank))
469 | grounds_gather, preds_gather = list(map(lambda x: x.cpu().numpy(), [grounds_gather, preds_gather]))
470 |
471 | if local_rank == 0:
472 | # v_paths = np.array(v_paths)[random.sample(list(wrong), 10)]
473 | v_paths = np.array(v_paths)
474 | grounds = np.array(grounds)
475 | preds = np.array(preds)
476 | wrong_idx = np.where(grounds != preds)
477 | v_paths = v_paths[wrong_idx[0]]
478 | grounds = grounds[wrong_idx[0]]
479 | preds = preds[wrong_idx[0]]
480 | return max(meter_dict['Acc'].avg, meter_dict['Acc_1'].avg, meter_dict['Acc_2'].avg, meter_dict['Acc_3'].avg), meter_dict['Total_loss'].avg, dict(grounds=grounds_gather, preds=preds_gather, valid_images=(v_paths, grounds, preds)), meter_dict, output
481 |
482 | if __name__ == '__main__':
483 | if args.visdom['enable']:
484 | vis = Visualizer(args.visdom['visname'])
485 | try:
486 | main(args.local_rank, args.nprocs, args)
487 | except KeyboardInterrupt:
488 | torch.cuda.empty_cache()
489 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3:
490 | print('remove {}: Directory'.format(args.save))
491 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save))
492 | os._exit(0)
493 | except Exception:
494 | print(traceback.print_exc())
495 | if os.path.exists(args.save) and len(os.listdir(args.save)) < 3:
496 | print('remove {}: Directory'.format(args.save))
497 | os.system('rm -rf {} \n mv {} ./Checkpoints/trash'.format(args.save, args.save))
498 | os._exit(0)
499 | finally:
500 | torch.cuda.empty_cache()
501 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | from .print_function import print_func
6 | from .build import *
7 | from .evaluate_metric import EvaluateMetric
8 | from .utils import *
9 |
--------------------------------------------------------------------------------
/utils/build.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import torch
6 | import math
7 | import torch.nn.functional as F
8 | from .utils import cosine_scheduler
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 |
13 | class LabelSmoothingCrossEntropy(torch.nn.Module):
14 | def __init__(self, smoothing: float = 0.1,
15 | reduction="mean", weight=None):
16 | super(LabelSmoothingCrossEntropy, self).__init__()
17 | self.smoothing = smoothing
18 | self.reduction = reduction
19 | self.weight = weight
20 |
21 | def reduce_loss(self, loss):
22 | return loss.mean() if self.reduction == 'mean' else loss.sum() \
23 | if self.reduction == 'sum' else loss
24 |
25 | def linear_combination(self, x, y):
26 | return self.smoothing * x + (1 - self.smoothing) * y
27 |
28 | def forward(self, preds, target):
29 | assert 0 <= self.smoothing < 1
30 |
31 | if self.weight is not None:
32 | self.weight = self.weight.to(preds.device)
33 |
34 | n = preds.size(-1)
35 | log_preds = F.log_softmax(preds, dim=-1)
36 | loss = self.reduce_loss(-log_preds.sum(dim=-1))
37 | nll = F.nll_loss(
38 | log_preds, target, reduction=self.reduction, weight=self.weight
39 | )
40 | return self.linear_combination(loss / n, nll)
41 |
42 | def build_optim(args, model):
43 | if args.optim == 'SGD':
44 | optimizer = torch.optim.SGD(
45 | model.parameters(),
46 | args.learning_rate,
47 | momentum=args.momentum,
48 | weight_decay=args.weight_decay
49 | )
50 | elif args.optim == 'Adam':
51 | optimizer = torch.optim.Adam(
52 | model.parameters(),
53 | lr=args.learning_rate
54 | )
55 | elif args.optim == 'AdamW':
56 | optimizer = torch.optim.AdamW(
57 | model.parameters(),
58 | lr=args.learning_rate
59 | )
60 | return optimizer
61 | #
62 | def build_scheduler(args, optimizer):
63 | if args.scheduler['name'] == 'cosin':
64 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
65 | optimizer, float(args.epochs-args.scheduler['warm_up_epochs']), eta_min=args.learning_rate_min)
66 | elif args.scheduler['name'] == 'ReduceLR':
67 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
68 | patience=args.scheduler['patience'], verbose=True,
69 | threshold=0.0001,
70 | threshold_mode='rel', cooldown=3, min_lr=0.00001,
71 | eps=1e-08)
72 | else:
73 | raise NameError('build scheduler error!')
74 |
75 | if args.scheduler['warm_up_epochs'] > 0:
76 | warmup_schedule = lambda epoch: np.linspace(1e-8, args.learning_rate, args.scheduler['warm_up_epochs'])[epoch]
77 | return (scheduler, warmup_schedule)
78 | return (scheduler,)
79 |
80 | def build_loss(args):
81 | loss_Function=dict(
82 | CE_smooth = LabelSmoothingCrossEntropy(),
83 | CE = torch.nn.CrossEntropyLoss(),
84 | MSE = torch.nn.MSELoss(),
85 | BCE = torch.nn.BCELoss(),
86 | )
87 | if args.loss['name'] == 'CE' and args.loss['labelsmooth']:
88 | return loss_Function['CE_smooth']
89 | return loss_Function[args.loss['name']]
90 |
--------------------------------------------------------------------------------
/utils/evaluate_metric.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | # -------------------
6 | # import modules
7 | # -------------------
8 | import random, os
9 | import numpy as np
10 | import cv2
11 | import heapq
12 | import shutil
13 | from textwrap import wrap
14 |
15 | import matplotlib
16 | import matplotlib.pyplot as plt
17 | import matplotlib.image as mpimage
18 |
19 | from sklearn.model_selection import train_test_split
20 | from sklearn.metrics import confusion_matrix, auc, roc_curve, roc_auc_score
21 | import seaborn as sns
22 | from torchvision import transforms
23 | from PIL import Image
24 | import torch
25 | from torchvision.utils import save_image, make_grid
26 | # ---------------------------------------
27 | # Plot accuracy and loss
28 | # ---------------------------------------
29 | def get_error_bar(best_score, valid_examples):
30 | print("--------------------------------------------")
31 | print("Standard Error") # best_score: Average al of scores, valid_examples: num of all samples
32 | print("--------------------------------------------")
33 |
34 | err = np.sqrt((best_score * (1 - best_score)) / valid_examples)
35 | err_rounded_68 = round(err, 2)
36 | err_rounded_95 = round((err_rounded_68 * 2), 2)
37 |
38 | print('Error (68% CI): +- ' + str(err_rounded_68))
39 | print('Error (95% CI): +- ' + str(err_rounded_95))
40 | print()
41 | return err_rounded_68
42 |
43 | def plot_train_results(PREDICTIONS_PATH, train_results, idx):
44 | '''
45 |
46 | :param PREDICTIONS_PATH: plot image save path
47 | :param train_results: {'valid_score':[...], 'valid_loss':[...], 'train_score': [...], 'train_loss':[...]}
48 | :param best_score: validation best acc
49 | :param idx: epoch index
50 | :return: None
51 | '''
52 |
53 | # best_score = sum(train_results['valid_score']) / len(train_results['valid_score'])
54 | valid_examples = len(train_results['valid_score'])
55 | super_category = str(idx)
56 |
57 | best_score = train_results["best_score"]
58 | standard_error = get_error_bar(best_score, valid_examples)
59 | y_upper = train_results["valid_score"] + standard_error
60 | y_lower = train_results["valid_score"] - standard_error
61 |
62 | print("--------------------------------------------")
63 | print("Results")
64 | print("--------------------------------------------")
65 |
66 | fig = plt.figure(figsize=(15, 5))
67 |
68 | plt.subplot(1, 2, 1)
69 | plt.plot(range(0, len(train_results["train_score"])), train_results["train_score"], label='train')
70 |
71 | plt.plot(range(0, len(train_results["valid_score"])), train_results["valid_score"], label='valid')
72 |
73 | kwargs = {'color': 'black', 'linewidth': 1, 'linestyle': '--', 'dashes': (5, 5)}
74 | plt.plot(range(0, len(train_results["valid_score"])), y_lower, **kwargs)
75 | plt.plot(range(0, len(train_results["valid_score"])), y_upper, **kwargs, label='validation SE (68% CI)')
76 |
77 | plt.title('Accuracy Plot - ' + super_category, fontsize=20)
78 | plt.ylabel('Accuracy', fontsize=16)
79 | plt.xlabel('Training Epochs', fontsize=16)
80 | plt.ylim(0, 1)
81 | plt.legend()
82 |
83 | plt.subplot(1, 2, 2)
84 | plt.plot(range(0, len(train_results["train_loss"])), train_results["train_loss"], label='train')
85 | plt.plot(range(0, len(train_results["valid_loss"])), train_results["valid_loss"], label='valid')
86 |
87 | plt.title('Loss Plot - ' + super_category, fontsize=20)
88 | plt.ylabel('Loss', fontsize=16)
89 | plt.xlabel('Training Epochs', fontsize=16)
90 | max_train_loss = max(train_results["train_loss"])
91 | max_valid_loss = max(train_results["valid_loss"])
92 | y_max_t_v = max_valid_loss if max_valid_loss > max_train_loss else max_train_loss
93 | ylim_loss = y_max_t_v if y_max_t_v > 1 else 1
94 | plt.ylim(0, ylim_loss)
95 | plt.legend()
96 |
97 | plt.show()
98 |
99 | fig.savefig(os.path.join(PREDICTIONS_PATH, "train_results_{}.png".format(idx)), dpi=fig.dpi)
100 |
101 |
102 | # ---------------------------------------
103 | # Plot Confusion Matrix
104 | # ---------------------------------------
105 | def plot_confusion_matrix(PREDICTIONS_PATH, grounds, preds, categories, idx, top=20):
106 | print("--------------------------------------------")
107 | print("Confusion Matrix")
108 | print("--------------------------------------------")
109 |
110 | super_category = str(idx)
111 | num_cat = []
112 | for ind, cat in enumerate(categories):
113 | print("Class {0} : {1}".format(ind, cat))
114 | num_cat.append(ind)
115 | print()
116 | numclass = len(num_cat)
117 |
118 | cm = confusion_matrix(grounds, preds, labels=num_cat)
119 | fig = plt.figure(figsize=(10, 8))
120 | ax = fig.add_subplot()
121 | sns.heatmap(cm, annot=False, fmt='g', ax=ax); # annot=True to annotate cells, ftm='g' to disable scientific notation
122 |
123 | # labels, title and ticks
124 | ax.set_title('Confusion Matrix - ' + super_category, fontsize=20)
125 | ax.set_xlabel('Predicted labels', fontsize=16)
126 | ax.set_ylabel('True labels', fontsize=16)
127 |
128 | ax.set_xticks(range(0,len(num_cat), 1))
129 | ax.set_yticks(range(0,len(num_cat), 1))
130 | ax.xaxis.set_ticklabels(num_cat)
131 | ax.yaxis.set_ticklabels(num_cat)
132 |
133 | plt.pause(0.1)
134 | fig.savefig(os.path.join(PREDICTIONS_PATH, "confusion_matrix"), dpi=fig.dpi)
135 |
136 | # -------------------------------------------------
137 | # Plot Accuracy and Precision
138 | # -------------------------------------------------
139 | Accuracy = [(cm[i, i] / sum(cm[i, :])) * 100 if sum(cm[i, :]) != 0 else 0.000001 for i in range(cm.shape[0])]
140 | Precision = [(cm[i, i] / sum(cm[:, i])) * 100 if sum(cm[:, i]) != 0 else 0.000001 for i in range(cm.shape[1])]
141 |
142 | fig = plt.figure(figsize=(int((numclass*3)%300), 8))
143 | ax = fig.add_subplot()
144 |
145 | bar_width = 0.4
146 | x = np.arange(len(Accuracy))
147 | b1 = ax.bar(x, Accuracy, width=bar_width, label='Accuracy', color=sns.xkcd_rgb["pale red"], tick_label=x)
148 |
149 | ax2 = ax.twinx()
150 | b2 = ax2.bar(x + bar_width, Precision, width=bar_width, label='Precision', color=sns.xkcd_rgb["denim blue"])
151 |
152 | average_acc = sum(Accuracy)/len(Accuracy)
153 | average_prec = sum(Precision)/len(Precision)
154 | b3 = plt.hlines(y=average_acc, xmin=-bar_width, xmax=numclass - 1 + bar_width * 2, linewidth=2, linestyles='--', color='r',
155 | label='Average Acc : %0.2f' % average_acc)
156 | b4 = plt.hlines(y=average_prec, xmin=-bar_width, xmax=numclass - 1 + bar_width * 2, linewidth=2, linestyles='--', color='b',
157 | label='Average Prec : %0.2f' % average_prec)
158 | plt.xticks(np.arange(numclass) + bar_width / 2, np.arange(numclass))
159 |
160 | # labels, title and ticks
161 | ax.set_title('Accuracy and Precision Epoch #{}'.format(idx), fontsize=20)
162 | ax.set_xlabel('labels', fontsize=16)
163 | ax.set_ylabel('Acc(%)', fontsize=16)
164 | ax2.set_ylabel('Prec(%)', fontsize=16)
165 | ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
166 |
167 | ax.tick_params(axis='y', colors=b1[0].get_facecolor())
168 | ax2.tick_params(axis='y', colors=b2[0].get_facecolor())
169 |
170 | plt.legend(handles=[b1, b2, b3, b4])
171 | # fig.savefig(os.path.join(PREDICTIONS_PATH, "Accuracy-Precision_{}.png".format(idx)), dpi=fig.dpi)
172 | fig.savefig(os.path.join(PREDICTIONS_PATH, "Accuracy-Precision.png"), dpi=fig.dpi)
173 |
174 | plt.close()
175 |
176 | TopK_idx_acc = heapq.nlargest(top, range(len(Accuracy)), Accuracy.__getitem__)
177 | TopK_idx_prec = heapq.nlargest(top, range(len(Precision)), Precision.__getitem__)
178 |
179 | TopK_low_idx = heapq.nsmallest(top, range(len(Precision)), Precision.__getitem__)
180 |
181 |
182 | print('=' * 80)
183 | print('Accuracy Tok {0}: \n'.format(top))
184 | print('| Class ID \t Accuracy(%) \t Precision(%) |')
185 | for i in TopK_idx_acc:
186 | print('| {0} \t {1} \t {2} |'.format(i, round(Accuracy[i], 2), round(Precision[i], 2)))
187 | print('-' * 80)
188 | print('Precision Tok {0}: \n'.format(top))
189 | print('| Class ID \t Accuracy(%) \t Precision(%) |')
190 | for i in TopK_idx_prec:
191 | print('| {0} \t {1} \t {2} |'.format(i, round(Accuracy[i], 2), round(Precision[i], 2)))
192 | print('=' * 80)
193 |
194 | return TopK_low_idx
195 |
196 |
197 | # Fast Rank Pooling
198 | sample_size = 128
199 | def GenerateRPImage(imgs_path, sl):
200 | def get_DDI(video_arr):
201 | def get_w(N):
202 | return [float(i) * 2 - N - 1 for i in range(1, N + 1)]
203 |
204 | w_arr = get_w(len(video_arr))
205 | re = np.zeros((sample_size, sample_size, 3))
206 | for a, b in zip(video_arr, w_arr):
207 | img = cv2.imread(os.path.join(imgs_path, "%06d.jpg" % a))
208 | img = cv2.resize(img, (sample_size, sample_size))
209 | re += img * b
210 | re -= np.min(re)
211 | re = 255.0 * re / np.max(re) if np.max(re) != 0 else 255.0 * re / (np.max(re) + 0.00001)
212 |
213 | return re.astype('uint8')
214 |
215 | return get_DDI(sl)
216 |
217 | # ---------------------------------------
218 | # Wrongly Classified Images
219 | # ---------------------------------------
220 | def plot_wrongly_classified_images(PREDICTIONS_PATH, TopK_low_idx, valid_images, idx):
221 | print("--------------------------------------------")
222 | print("Wrongly Classified Images")
223 | print("--------------------------------------------")
224 |
225 | v_paths, grounds, preds = valid_images
226 | f = lambda n, sn: [(lambda n, arr: n if arr == [] else int(np.mean(arr)))(n * i / sn, range(int(n * i / sn),
227 | max(int(
228 | n * i / sn) + 1,
229 | int(n * (
230 | i + 1) / sn))))
231 | for i in range(sn)]
232 |
233 | train_images = []
234 | ground, pred, pred_lbl_file = [], [], []
235 | for g, p, v in zip(grounds, preds, v_paths):
236 | assert p != g, 'Pred: {} equ to ground-truth: {}'.format(p, g)
237 | if g in TopK_low_idx[:10]:
238 | imgs = [transforms.ToTensor()(Image.open(os.path.join(v, "%06d.jpg" % a)).resize((200, 200))).unsqueeze(0) for a in f(len(os.listdir(v))//2, 10)]
239 | train_images.append(make_grid(torch.cat(imgs), nrow=10, padding=2).permute(1, 2, 0))
240 | ground.append(g)
241 | pred.append(p)
242 | pred_lbl_file.append(v)
243 | if len(train_images) > 9:
244 | break
245 |
246 | fig = plt.figure(figsize=(30, 20))
247 | k = 0
248 | for i in range(0, len(train_images)):
249 | fig.add_subplot(10, 1, k + 1)
250 | plt.axis('off')
251 | if i == 0:
252 | title = "Orig lbl: " + str(ground[i]) + " Pred lbl: " + str(pred[i]) + " " + pred_lbl_file[i]
253 | else:
254 | title = '\n'*10 + "Orig lbl: " + str(ground[i]) + " Pred lbl: " + str(pred[i]) + " " + pred_lbl_file[i]
255 | plt.title(title)
256 | plt.imshow(train_images[i])
257 | k += 1
258 |
259 | plt.pause(0.1)
260 | print()
261 | fig.savefig(os.path.join(PREDICTIONS_PATH, "wrongly_classified_images.png".format(idx)), dpi=fig.dpi)
262 | plt.close()
263 |
264 | def EvaluateMetric(PREDICTIONS_PATH, train_results, idx):
265 | TopK_low_idx = plot_confusion_matrix(PREDICTIONS_PATH, train_results['grounds'], train_results['preds'], train_results['categories'], idx)
266 |
--------------------------------------------------------------------------------
/utils/print_function.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2010-2021 Alibaba Group Holding Limited.
3 | '''
4 |
5 | import logging
6 |
7 | def print_func(info):
8 | '''
9 | :param info: {name: value}
10 | :return:
11 | '''
12 | txts = []
13 | for name, value in info.items():
14 | txts.append('{}: {}'.format(name, value))
15 | logging.info('\t'.join(txts))
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/yuhuixu1993/PC-DARTS/blob/master/utils.py
4 | '''
5 |
6 | import os
7 | import numpy as np
8 | import torch
9 | import shutil
10 | import torchvision.transforms as transforms
11 | from torch.autograd import Variable
12 | from collections import OrderedDict
13 |
14 | class ClassAcc():
15 | def __init__(self, GESTURE_CLASSES):
16 | self.class_acc = dict(zip([i for i in range(GESTURE_CLASSES)], [0]*GESTURE_CLASSES))
17 | self.single_class_num = [0]*GESTURE_CLASSES
18 | def update(self, logits, target):
19 | pred = torch.argmax(logits, dim=1)
20 | for p, t in zip(pred.cpu().numpy(), target.cpu().numpy()):
21 | if p == t:
22 | self.class_acc[t] += 1
23 | self.single_class_num[t] += 1
24 | def result(self):
25 | return [round(v / (self.single_class_num[k]+0.000000001), 4) for k, v in self.class_acc.items()]
26 | class AverageMeter(object):
27 |
28 | def __init__(self):
29 | self.reset()
30 |
31 | def reset(self):
32 | self.avg = 0
33 | self.sum = 0
34 | self.cnt = 0
35 |
36 | def update(self, val, n=1):
37 | self.sum += val * n
38 | self.cnt += n
39 | self.avg = self.sum / self.cnt
40 |
41 | def adjust_learning_rate(optimizer, step, lr):
42 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
43 | df = 0.7
44 | ds = 40000.0
45 | lr = lr * np.power(df, step / ds)
46 | # lr = args.lr * (0.1**(epoch // 30))
47 | for param_group in optimizer.param_groups:
48 | param_group['lr'] = lr
49 | return lr
50 |
51 | def accuracy(output, target, topk=(1,)):
52 | maxk = max(topk)
53 | batch_size = target.size(0)
54 |
55 | _, pred = output.topk(maxk, 1, True, True)
56 | pred = pred.t()
57 | correct = pred.eq(target.view(1, -1).expand_as(pred))
58 |
59 | res = []
60 | for k in topk:
61 | correct_k = correct[:k].view(-1).float().sum(0)
62 | res.append(correct_k.mul_(100.0/batch_size))
63 | return res
64 |
65 | def calculate_accuracy(outputs, targets):
66 | with torch.no_grad():
67 | batch_size = targets.size(0)
68 | _, pred = outputs.topk(1, 1, True)
69 | pred = pred.t()
70 | correct = pred.eq(targets.view(1, -1))
71 | correct_k = correct.view(-1).float().sum(0, keepdim=True)
72 | #n_correct_elems = correct.float().sum().data[0]
73 | # n_correct_elems = correct.float().sum().item()
74 | # return n_correct_elems / batch_size
75 | return correct_k.mul_(1.0 / batch_size)
76 |
77 | def count_parameters_in_MB(model):
78 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
79 |
80 |
81 | def save_checkpoint(state, is_best=False, save='./', filename='checkpoint.pth.tar'):
82 | filename = os.path.join(save, filename)
83 | torch.save(state, filename)
84 | if is_best:
85 | best_filename = os.path.join(save, 'model_best.pth.tar')
86 | shutil.copyfile(filename, best_filename)
87 |
88 | def load_checkpoint(model, model_path, optimizer=None):
89 | # checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(4))
90 | checkpoint = torch.load(model_path, map_location='cpu')
91 | model.load_state_dict(checkpoint['model'])
92 | if optimizer:
93 | optimizer.load_state_dict(checkpoint['optimizer'])
94 | epoch = checkpoint['epoch']
95 | bestacc = checkpoint['bestacc']
96 | return model, optimizer, epoch, bestacc
97 |
98 | def load_pretrained_checkpoint(model, model_path):
99 | # params = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(local_rank))['model']
100 | params = torch.load(model_path, map_location='cpu')['model']
101 | new_state_dict = OrderedDict()
102 |
103 | for k, v in params.items():
104 | name = k[7:] if k[:7] == 'module.' else k
105 | try:
106 | if v.shape == model.state_dict()[name].shape:
107 | if name not in ['dtn.mlp_head_small.1.bias', "dtn.mlp_head_small.1.weight",
108 | 'dtn.mlp_head_media.1.bias', "dtn.mlp_head_media.1.weight",
109 | 'dtn.mlp_head_large.1.bias', "dtn.mlp_head_large.1.weight"]:
110 | new_state_dict[name] = v
111 | except:
112 | continue
113 | ret = model.load_state_dict(new_state_dict, strict=False)
114 | print('Missing keys: \n', ret.missing_keys)
115 | return model
116 |
117 | def drop_path(x, drop_prob):
118 | if drop_prob > 0.:
119 | keep_prob = 1.-drop_prob
120 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
121 | x.div_(keep_prob)
122 | x.mul_(mask)
123 | return x
124 |
125 |
126 | def create_exp_dir(path, scripts_to_save=None):
127 | if not os.path.exists(path):
128 | os.mkdir(path)
129 | print('Experiment dir : {}'.format(path))
130 |
131 | if scripts_to_save is not None:
132 | os.mkdir(os.path.join(path, 'scripts'))
133 | for script in scripts_to_save:
134 | if os.path.isdir(script) and script != '__pycache__':
135 | dst_file = os.path.join(path, 'scripts', script)
136 | shutil.copytree(script, dst_file)
137 | else:
138 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
139 | shutil.copyfile(script, dst_file)
140 |
141 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
142 | warmup_schedule = np.array([])
143 | warmup_iters = warmup_epochs * niter_per_ep
144 | if warmup_epochs > 0:
145 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
146 |
147 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
148 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
149 |
150 | schedule = np.concatenate((warmup_schedule, schedule))
151 | assert len(schedule) == epochs * niter_per_ep
152 | return schedule
--------------------------------------------------------------------------------
/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | '''
2 | This file is modified from:
3 | https://github.com/zhoubenjia/RAAR3DNet/blob/master/Network_Train/utils/visualizer.py
4 | '''
5 |
6 |
7 | #coding: utf8
8 |
9 | import numpy as np
10 | import time
11 |
12 |
13 | class Visualizer():
14 | def __init__(self, env='default', **kwargs):
15 | import visdom
16 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs)
17 |
18 | self.index = {}
19 | self.log_text = ''
20 |
21 | def reinit(self, env='defult', **kwargs):
22 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs)
23 | return self
24 |
25 | def plot_many(self, d, modality, epoch=None):
26 | colmu_stac = []
27 | for k, v in d.items():
28 | colmu_stac.append(np.array(v))
29 | if epoch:
30 | x = epoch
31 | else:
32 | x = self.index.get(modality, 0)
33 | # self.vis.line(Y=np.column_stack((np.array(dicts['loss1']), np.array(dicts['loss2']))),
34 | self.vis.line(Y=np.column_stack(tuple(colmu_stac)),
35 | X=np.array([x]),
36 | win=(modality),
37 | # opts=dict(title=modality,legend=['loss1', 'loss2'], ylabel='loss value'),
38 | opts=dict(title=modality, legend=list(d.keys()), ylabel='Value', xlabel='Iteration'),
39 | update=None if x == 0 else 'append')
40 | if not epoch:
41 | self.index[modality] = x + 1
42 |
43 | def plot(self, name, y):
44 | """
45 | self.plot('loss',1.00)
46 | """
47 | x = self.index.get(name, 0)
48 | self.vis.line(Y=np.array([y]), X=np.array([x]),
49 | win=(name),
50 | opts=dict(title=name),
51 | update=None if x == 0 else 'append'
52 | )
53 | self.index[name] = x + 1
54 |
55 | def log(self, info, win='log_text'):
56 | """
57 | self.log({'loss':1,'lr':0.0001})
58 | """
59 |
60 | self.log_text += ('[{time}] {info}
'.format(
61 | time=time.strftime('%m.%d %H:%M:%S'),
62 | info=info))
63 | self.vis.text(self.log_text, win=win)
64 |
65 | def img_grid(self, name, input_3d, heatmap=False):
66 | self.vis.images(
67 | # np.random.randn(20, 3, 64, 64),
68 | show_image_grid(input_3d, name, heatmap),
69 | win=name,
70 | opts=dict(title=name, caption='img_grid.')
71 | )
72 | def img(self, name, input):
73 | self.vis.images(
74 | input,
75 | win=name,
76 | opts=dict(title=name, caption='RGB Images.')
77 | )
78 |
79 | def draw_curve(self, name, data):
80 | self.vis.line(Y=np.array(data), X=np.array(range(len(data))),
81 | win=(name),
82 | opts=dict(title=name),
83 | update=None
84 | )
85 |
86 | def featuremap(self, name, input):
87 | self.vis.heatmap(input, win=name, opts=dict(title=name))
88 |
89 | def draw_bar(self, name, inp):
90 | self.vis.bar(
91 | X=np.abs(np.array(inp)),
92 | win=name,
93 | opts=dict(
94 | stacked=True,
95 | legend=list(map(str, range(inp.shape[-1]))),
96 | rownames=list(map(str, range(inp.shape[0])))
97 | )
98 | )
99 |
100 |
101 | def __getattr__(self, name):
102 | return getattr(self.vis, name)
103 |
--------------------------------------------------------------------------------