├── .gitignore
├── .idea
├── .gitignore
├── Angular.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── __init__.py
├── __pycache__
└── utils.cpython-36.pyc
├── ang_adjs.py
├── class_labels.py
├── config
├── test.yaml
└── train.yaml
├── encoding
├── __pycache__
│ └── data_encoder.cpython-36.pyc
└── data_encoder.py
├── feeders
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── feeder.cpython-36.pyc
│ ├── feeder.cpython-37.pyc
│ ├── feeder_as_gcn.cpython-36.pyc
│ ├── feeder_as_gcn.cpython-37.pyc
│ ├── feeder_dgnn.cpython-36.pyc
│ ├── feeder_dgnn.cpython-37.pyc
│ ├── tools.cpython-36.pyc
│ └── tools.cpython-37.pyc
├── feeder.py
├── feeder_as_gcn.py
├── feeder_dgnn.py
└── tools.py
├── figures
├── Architecture.png
├── angle.png
└── skeletons.png
├── graph
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── ang_adjs.cpython-36.pyc
│ ├── ang_adjs.cpython-37.pyc
│ ├── azure_kinect.cpython-36.pyc
│ ├── azure_kinect.cpython-37.pyc
│ ├── directed_ntu_rgb_d.cpython-36.pyc
│ ├── directed_ntu_rgb_d.cpython-37.pyc
│ ├── hyper_graphs.cpython-36.pyc
│ ├── hyper_graphs.cpython-37.pyc
│ ├── kinetics.cpython-36.pyc
│ ├── kinetics.cpython-37.pyc
│ ├── ntu_rgb_d.cpython-36.pyc
│ ├── ntu_rgb_d.cpython-37.pyc
│ ├── tools.cpython-36.pyc
│ └── tools.cpython-37.pyc
├── ang_adjs.py
├── azure_kinect.py
├── directed_ntu_rgb_d.py
├── hyper_graphs.py
├── kinetics.py
├── ntu_rgb_d.py
└── tools.py
├── main.py
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── activation.cpython-36.pyc
│ ├── att_gcn.cpython-36.pyc
│ ├── hyper_gcn.cpython-36.pyc
│ ├── mlp.cpython-36.pyc
│ ├── modules.cpython-36.pyc
│ ├── ms_gcn.cpython-36.pyc
│ ├── ms_gtcn.cpython-36.pyc
│ ├── ms_tcn.cpython-36.pyc
│ └── network.cpython-36.pyc
├── activation.py
├── att_gcn.py
├── hyper_gcn.py
├── mlp.py
├── modules.py
├── ms_gcn.py
├── ms_gtcn.py
├── ms_tcn.py
└── network.py
├── notification
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── email_config.cpython-36.pyc
│ ├── email_config.cpython-37.pyc
│ ├── email_sender.cpython-36.pyc
│ ├── email_sender.cpython-37.pyc
│ ├── email_templates.cpython-36.pyc
│ ├── email_templates.cpython-37.pyc
│ ├── html_templates.cpython-36.pyc
│ └── html_templates.cpython-37.pyc
├── email_config.py
├── email_sender.py
├── email_templates.py
├── email_test.py
└── html_templates.py
├── processor
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── args.cpython-36.pyc
├── args.py
├── io.py
├── processor.py
├── recognition.py
└── torchlight_io.py
├── train.sh
├── utils.py
└── utils_dir
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── utils_cam.cpython-36.pyc
├── utils_io.cpython-36.pyc
├── utils_math.cpython-36.pyc
├── utils_result.cpython-36.pyc
└── utils_visual.cpython-36.pyc
├── utils_cam.py
├── utils_io.py
├── utils_math.py
├── utils_result.py
└── utils_visual.py
/.gitignore:
--------------------------------------------------------------------------------
1 | work_dir
2 | apex
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/Angular.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
161 |
162 |
163 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Angular Encoding for Skeleton-Based Action Recognition
5 |
6 |
7 |
8 |
9 | Overview
10 |
11 |
12 |
13 |
14 | PyTorch implementation of "TNNLS 2022: Fusing Higher-Order Features in Graph Neural Networks for Skeleton-Based Action Recognition".
15 | (https://arxiv.org/pdf/2105.01563.pdf).
16 |
17 |
18 |
19 | ## Angular Features
20 |
21 |
22 |
23 | ## Network Architecture
24 |
25 |
26 |
27 | ## Dependencies
28 |
29 | - Python >= 3.6
30 | - PyTorch >= 1.2.0
31 | - [NVIDIA Apex](https://github.com/NVIDIA/apex) (auto mixed precision training)
32 | - PyYAML, tqdm, tensorboardX, matplotlib, seaborn
33 |
34 | ## Data Preparation
35 |
36 | ### Download Datasets
37 |
38 | There are 2 datasets to download:
39 | - NTU RGB+D 60 Skeleton
40 | - NTU RGB+D 120 Skeleton
41 |
42 | Request the datasets here: http://rose1.ntu.edu.sg/Datasets/actionRecognition.asp
43 |
44 | ### Data Preprocessing
45 |
46 | #### Directory Structure
47 |
48 | Put downloaded data into the following directory structure:
49 |
50 | ```
51 | - data/
52 | - nturgbd_raw/
53 | - nturgb+d_skeletons/ # from `nturgbd_skeletons_s001_to_s017.zip`
54 | ...
55 | - nturgb+d_skeletons120/ # from `nturgbd_skeletons_s018_to_s032.zip`
56 | ```
57 |
58 | #### Generating Data
59 |
60 | - `cd data_gen`
61 | - `python3 ntu_gendata.py`
62 | - `python3 ntu120_gendata.py`
63 | - This can take hours. Better CPUs lead to much faster processing.
64 |
65 | ## Training
66 | ```
67 | bash train.sh
68 | ```
69 |
70 | ## Testing
71 | ```
72 | bash test.sh
73 | ```
74 |
75 | ## Acknowledgements
76 |
77 | This repo is based on
78 | - [MS-G3D](https://github.com/kenziyuliu/ms-g3d)
79 | - [2s-AGCN](https://github.com/lshiwjx/2s-AGCN)
80 | - [ST-GCN](https://github.com/yysijie/st-gcn)
81 |
82 | Thanks to the original authors for their work!
83 |
84 | The flat icon is from [Freepik](https://www.freepik.com/).
85 |
86 | ## Citation
87 |
88 | Please cite this work if you find it useful:
89 |
90 | ```
91 | @article{DBLP:journals/corr/abs-2105-01563,
92 | author = {Zhenyue Qin and Yang Liu and Pan Ji and Dongwoo Kim and Lei Wang and
93 | Bob McKay and Saeed Anwar and Tom Gedeon},
94 | title = {Fusing Higher-Order Features in Graph Neural Networks for Skeleton-based Action Recognition},
95 | journal = {IEEE Transactions on Neural Networks and Learning Systems (TNNLS)},
96 | year = {2022}
97 | }
98 | ```
99 |
100 |
101 | ## Contact
102 | If you have further question, please email `zhenyue.qin@anu.edu.au` or `yang.liu3@anu.edu.au`.
103 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from . import tools
2 | from . import ntu_rgb_d
3 | from . import kinetics
4 | from . import azure_kinect
5 | from . import directed_ntu_rgb_d
6 |
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/ang_adjs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from graph.tools import normalize_adjacency_matrix
5 |
6 |
7 | def get_ang_adjs(data_type):
8 | rtn_adjs = []
9 |
10 | if data_type == 'ntu':
11 | node_num = 25
12 | sym_pairs = ((21, 2), (11, 7), (18, 14), (20, 16), (24, 25), (22, 23))
13 | for a_sym_pair in sym_pairs:
14 | a_adj = np.eye(node_num)
15 | a_adj[:, a_sym_pair[0]-1] = 1
16 | a_adj[:, a_sym_pair[1]-1] = 1
17 | a_adj = torch.tensor(normalize_adjacency_matrix(a_adj))
18 | rtn_adjs.append(a_adj)
19 |
20 | return torch.cat(rtn_adjs, dim=0)
21 |
--------------------------------------------------------------------------------
/class_labels.py:
--------------------------------------------------------------------------------
1 | anu_bullying_pair_labels = {
2 | 1: 'G1A1: hit with knees',
3 | 2: 'G2A1: hit with head',
4 | 3: 'G3A1: punch to face',
5 | 4: 'G4A1: punch to body',
6 | 5: 'G5A1: cover mouth',
7 | 6: 'G6A1: pinch neck',
8 | 7: 'G7A1: slap',
9 | 8: 'G8A1: kicking',
10 | 9: 'G9A1: pushing',
11 | 10: 'G10A1: pierce others',
12 | 11: 'G11A1: pull hairs',
13 | 12: 'G12A1: drag other person',
14 | 13: 'G13A1: pull collar',
15 | 14: 'G14A1: swing others',
16 | 15: 'G15A1: beat with elbow',
17 | 16: 'G16A1: knoch over',
18 | 17: 'G17A1: hit with object',
19 | 18: 'G18A1: point to person',
20 | 19: 'G19A1: cuff ear',
21 | 20: 'G20A1: pinch arms',
22 | 21: 'G21A1: use cigarette to burn',
23 | 22: 'G22A1: sidekick person',
24 | 23: 'G23A1: cast to person',
25 | 24: 'G24A1: shoot person',
26 | 25: 'G25A1: stab person',
27 | 26: 'G26A1: wave knife to others',
28 | 27: 'G27A1: splash liquid on person',
29 | 28: 'G28A1: stumble person',
30 | 29: 'G29A1: step on foot',
31 | 30: 'G30A1: touch pocket',
32 | 31: 'G31A1: bite person',
33 | 32: 'G32A1: take picture for others',
34 | 33: 'G33A1: spiting to person',
35 | 34: 'G34A1: chop person',
36 | 35: 'G35A1: take chair while other sitting',
37 | 36: 'G36A1: pat on head',
38 | 37: 'G37A1: pinch face',
39 | 38: 'G38A1: pinch body',
40 | 39: 'G39A1: follow person',
41 | 40: 'G40A1: belt person',
42 | 41: 'G1A2: nod head',
43 | 42: 'G2A2: bow',
44 | 43: 'G3A2: shake hands',
45 | 44: 'G4A2: rock-paper-scissors',
46 | 45: 'G5A2: touch elbows',
47 | 46: 'G6A2: wave hand',
48 | 47: 'G7A2: fist bumping',
49 | 48: 'G8A2: pat on back',
50 | 49: 'G9A2: giving object',
51 | 50: 'G10A2: exchange object',
52 | 51: 'G11A2: clapping; hushing',
53 | 52: 'G12A2: drink water; brush teeth',
54 | 53: 'G13A2: stand up; jump up',
55 | 54: 'G14A2: take off a hat; play a phone',
56 | 55: 'G15A2: take a selfie; wipe face',
57 | 56: 'G16A2: cross hands in front; throat-slitting',
58 | 57: 'G17A2: crawling; open bottle',
59 | 58: 'G18A2: sneeze; yawn',
60 | 59: 'G19A2: self-cutting with knife; take off headphone',
61 | 60: 'G20A2: stretch oneself; flick hair',
62 | 61: 'G21A2: thumb up; thumb down',
63 | 62: 'G22A2: make ok sign; make victory sign',
64 | 63: 'G23A2: cutting nails; cutting paper',
65 | 64: 'G24A2: squat down; toss a coin',
66 | 65: 'G25A2: fold paper; ball up paper',
67 | 66: 'G26A2: play magic cube; surrender',
68 | 67: 'G27A2: apply cream on face; apply cream on hand',
69 | 68: 'G28A2: put on bag; take off bag',
70 | 69: 'G29A2: put object into bag; take object out of bag',
71 | 70: 'G30A2: open a box; yelling',
72 | 71: 'G31A2: arm circles; arm swings',
73 | 72: 'G32A2: whisper',
74 | 73: 'G33A2: clapping each other',
75 | 74: 'G34A2: running; vomiting',
76 | 75: 'G35A2: walk apart',
77 | 76: 'G36A2: headache; back pain',
78 | 77: 'G37A2: walk toward',
79 | 78: 'G38A2: falling down; chest pain',
80 | 79: 'G39A2: walk and hold person',
81 | 80: 'G40A2: cheers and drink',
82 | }
83 |
84 | ntu120_code_labels = {
85 | 1: 'drink water',
86 | 2: 'eat meal/snack"',
87 | 3: "brushing teeth",
88 | 4: "brushing hair",
89 | 5: "drop",
90 | 6: "pickup",
91 | 7: "throw",
92 | 8: "sitting down",
93 | 9: "standing up (from sitting position)",
94 | 10: "clapping",
95 | 11: "reading",
96 | 12: "writing",
97 | 13: "tear up paper",
98 | 14: "wear jacket",
99 | 15: "take off jacket",
100 | 16: "wear a shoe",
101 | 17: "take off a shoe",
102 | 18: "wear on glasses",
103 | 19: "take off glasses",
104 | 20: "put on a hat/cap",
105 | 21: "take off a hat/cap",
106 | 22: "cheer up",
107 | 23: "hand waving",
108 | 24: "kicking something",
109 | 25: "reach into pocket",
110 | 26: "hopping (one foot jumping)",
111 | 27: "jump up",
112 | 28: "make a phone call/answer phone",
113 | 29: "playing with phone/tablet",
114 | 30: "typing on a keyboard",
115 | 31: "pointing to something with finger",
116 | 32: "taking a selfie",
117 | 33: "check time (from watch)",
118 | 34: "rub two hands together",
119 | 35: "nod head/bow",
120 | 36: "shake head",
121 | 37: "wipe face",
122 | 38: "salute",
123 | 39: "put the palms together",
124 | 40: "cross hands in front (say stop)",
125 | 41: "sneeze/cough",
126 | 42: "staggering",
127 | 43: "falling",
128 | 44: "touch head (headache)",
129 | 45: "touch chest (stomachache/heart pain)",
130 | 46: "touch back (backache)",
131 | 47: "touch neck (neckache)",
132 | 48: "nausea or vomiting condition",
133 | 49: "use a fan (with hand or paper)/feeling warm",
134 | 50: "punching/slapping other person",
135 | 51: "kicking other person",
136 | 52: "pushing other person",
137 | 53: "pat on back of other person",
138 | 54: "point finger at the other person",
139 | 55: "hugging other person",
140 | 56: "giving something to other person",
141 | 57: "touch other person's pocket",
142 | 58: "handshaking",
143 | 59: "walking towards each other",
144 | 60: "walking apart from each other",
145 | 61: "put on headphone",
146 | 62: "take off headphone",
147 | 63: "shoot at the basket",
148 | 64: "bounce ball",
149 | 65: "tennis bat swing",
150 | 66: "juggling table tennis balls",
151 | 67: "hush (quite)",
152 | 68: "flick hair",
153 | 69: "thumb up",
154 | 70: "thumb down",
155 | 71: "make ok sign",
156 | 72: "make victory sign",
157 | 73: "staple book",
158 | 74: "counting money",
159 | 75: "cutting nails",
160 | 76: "cutting paper (using scissors)",
161 | 77: "snapping fingers",
162 | 78: "open bottle",
163 | 79: "sniff (smell)",
164 | 80: "squat down",
165 | 81: "toss a coin",
166 | 82: "fold paper",
167 | 83: "ball up paper",
168 | 84: "play magic cube",
169 | 85: "apply cream on face",
170 | 86: "apply cream on hand back",
171 | 87: "put on bag",
172 | 88: "take off bag",
173 | 89: "put something into a bag",
174 | 90: "take something out of a bag",
175 | 91: "open a box",
176 | 92: "move heavy objects",
177 | 93: "shake fist",
178 | 94: "throw up cap/hat",
179 | 95: "hands up (both hands)",
180 | 96: "cross arms",
181 | 97: "arm circles",
182 | 98: "arm swings",
183 | 99: "running on the spot",
184 | 100: "butt kicks (kick backward)",
185 | 101: "cross toe touch",
186 | 102: "side kick",
187 | 103: "yawn",
188 | 104: "stretch oneself",
189 | 105: "blow nose",
190 | 106: "hit other person with something",
191 | 107: "wield knife towards other person",
192 | 108: "knock over other person (hit with body)",
193 | 109: "grab other person’s stuff",
194 | 110: "shoot at other person with a gun",
195 | 111: "step on foot",
196 | 112: "high-five",
197 | 113: "cheers and drink",
198 | 114: "carry something with other person",
199 | 115: "take a photo of other person",
200 | 116: "follow other person",
201 | 117: "whisper in other person’s ear",
202 | 118: "exchange things with other person",
203 | 119: "support somebody with hand",
204 | 120: "finger-guessing game (playing rock-paper-scissors)",
205 | }
206 |
207 | bly_labels = {
208 | 1: 'G1A1: hit with knees',
209 | 2: 'G2A1: hit with head',
210 | 3: 'G3A1: punch to face',
211 | 4: 'G4A1: punch to body',
212 | 5: 'G5A1: cover mouth',
213 | 6: 'G6A1: pinch neck',
214 | 7: 'G7A1: slap',
215 | 8: 'G8A1: kicking',
216 | 9: 'G9A1: pushing',
217 | 10: 'G10A1: pierce others',
218 | 11: 'G11A1: pull hairs',
219 | 12: 'G12A1: drag other person',
220 | 13: 'G13A1: pull collar',
221 | 14: 'G14A1: swing others',
222 | 15: 'G15A1: beat with elbow',
223 | 16: 'G16A1: knoch over',
224 | 17: 'G17A1: hit with object',
225 | 18: 'G18A1: point to person',
226 | 19: 'G19A1: cuff ear',
227 | 20: 'G20A1: pinch arms',
228 | 21: 'G21A1: use cigarette to burn',
229 | 22: 'G22A1: sidekick person',
230 | 23: 'G23A1: cast to person',
231 | 24: 'G24A1: shoot person',
232 | 25: 'G25A1: stab person',
233 | 26: 'G26A1: wave knife to others',
234 | 27: 'G27A1: splash liquid on person',
235 | 28: 'G28A1: stumble person',
236 | 29: 'G29A1: step on foot',
237 | 30: 'G30A1: touch pocket',
238 | 31: 'G31A1: bite person',
239 | 32: 'G33A1: spiting to person',
240 | 33: 'G34A1: chop person',
241 | 34: 'G35A1: take chair while other sitting',
242 | 35: 'G36A1: pat on head',
243 | 36: 'G37A1: pinch face',
244 | 37: 'G38A1: pinch body',
245 | 38: 'G40A1: belt person',
246 | }
247 |
248 | anubis_ind_actions = {
249 | 0: 'hit with knees',
250 | 1: 'hit with head',
251 | 2: 'punch to face',
252 | 3: 'punch to body',
253 | 4: 'cover mouth',
254 | 5: 'strangling neck',
255 | 6: 'slap',
256 | 7: 'kicking',
257 | 8: 'pushing',
258 | 9: 'pierce arms',
259 | 10: 'pull hairs',
260 | 11: 'drag other person (other resist)',
261 | 12: 'pull collar',
262 | 13: 'shake others',
263 | 14: 'beat with elbow',
264 | 15: 'knock over',
265 | 16: 'hit with object',
266 | 17: 'point to person',
267 | 18: 'lift ear',
268 | 19: 'pinch arms',
269 | 20: 'use cigarette to burn',
270 | 21: 'sidekick person',
271 | 22: 'pick and throw an object to person',
272 | 23: 'shoot person',
273 | 24: 'stab person',
274 | 25: 'wave knife to others',
275 | 26: 'splash liquid on person',
276 | 27: 'stumble person',
277 | 28: 'step on foot',
278 | 29: 'pickpocketing',
279 | 30: 'bite person',
280 | 31: 'take picture for others (sneakily)',
281 | 32: 'spiting to person',
282 | 33: 'chop (cut) person',
283 | 34: 'take chair while other sitting',
284 | 35: 'hit the head with hand',
285 | 36: 'pinch face with two hands',
286 | 37: 'pinch body (not arm)',
287 | 38: 'follow person',
288 | 39: 'belt person',
289 | 40: 'nod head',
290 | 41: 'bow',
291 | 42: 'shake hands',
292 | 43: 'rock-paper-scissors',
293 | 44: 'touch elbows',
294 | 45: 'wave hand',
295 | 46: 'fist bumping',
296 | 47: 'pat on shoulders',
297 | 48: 'giving object',
298 | 49: 'exchange object',
299 | 50: 'clapping and hushing',
300 | 51: 'drink water',
301 | 52: 'brush teeth',
302 | 53: 'stand up',
303 | 54: 'jump up',
304 | 55: 'take off a hat',
305 | 56: 'play a phone',
306 | 57: 'take a selfie',
307 | 58: 'wipe face',
308 | 59: 'cross hands in front',
309 | 60: 'throat-slitting',
310 | 61: 'crawling',
311 | 62: 'open bottle',
312 | 63: 'sneeze',
313 | 64: 'yawn',
314 | 65: 'self-cutting with knife',
315 | 66: 'take off headphone',
316 | 67: 'stretch oneself',
317 | 68: 'flick hair',
318 | 69: 'thumb up',
319 | 70: 'thumb down',
320 | 71: 'make ok sign',
321 | 72: 'make victory sign',
322 | 73: 'cutting nails',
323 | 74: 'cutting paper',
324 | 75: 'squat down',
325 | 76: 'toss a coin',
326 | 77: 'fold paper',
327 | 78: 'ball up paper',
328 | 79: 'play magic cube',
329 | 80: 'surrender',
330 | 81: 'apply cream on face',
331 | 82: 'apply cream on hand',
332 | 83: 'put on bag',
333 | 84: 'take off bag',
334 | 85: 'put object into bag',
335 | 86: 'take object out of bag',
336 | 87: 'open a box and yelling',
337 | 88: 'arm circles',
338 | 89: 'arm swings',
339 | 90: 'whisper',
340 | 91: 'clapping each other',
341 | 92: 'running',
342 | 93: 'vomiting',
343 | 94: 'walk apart',
344 | 95: 'headache',
345 | 96: 'back pain',
346 | 97: 'walk form apart to together',
347 | 98: 'falling down',
348 | 99: 'chest pain',
349 | 100: 'support with arms for old people walking',
350 | 101: 'cheers and drink'}
351 |
--------------------------------------------------------------------------------
/config/test.yaml:
--------------------------------------------------------------------------------
1 |
2 | work_dir: ./work_dir/ntu120_xsub_2021_angular_test
3 |
4 | # feeder
5 | feeder: feeders.feeder.Feeder
6 |
7 | test_feeder_args:
8 | data_path: ./data/ntu120/xsub/val_data_joint.npy
9 | label_path: ./data/ntu120/xsub/val_label.pkl
10 |
11 | # model
12 | model: model.network.Model
13 | model_args:
14 | in_channels: 15
15 | num_class: 120
16 | num_point: 25
17 | num_person: 2
18 | num_gcn_scales: 13
19 | num_g3d_scales: 6
20 | graph: graph.ntu_rgb_d.AdjMatrixGraph
21 |
22 | # ablation
23 | ablation: sgcn_only
24 |
25 |
26 | # optim
27 | weight_decay: 0.0005
28 | base_lr: 0.05
29 | step: [30,40,50]
30 |
31 | num_epoch: 60
32 | device: [0,1]
33 | batch_size: 40
34 | forward_batch_size: 40
35 | test_batch_size: 40
36 | nesterov: True
37 |
38 | optimizer: SGD
39 |
40 | eval_start: 5
41 | eval_interval: 5
42 |
43 | phase: test
44 | weights: "PATH TO THE PRETRAINED MODEL"
45 | save_score: True
46 |
--------------------------------------------------------------------------------
/config/train.yaml:
--------------------------------------------------------------------------------
1 |
2 | work_dir: ./work_dir/ntu120_xsub_2021_angular_train
3 |
4 | # feeder
5 | feeder: feeders.feeder.Feeder
6 | train_feeder_args:
7 | data_path: ./data/ntu120/xsub/train_data_joint.npy
8 | label_path: ./data/ntu120/xsub/train_label.pkl
9 | debug: False
10 | random_choose: False
11 | random_shift: False
12 | random_move: False
13 | window_size: -1
14 | normalization: False
15 |
16 | test_feeder_args:
17 | data_path: ./data/ntu120/xsub/val_data_joint.npy
18 | label_path: ./data/ntu120/xsub/val_label.pkl
19 |
20 | # model
21 | model: model.network.Model
22 | model_args:
23 | in_channels: 15
24 | num_class: 120
25 | num_point: 25
26 | num_person: 2
27 | num_gcn_scales: 13
28 | num_g3d_scales: 6
29 | graph: graph.ntu_rgb_d.AdjMatrixGraph
30 |
31 | # ablation
32 | ablation: sgcn_only
33 |
34 |
35 | # optim
36 | weight_decay: 0.0005
37 | base_lr: 0.05
38 | step: [30,40,50]
39 |
40 | # training
41 | num_epoch: 60
42 | device: [0,1]
43 | batch_size: 40
44 | forward_batch_size: 40
45 | test_batch_size: 80
46 | nesterov: True
47 |
48 | # 额外的
49 | to_add_onehot: False
50 | optimizer: SGD
51 |
52 | # Pretrained models
53 | # # This is to load the pretrained model, uncomment to use.
54 | # weights: ""
55 | # checkpoint: ""
56 | # resume: True
57 |
58 | eval_start: 1
59 | eval_interval: 5
60 |
--------------------------------------------------------------------------------
/encoding/__pycache__/data_encoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/encoding/__pycache__/data_encoder.cpython-36.pyc
--------------------------------------------------------------------------------
/encoding/data_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn.functional as functional
4 |
5 |
6 | class DataRepeatEncoder:
7 | def __init__(self, rep_num):
8 | self.rep_num = rep_num
9 |
10 | def encode_data(self, a_bch, **kwargs):
11 | rtn = a_bch.repeat(1, self.rep_num, 1, 1, 1)
12 | return rtn
13 |
14 |
15 | class DataInterpolatingEncoder:
16 | def __init__(self, new_length=None):
17 | self.new_length = new_length
18 |
19 | def encode_data(self, a_bch, **kwargs):
20 | N, C, T, V, M = a_bch.size()
21 | a_bch = a_bch.permute(0, 4, 1, 3, 2).contiguous().view(N*M*C, V, T)
22 | a_bch = functional.interpolate(a_bch, size=self.new_length, mode='linear')
23 | a_bch = a_bch.view(N, M, C, V, self.new_length).permute(0, 2, 4, 3, 1)
24 | return a_bch
25 |
26 | class TrigonometricTemporalEncoder:
27 | def __init__(self, inc_type, freq_num, seq_len, is_with_orig=True):
28 | self.inc_func = inc_type
29 | self.periodic_fns = [torch.cos]
30 | self.is_with_orig = is_with_orig
31 | self.K = freq_num
32 | self.T = seq_len
33 | self.prepare_period_fns()
34 |
35 | def prepare_period_fns(self):
36 | assert self.inc_func is not None
37 | assert self.periodic_fns is not None
38 |
39 | # Get frequency values
40 | self.temp_freq_bands = []
41 | for k in range(1, self.K + 1):
42 | if self.inc_func == 'linear':
43 | a_freq = k
44 | elif self.inc_func == 'exp':
45 | a_freq = 2 ** (k - 1)
46 | elif self.inc_func == 'pow':
47 | a_freq = k ** 2
48 | else:
49 | raise NotImplementedError('Unsupported inc_func.')
50 | self.temp_freq_bands.append(math.pi / self.T * a_freq)
51 |
52 | self.temp_freq_bands = torch.tensor(self.temp_freq_bands)
53 | print('Temporal frequency components: ', self.temp_freq_bands)
54 |
55 | # Get embed functions
56 | self.temp_embed_fns = []
57 | if self.is_with_orig:
58 | self.temp_embed_fns.append(lambda x, frm_idx: x)
59 |
60 | for freq_t in self.temp_freq_bands:
61 | for p_fn in self.periodic_fns:
62 | self.temp_embed_fns.append(
63 | lambda x, frm_idx, p_fn=p_fn, freq=freq_t: (x * p_fn(freq * (frm_idx + 1 / 2)))
64 | ) # TTE
65 |
66 | def encode_data(self, a_bch, dim):
67 | t_len_all = a_bch.shape[2]
68 | time_list = []
69 | for t_idx in range(t_len_all):
70 | a_series = a_bch[:, :, t_idx, :, :].unsqueeze(2)
71 |
72 | new_time_list = []
73 | for fn in self.temp_embed_fns:
74 | a_new_one = fn(a_series, t_idx)
75 | new_time_list.append(a_new_one)
76 | new_time_list = torch.cat(new_time_list, dim)
77 |
78 | time_list.append(new_time_list)
79 | rtn = torch.cat(time_list, 2)
80 | return rtn
81 |
--------------------------------------------------------------------------------
/feeders/__init__.py:
--------------------------------------------------------------------------------
1 | from . import tools
2 | from . import feeder
3 | from . import feeder_as_gcn
4 | from . import feeder_dgnn
5 |
--------------------------------------------------------------------------------
/feeders/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder.cpython-36.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder.cpython-37.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder_as_gcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_as_gcn.cpython-36.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder_as_gcn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_as_gcn.cpython-37.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder_dgnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_dgnn.cpython-36.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/feeder_dgnn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_dgnn.cpython-37.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/tools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/tools.cpython-36.pyc
--------------------------------------------------------------------------------
/feeders/__pycache__/tools.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/tools.cpython-37.pyc
--------------------------------------------------------------------------------
/feeders/feeder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.extend(['../'])
3 |
4 | import torch
5 | import pickle
6 | import numpy as np
7 | from torch.utils.data import Dataset, TensorDataset
8 | from numpy import inf
9 |
10 | import scipy.fftpack
11 |
12 | from feeders import tools
13 |
14 |
15 | class Feeder(Dataset):
16 | def __init__(self, data_path, label_path,
17 | random_choose=False, random_shift=False, random_move=False,
18 | window_size=-1, normalization=False, debug=False, use_mmap=True,
19 | tgt_labels=None, frame_len=300, **kwargs):
20 | """
21 | :param data_path:
22 | :param label_path:
23 | :param random_choose: If true, randomly choose a portion of the input sequence
24 | :param random_shift: If true, randomly pad zeros at the begining or end of sequence
25 | :param random_move:
26 | :param window_size: The length of the output sequence
27 | :param normalization: If true, normalize input sequence
28 | :param debug: If true, only use the first 100 samples
29 | :param use_mmap: If true, use mmap mode to load data, which can save the running memory
30 | """
31 |
32 | self.debug = debug
33 | self.data_path = data_path
34 | self.label_path = label_path
35 | self.random_choose = random_choose
36 | self.random_shift = random_shift
37 | self.random_move = random_move
38 | self.window_size = window_size
39 | self.normalization = normalization
40 | self.use_mmap = use_mmap
41 | self.tgt_labels = tgt_labels
42 | self.kwargs = kwargs
43 |
44 | # other parameters
45 | self.frame_len = frame_len
46 |
47 | self.load_data()
48 |
49 | # Internal dataloader parameters
50 | # self.load_bch_sz = 2000
51 | # self.internal_dataloader = self.get_a_dataloader()
52 |
53 | if normalization:
54 | self.get_mean_map()
55 |
56 | def get_a_dataloader(self):
57 | a_dataset = TensorDataset(torch.tensor(self.data))
58 | a_dataloader = torch.utils.data.DataLoader(
59 | dataset=a_dataset,
60 | batch_size=self.load_bch_sz,
61 | shuffle=False,
62 | num_workers=4,
63 | drop_last=False
64 | )
65 | return a_dataloader
66 |
67 | def load_data(self):
68 | # data: N C T V M
69 | try:
70 | with open(self.label_path) as f:
71 | self.sample_name, self.label = pickle.load(f)
72 | except:
73 | # for pickle file from python2
74 | with open(self.label_path, 'rb') as f:
75 | self.sample_name, self.label = pickle.load(f, encoding='latin1')
76 | # self.label = np.array(self.label)
77 |
78 | # load data
79 | if self.use_mmap:
80 | self.data = np.load(self.data_path, mmap_mode='r')
81 | else:
82 | self.data = np.load(self.data_path)
83 |
84 | # Use tgt labels
85 | if self.tgt_labels is not None:
86 | self.label = np.array(self.label)
87 | tmp_data = None
88 | tmp_label = None
89 | for a_tgt_label in self.tgt_labels:
90 | selected_idxes = np.array(self.label) == a_tgt_label
91 | if tmp_data is None:
92 | tmp_data = self.data[selected_idxes]
93 | tmp_label = self.label[selected_idxes]
94 | else:
95 | tmp_data = np.concatenate((tmp_data, self.data[selected_idxes]), axis=0)
96 | tmp_label = np.concatenate((tmp_label, self.label[selected_idxes]), axis=0)
97 | self.data = tmp_data
98 | self.label = tmp_label
99 |
100 | if 'process_type' in self.kwargs:
101 | self.process_data(process_type=self.kwargs['process_type'])
102 |
103 | if self.debug:
104 | self.label = self.label[0:1000]
105 | self.data = self.data[0:1000]
106 | self.sample_name = self.sample_name[0:1000]
107 | # debug_tgt = 4206 # 4206 #
108 | # hard_samples = [9566, 13297, 15239, 13351, 11670, 8935, 2815, 9329, 15238, 8896]
109 | # self.label = list(self.label[i] for i in hard_samples)
110 | # self.data = list(self.data[i] for i in hard_samples)
111 | # self.sample_name = list(self.sample_name[i] for i in hard_samples)
112 |
113 | # Discrete cosine transform
114 | if 'dct' in self.kwargs:
115 | self.dct_data(self.kwargs['dct'])
116 | print('Discrete cosine transform completed. DCT type: ', self.kwargs['dct'])
117 |
118 | def dct_data(self, dct_op):
119 | dct_out = scipy.fftpack.dct(self.data, axis=2)
120 | if dct_op == 'overwrite':
121 | self.data = dct_out
122 | elif dct_op == 'concat':
123 | self.data = np.concatenate((self.data, dct_out), axis=1)
124 | elif dct_op == 'lengthen':
125 | dct_out = dct_out[:, :, :(self.frame_len // 2), :, :]
126 | self.data = np.concatenate((self.data, dct_out), axis=2)
127 |
128 | def process_data(self, process_type):
129 | rtn_data = []
130 | rtn_label = []
131 | if process_type == 'single_person':
132 | data_idx = 0
133 | for a_data in self.data:
134 | for ppl_id in range(self.data.shape[-1]):
135 | a_ppl = a_data[:, :, :, ppl_id]
136 | # comment the below one
137 | # if np.max(a_ppl) > 0.01:
138 |
139 | # Keep all data (some mutual data also contain zero)
140 | if np.max(a_ppl) > -1:
141 | rtn_data.append(np.expand_dims(a_ppl, axis=-1))
142 | rtn_label.append(self.label[data_idx])
143 | data_idx += 1
144 | else:
145 | raise NotImplementedError
146 | rtn_data = np.stack(rtn_data, axis=0)
147 | rtn_label = np.stack(rtn_label, axis=0)
148 |
149 | # relabel data to consider actor and receiver
150 | rtn_label = self.relabel_by_energy()
151 |
152 | self.data = rtn_data
153 | self.label = rtn_label
154 |
155 | def get_energy(self, s): # ctv
156 | index = s.sum(-1).sum(0) != 0 # select valid frames
157 | s = s[:, index, :]
158 | if len(s) != 0:
159 | s = s[0, :, :].std() + s[1, :, :].std() + s[2, :, :].std() # three channels
160 | else:
161 | s = 0
162 | return s
163 |
164 | def relabel_by_energy(self):
165 | rtn_label = []
166 | for a_idx, a_data in enumerate(self.data):
167 | person_1 = a_data[:, :, :, 0] # C,T,V
168 | person_2 = a_data[:, :, :, 1] # C,T,V
169 | energy_1 = self.get_energy(person_1)
170 | energy_2 = self.get_energy(person_2)
171 | if energy_1 > energy_2: # first kicks the second
172 | rtn_label.append((1, 0))
173 | else:
174 | rtn_label.append((0, 1))
175 | return np.concatenate(rtn_label, axis=0)
176 |
177 | def get_mean_map(self):
178 | data = self.data
179 | N, C, T, V, M = data.shape
180 | self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0)
181 | self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1))
182 |
183 | def __len__(self):
184 | return len(self.label)
185 |
186 | def __iter__(self):
187 | return self
188 |
189 | def __getitem__(self, index):
190 | data_numpy = self.data[index]
191 | label = self.label[index]
192 | data_numpy = np.array(data_numpy)
193 |
194 | if self.normalization:
195 | data_numpy = (data_numpy - self.mean_map) / self.std_map
196 | if self.random_shift:
197 | data_numpy = tools.random_shift(data_numpy)
198 | if self.random_choose:
199 | data_numpy = tools.random_choose(data_numpy, self.window_size)
200 | elif self.window_size > 0:
201 | data_numpy = tools.auto_pading(data_numpy, self.window_size)
202 | if self.random_move:
203 | data_numpy = tools.random_move(data_numpy)
204 |
205 | # Remove NAN
206 | data_numpy = np.nan_to_num(data_numpy)
207 | data_numpy[data_numpy == -inf] = 0
208 |
209 | return data_numpy, label, index
210 |
211 | def top_k(self, score, top_k):
212 | # 如果label只有一个值的话
213 | the_label = np.array(self.label)
214 | if len(the_label.shape) == 1 or the_label.shape[1] == 1:
215 | rank = score.argsort()
216 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)]
217 | return sum(hit_top_k) * 1.0 / len(hit_top_k)
218 | # label里面包含fine grain的label
219 | else:
220 | act_label = the_label[:, 0]
221 | fgr_label = the_label[:, 1]
222 | act_score, fgr_score = score
223 | rank_act = act_score.argsort()
224 | rank_fgr = fgr_score.argsort()
225 | hit_top_k_acc = [l in rank_act[i, -top_k:] for i, l in enumerate(act_label)]
226 | hit_top_k_fgr = [l in rank_fgr[i, -top_k:] for i, l in enumerate(fgr_label)]
227 | return sum(hit_top_k_acc) * 1.0 / len(hit_top_k_acc), \
228 | sum(hit_top_k_fgr) * 1.0 / len(hit_top_k_fgr)
229 |
230 |
231 | def import_class(name):
232 | components = name.split('.')
233 | mod = __import__(components[0])
234 | for comp in components[1:]:
235 | mod = getattr(mod, comp)
236 | return mod
237 |
238 |
239 | def test(data_path, label_path, vid=None, graph=None, is_3d=False):
240 | '''
241 | vis the samples using matplotlib
242 | :param data_path:
243 | :param label_path:
244 | :param vid: the id of sample
245 | :param graph:
246 | :param is_3d: when vis NTU, set it True
247 | :return:
248 | '''
249 | import matplotlib.pyplot as plt
250 | loader = torch.utils.data.DataLoader(
251 | dataset=Feeder(data_path, label_path),
252 | batch_size=64,
253 | shuffle=False,
254 | num_workers=2)
255 |
256 | if vid is not None:
257 | sample_name = loader.dataset.sample_name
258 | sample_id = [name.split('.')[0] for name in sample_name]
259 | index = sample_id.index(vid)
260 | data, label, index = loader.dataset[index]
261 | data = data.reshape((1,) + data.shape)
262 |
263 | # for batch_idx, (data, label) in enumerate(loader):
264 | N, C, T, V, M = data.shape
265 |
266 | plt.ion()
267 | fig = plt.figure()
268 | if is_3d:
269 | from mpl_toolkits.mplot3d import Axes3D
270 | ax = fig.add_subplot(111, projection='3d')
271 | else:
272 | ax = fig.add_subplot(111)
273 |
274 | if graph is None:
275 | p_type = ['b.', 'g.', 'r.', 'c.', 'm.', 'y.', 'k.', 'k.', 'k.', 'k.']
276 | pose = [
277 | ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M)
278 | ]
279 | ax.axis([-1, 1, -1, 1])
280 | for t in range(T):
281 | for m in range(M):
282 | pose[m].set_xdata(data[0, 0, t, :, m])
283 | pose[m].set_ydata(data[0, 1, t, :, m])
284 | fig.canvas.draw()
285 | plt.pause(0.001)
286 | else:
287 | p_type = ['b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-']
288 | import sys
289 | from os import path
290 | sys.path.append(
291 | path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
292 | G = import_class(graph)()
293 | edge = G.inward
294 | pose = []
295 | for m in range(M):
296 | a = []
297 | for i in range(len(edge)):
298 | if is_3d:
299 | a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0])
300 | else:
301 | a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
302 | pose.append(a)
303 | ax.axis([-1, 1, -1, 1])
304 | if is_3d:
305 | ax.set_zlim3d(-1, 1)
306 | for t in range(T):
307 | for m in range(M):
308 | for i, (v1, v2) in enumerate(edge):
309 | x1 = data[0, :2, t, v1, m]
310 | x2 = data[0, :2, t, v2, m]
311 | if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1:
312 | pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m])
313 | pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m])
314 | if is_3d:
315 | pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m])
316 | fig.canvas.draw()
317 | # plt.savefig('/home/lshi/Desktop/skeleton_sequence/' + str(t) + '.jpg')
318 | plt.pause(0.01)
319 |
320 |
321 | if __name__ == '__main__':
322 | import os
323 | os.environ['DISPLAY'] = 'localhost:10.0'
324 | data_path = "../data/ntu/xview/val_data_joint.npy"
325 | label_path = "../data/ntu/xview/val_label.pkl"
326 | graph = 'graph.ntu_rgb_d.Graph'
327 | test(data_path, label_path, vid='S004C001P003R001A032', graph=graph, is_3d=True)
328 | # data_path = "../data/kinetics/val_data.npy"
329 | # label_path = "../data/kinetics/val_label.pkl"
330 | # graph = 'graph.Kinetics'
331 | # test(data_path, label_path, vid='UOD7oll3Kqo', graph=graph)
332 |
--------------------------------------------------------------------------------
/feeders/feeder_as_gcn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.extend(['../'])
3 |
4 | import numpy as np
5 | import random
6 | import pickle
7 | import time
8 | import copy
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | import torch.nn.functional as F
14 | from torchvision import datasets, transforms
15 |
16 | from . import tools
17 |
18 |
19 | class Feeder(torch.utils.data.Dataset):
20 | def __init__(self,
21 | data_path, label_path,
22 | repeat_pad=False,
23 | random_choose=False,
24 | random_move=False,
25 | window_size=-1,
26 | debug=False,
27 | down_sample = False,
28 | mmap=True):
29 | self.debug = debug
30 | self.data_path = data_path
31 | self.label_path = label_path
32 | self.repeat_pad = repeat_pad
33 | self.random_choose = random_choose
34 | self.random_move = random_move
35 | self.window_size = window_size
36 | self.down_sample = down_sample
37 |
38 | self.load_data(mmap)
39 |
40 | def load_data(self, mmap):
41 |
42 | with open(self.label_path, 'rb') as f:
43 | self.sample_name, self.label = pickle.load(f)
44 |
45 | if mmap:
46 | self.data = np.load(self.data_path, mmap_mode='r')
47 | else:
48 | self.data = np.load(self.data_path)
49 |
50 | if self.debug:
51 | self.label = self.label[0:100]
52 | self.data = self.data[0:100]
53 | self.sample_name = self.sample_name[0:100]
54 |
55 | self.N, self.C, self.T, self.V, self.M = self.data.shape
56 |
57 | def __len__(self):
58 | return len(self.label)
59 |
60 | def __getitem__(self, index):
61 | data_numpy = np.array(self.data[index]).astype(np.float32)
62 | label = self.label[index]
63 |
64 | valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0
65 | begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax()
66 | length = end-begin
67 |
68 | if self.repeat_pad:
69 | data_numpy = tools.repeat_pading(data_numpy)
70 | if self.random_choose:
71 | data_numpy = tools.random_choose(data_numpy, self.window_size)
72 | elif self.window_size > 0:
73 | data_numpy = tools.auto_pading(data_numpy, self.window_size)
74 | if self.random_move:
75 | data_numpy = tools.random_move(data_numpy)
76 |
77 | data_last = copy.copy(data_numpy[:,-11:-10,:,:])
78 | target_data = copy.copy(data_numpy[:,-10:,:,:])
79 | input_data = copy.copy(data_numpy[:,:-10,:,:])
80 |
81 | if self.down_sample:
82 | if length<=60:
83 | input_data_dnsp = input_data[:,:50,:,:]
84 | else:
85 | rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50)))
86 | input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]]
87 | input_data_dnsp = np.array(input_data_dnsp).astype(np.float32)
88 | input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3))
89 |
90 | return input_data, input_data_dnsp, target_data, data_last, label
--------------------------------------------------------------------------------
/feeders/feeder_dgnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import torch
4 | from torch.utils.data import Dataset
5 | import sys
6 |
7 | sys.path.extend(['../'])
8 | from feeders import tools
9 |
10 |
11 | class BaseDataset(Dataset):
12 | def __init__(self, data_path, label_path,
13 | random_choose=False, random_shift=False, random_move=False,
14 | window_size=-1, normalization=False, debug=False, use_mmap=True):
15 | """
16 | :param data_path:
17 | :param label_path:
18 | :param random_choose: If true, randomly choose a portion of the input sequence
19 | :param random_shift: If true, randomly pad zeros at the begining or end of sequence
20 | :param random_move:
21 | :param window_size: The length of the output sequence
22 | :param normalization: If true, normalize input sequence
23 | :param debug: If true, only use the first 100 samples
24 | :param use_mmap: If true, use mmap mode to load data, which can save the running memory
25 | """
26 | self.debug = debug
27 | self.data_path = data_path
28 | self.label_path = label_path
29 | self.random_choose = random_choose
30 | self.random_shift = random_shift
31 | self.random_move = random_move
32 | self.window_size = window_size
33 | self.normalization = normalization
34 | self.use_mmap = use_mmap
35 | self.load_data()
36 | if normalization:
37 | self.get_mean_map()
38 |
39 | def load_data(self):
40 | # data: (N,C,V,T,M)
41 | try:
42 | with open(self.label_path) as f:
43 | self.sample_name, self.label = pickle.load(f)
44 | except:
45 | # for pickle file from python2
46 | with open(self.label_path, 'rb') as f:
47 | self.sample_name, self.label = pickle.load(f, encoding='latin1')
48 |
49 | # load data
50 | if self.use_mmap:
51 | self.data = np.load(self.data_path, mmap_mode='r')
52 | else:
53 | self.data = np.load(self.data_path)
54 | if self.debug:
55 | self.label = self.label[0:100]
56 | self.data = self.data[0:100]
57 | self.sample_name = self.sample_name[0:100]
58 |
59 | def get_mean_map(self):
60 | """Computes the mean and standard deviation of the dataset"""
61 | data = self.data
62 | N, C, T, V, M = data.shape
63 | self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0)
64 | self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1))
65 |
66 | def __len__(self):
67 | return len(self.label)
68 |
69 | def __iter__(self):
70 | return self
71 |
72 | def __getitem__(self, index):
73 | data_numpy = self.data[index]
74 | label = self.label[index]
75 | data_numpy = np.array(data_numpy)
76 |
77 | if self.normalization:
78 | data_numpy = (data_numpy - self.mean_map) / self.std_map
79 | if self.random_shift:
80 | data_numpy = tools.random_shift(data_numpy)
81 | if self.random_choose:
82 | data_numpy = tools.random_choose(data_numpy, self.window_size)
83 | elif self.window_size > 0:
84 | data_numpy = tools.auto_pading(data_numpy, self.window_size)
85 | if self.random_move:
86 | data_numpy = tools.random_move(data_numpy)
87 |
88 | return data_numpy, label, index
89 |
90 | def top_k(self, score, top_k):
91 | rank = score.argsort()
92 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)]
93 | return sum(hit_top_k) * 1.0 / len(hit_top_k)
94 |
95 |
96 | class Feeder(Dataset):
97 | def __init__(self, joint_data_path, bone_data_path, label_path,
98 | random_choose=False, random_shift=False, random_move=False,
99 | window_size=-1, normalization=False, debug=False, use_mmap=True):
100 |
101 | self.joint_dataset = BaseDataset(joint_data_path, label_path, random_choose, random_shift, random_move, window_size, normalization, debug, use_mmap)
102 | self.bone_dataset = BaseDataset(bone_data_path, label_path, random_choose, random_shift, random_move, window_size, normalization, debug, use_mmap)
103 | self.sample_name = self.joint_dataset.sample_name
104 |
105 | def __len__(self):
106 | return min(len(self.joint_dataset), len(self.bone_dataset))
107 |
108 | def __iter__(self):
109 | return self
110 |
111 | def __getitem__(self, index):
112 | joint_data, label, index = self.joint_dataset[index]
113 | bone_data, label, index = self.bone_dataset[index]
114 | # Either label is fine
115 | return joint_data, bone_data, label, index
116 |
117 | def top_k(self, score, top_k):
118 | # Either dataset can be delegate
119 | return self.joint_dataset.top_k(score, top_k)
120 |
121 |
122 | # def import_class(name):
123 | # components = name.split('.')
124 | # mod = __import__(components[0])
125 | # for comp in components[1:]:
126 | # mod = getattr(mod, comp)
127 | # return mod
128 |
129 |
130 | # def test(data_path, label_path, vid=None, graph=None, is_3d=False):
131 | # '''
132 | # vis the samples using matplotlib
133 | # :param data_path:
134 | # :param label_path:
135 | # :param vid: the id of sample
136 | # :param graph:
137 | # :param is_3d: when vis NTU, set it True
138 | # :return:
139 | # '''
140 | # import matplotlib.pyplot as plt
141 | # loader = torch.utils.data.DataLoader(
142 | # dataset=Feeder(data_path, label_path),
143 | # batch_size=64,
144 | # shuffle=False,
145 | # num_workers=2)
146 |
147 | # if vid is not None:
148 | # sample_name = loader.dataset.sample_name
149 | # sample_id = [name.split('.')[0] for name in sample_name]
150 | # index = sample_id.index(vid)
151 | # data, label, index = loader.dataset[index]
152 | # data = data.reshape((1,) + data.shape)
153 |
154 | # # for batch_idx, (data, label) in enumerate(loader):
155 | # N, C, T, V, M = data.shape
156 |
157 | # plt.ion()
158 | # fig = plt.figure()
159 | # if is_3d:
160 | # from mpl_toolkits.mplot3d import Axes3D
161 | # ax = fig.add_subplot(111, projection='3d')
162 | # else:
163 | # ax = fig.add_subplot(111)
164 |
165 | # if graph is None:
166 | # p_type = ['b.', 'g.', 'r.', 'c.', 'm.', 'y.', 'k.', 'k.', 'k.', 'k.']
167 | # pose = [
168 | # ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M)
169 | # ]
170 | # ax.axis([-1, 1, -1, 1])
171 | # for t in range(T):
172 | # for m in range(M):
173 | # pose[m].set_xdata(data[0, 0, t, :, m])
174 | # pose[m].set_ydata(data[0, 1, t, :, m])
175 | # fig.canvas.draw()
176 | # plt.pause(0.001)
177 | # else:
178 | # p_type = ['b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-']
179 | # import sys
180 | # from os import path
181 | # sys.path.append(
182 | # path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
183 | # G = import_class(graph)()
184 | # edge = G.inward
185 | # pose = []
186 | # for m in range(M):
187 | # a = []
188 | # for i in range(len(edge)):
189 | # if is_3d:
190 | # a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0])
191 | # else:
192 | # a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
193 | # pose.append(a)
194 | # ax.axis([-1, 1, -1, 1])
195 | # if is_3d:
196 | # ax.set_zlim3d(-1, 1)
197 | # for t in range(T):
198 | # for m in range(M):
199 | # for i, (v1, v2) in enumerate(edge):
200 | # x1 = data[0, :2, t, v1, m]
201 | # x2 = data[0, :2, t, v2, m]
202 | # if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1:
203 | # pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m])
204 | # pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m])
205 | # if is_3d:
206 | # pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m])
207 | # fig.canvas.draw()
208 | # # plt.savefig('/home/lshi/Desktop/skeleton_sequence/' + str(t) + '.jpg')
209 | # plt.pause(0.01)
210 |
211 |
212 | if __name__ == '__main__':
213 | pass
214 | # import os
215 | # os.environ['DISPLAY'] = 'localhost:10.0'
216 | # data_path = "../data/ntu/xview/val_data_joint.npy"
217 | # label_path = "../data/ntu/xview/val_label.pkl"
218 | # graph = 'graph.ntu_rgb_d.Graph'
219 | # test(data_path, label_path, vid='S004C001P003R001A032', graph=graph, is_3d=True)
220 | # data_path = "../data/kinetics/val_data.npy"
221 | # label_path = "../data/kinetics/val_label.pkl"
222 | # graph = 'graph.Kinetics'
223 | # test(data_path, label_path, vid='UOD7oll3Kqo', graph=graph)
224 |
--------------------------------------------------------------------------------
/feeders/tools.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 |
5 |
6 | def downsample(data_numpy, step, random_sample=True):
7 | # input: C,T,V,M
8 | begin = np.random.randint(step) if random_sample else 0
9 | return data_numpy[:, begin::step, :, :]
10 |
11 |
12 | def temporal_slice(data_numpy, step):
13 | # input: C,T,V,M
14 | C, T, V, M = data_numpy.shape
15 | return data_numpy.reshape(C, T / step, step, V, M).transpose(
16 | (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M)
17 |
18 |
19 | def mean_subtractor(data_numpy, mean):
20 | # input: C,T,V,M
21 | # naive version
22 | if mean == 0:
23 | return
24 | C, T, V, M = data_numpy.shape
25 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
26 | begin = valid_frame.argmax()
27 | end = len(valid_frame) - valid_frame[::-1].argmax()
28 | data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean
29 | return data_numpy
30 |
31 |
32 | def auto_pading(data_numpy, size, random_pad=False):
33 | C, T, V, M = data_numpy.shape
34 | if T < size:
35 | begin = random.randint(0, size - T) if random_pad else 0
36 | data_numpy_paded = np.zeros((C, size, V, M))
37 | data_numpy_paded[:, begin:begin + T, :, :] = data_numpy
38 | return data_numpy_paded
39 | else:
40 | return data_numpy
41 |
42 |
43 | def random_choose(data_numpy, size, auto_pad=True):
44 | C, T, V, M = data_numpy.shape
45 | if T == size:
46 | return data_numpy
47 | elif T < size:
48 | if auto_pad:
49 | return auto_pading(data_numpy, size, random_pad=True)
50 | else:
51 | return data_numpy
52 | else:
53 | begin = random.randint(0, T - size)
54 | return data_numpy[:, begin:begin + size, :, :]
55 |
56 |
57 | def random_move(data_numpy,
58 | angle_candidate=[-10., -5., 0., 5., 10.],
59 | scale_candidate=[0.9, 1.0, 1.1],
60 | transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2],
61 | move_time_candidate=[1]):
62 | # input: C,T,V,M
63 | C, T, V, M = data_numpy.shape
64 | move_time = random.choice(move_time_candidate)
65 | node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
66 | node = np.append(node, T)
67 | num_node = len(node)
68 |
69 | A = np.random.choice(angle_candidate, num_node)
70 | S = np.random.choice(scale_candidate, num_node)
71 | T_x = np.random.choice(transform_candidate, num_node)
72 | T_y = np.random.choice(transform_candidate, num_node)
73 |
74 | a = np.zeros(T)
75 | s = np.zeros(T)
76 | t_x = np.zeros(T)
77 | t_y = np.zeros(T)
78 |
79 | # linspace
80 | for i in range(num_node - 1):
81 | a[node[i]:node[i + 1]] = np.linspace(
82 | A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
83 | s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1],
84 | node[i + 1] - node[i])
85 | t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1],
86 | node[i + 1] - node[i])
87 | t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1],
88 | node[i + 1] - node[i])
89 |
90 | theta = np.array([[np.cos(a) * s, -np.sin(a) * s],
91 | [np.sin(a) * s, np.cos(a) * s]])
92 |
93 | # perform transformation
94 | for i_frame in range(T):
95 | xy = data_numpy[0:2, i_frame, :, :]
96 | new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))
97 | new_xy[0] += t_x[i_frame]
98 | new_xy[1] += t_y[i_frame]
99 | data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M)
100 |
101 | return data_numpy
102 |
103 |
104 | def random_shift(data_numpy):
105 | C, T, V, M = data_numpy.shape
106 | data_shift = np.zeros(data_numpy.shape)
107 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
108 | begin = valid_frame.argmax()
109 | end = len(valid_frame) - valid_frame[::-1].argmax()
110 |
111 | size = end - begin
112 | bias = random.randint(0, T - size)
113 | data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :]
114 |
115 | return data_shift
116 |
117 |
118 | def openpose_match(data_numpy):
119 | C, T, V, M = data_numpy.shape
120 | assert (C == 3)
121 | score = data_numpy[2, :, :, :].sum(axis=1)
122 | # the rank of body confidence in each frame (shape: T-1, M)
123 | rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M)
124 |
125 | # data of frame 1
126 | xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1)
127 | # data of frame 2
128 | xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M)
129 | # square of distance between frame 1&2 (shape: T-1, M, M)
130 | distance = ((xy2 - xy1) ** 2).sum(axis=2).sum(axis=0)
131 |
132 | # match pose
133 | forward_map = np.zeros((T, M), dtype=int) - 1
134 | forward_map[0] = range(M)
135 | for m in range(M):
136 | choose = (rank == m)
137 | forward = distance[choose].argmin(axis=1)
138 | for t in range(T - 1):
139 | distance[t, :, forward[t]] = np.inf
140 | forward_map[1:][choose] = forward
141 | assert (np.all(forward_map >= 0))
142 |
143 | # string data
144 | for t in range(T - 1):
145 | forward_map[t + 1] = forward_map[t + 1][forward_map[t]]
146 |
147 | # generate data
148 | new_data_numpy = np.zeros(data_numpy.shape)
149 | for t in range(T):
150 | new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[
151 | t]].transpose(1, 2, 0)
152 | data_numpy = new_data_numpy
153 |
154 | # score sort
155 | trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0)
156 | rank = (-trace_score).argsort()
157 | data_numpy = data_numpy[:, :, :, rank]
158 |
159 | return data_numpy
160 |
161 |
162 | def repeat_pading(data_numpy):
163 | data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3]
164 | for i_p, person in enumerate(data_tmp):
165 | if person.sum()==0:
166 | continue
167 | if person[0].sum()==0:
168 | index = (person.sum(-1).sum(-1)!=0)
169 | tmp = person[index].copy()
170 | person*=0
171 | person[:len(tmp)] = tmp
172 | for i_f, frame in enumerate(person):
173 | if frame.sum()==0:
174 | if person[i_f:].sum()==0:
175 | rest = len(person)-i_f
176 | num = int(np.ceil(rest/i_f))
177 | pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest]
178 | data_tmp[i_p,i_f:] = pad
179 | break
180 | data_numpy = np.transpose(data_tmp, [3,1,2,0])
181 | return data_numpy
--------------------------------------------------------------------------------
/figures/Architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/Architecture.png
--------------------------------------------------------------------------------
/figures/angle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/angle.png
--------------------------------------------------------------------------------
/figures/skeletons.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/skeletons.png
--------------------------------------------------------------------------------
/graph/__init__.py:
--------------------------------------------------------------------------------
1 | from . import tools
2 | from . import ntu_rgb_d
3 | from . import kinetics
4 | from . import azure_kinect
5 | from . import directed_ntu_rgb_d
6 |
--------------------------------------------------------------------------------
/graph/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/ang_adjs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ang_adjs.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/ang_adjs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ang_adjs.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/azure_kinect.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/azure_kinect.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/azure_kinect.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/azure_kinect.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/directed_ntu_rgb_d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/directed_ntu_rgb_d.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/directed_ntu_rgb_d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/directed_ntu_rgb_d.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/hyper_graphs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/hyper_graphs.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/hyper_graphs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/hyper_graphs.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/kinetics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/kinetics.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/kinetics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/kinetics.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/ntu_rgb_d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ntu_rgb_d.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/ntu_rgb_d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ntu_rgb_d.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/tools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/tools.cpython-36.pyc
--------------------------------------------------------------------------------
/graph/__pycache__/tools.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/tools.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/ang_adjs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from graph.tools import normalize_adjacency_matrix
5 |
6 |
7 | def get_ang_adjs(data_type):
8 | rtn_adjs = []
9 |
10 | if data_type == 'ntu':
11 | node_num = 25
12 | sym_pairs = ((21, 2), (11, 7), (18, 14), (20, 16), (24, 25), (22, 23))
13 | for a_sym_pair in sym_pairs:
14 | a_adj = np.eye(node_num)
15 | a_adj[:, a_sym_pair[0]-1] = 1
16 | a_adj[:, a_sym_pair[1]-1] = 1
17 | a_adj = torch.tensor(normalize_adjacency_matrix(a_adj))
18 | rtn_adjs.append(a_adj)
19 |
20 | return torch.cat(rtn_adjs, dim=0)
21 |
--------------------------------------------------------------------------------
/graph/azure_kinect.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '')
3 | sys.path.extend(['../'])
4 |
5 | import numpy as np
6 |
7 | from graph import tools
8 |
9 | num_node = 32
10 | self_link = [(i, i) for i in range(num_node)]
11 | inward_ori_index = [
12 | (1, 0),
13 | (2, 1),
14 | (3, 2),
15 | (4, 2),
16 | (5, 4),
17 | (6, 5),
18 | (7, 6),
19 | (8, 7),
20 | (9, 8),
21 | (10, 7),
22 | (11, 2),
23 | (12, 11),
24 | (13, 12),
25 | (14, 13),
26 | (15, 14),
27 | (16, 15),
28 | (17, 14),
29 | (18, 0),
30 | (19, 18),
31 | (20, 19),
32 | (21, 20),
33 | (22, 0),
34 | (23, 22),
35 | (24, 23),
36 | (25, 24),
37 | (26, 3),
38 | (27, 26),
39 | (28, 27),
40 | (29, 28),
41 | (30, 27),
42 | (31, 30),
43 | ]
44 | inward = [(i, j) for (i, j) in inward_ori_index]
45 | outward = [(j, i) for (i, j) in inward]
46 | neighbor = inward + outward
47 |
48 |
49 | class Graph:
50 | def __init__(self, labeling_mode='spatial'):
51 | self.A = self.get_adjacency_matrix(labeling_mode)
52 | self.num_node = num_node
53 | self.self_link = self_link
54 | self.inward = inward
55 | self.outward = outward
56 | self.neighbor = neighbor
57 |
58 | def get_adjacency_matrix(self, labeling_mode=None):
59 | if labeling_mode is None:
60 | return self.A
61 | if labeling_mode == 'spatial':
62 | A = tools.get_spatial_graph(num_node, self_link, inward, outward)
63 | else:
64 | raise ValueError()
65 | return A
66 |
67 |
68 | class AdjMatrixGraph:
69 | def __init__(self, *args, **kwargs):
70 | self.edges = neighbor
71 | self.num_nodes = num_node
72 | self.self_loops = [(i, i) for i in range(self.num_nodes)]
73 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes)
74 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes)
75 | self.A = tools.normalize_adjacency_matrix(self.A_binary)
76 |
77 |
78 | if __name__ == '__main__':
79 | import matplotlib.pyplot as plt
80 | graph = AdjMatrixGraph()
81 | A, A_binary, A_binary_with_I = graph.A, graph.A_binary, graph.A_binary_with_I
82 | f, ax = plt.subplots(1, 3)
83 | ax[0].imshow(A_binary_with_I, cmap='gray')
84 | ax[1].imshow(A_binary, cmap='gray')
85 | ax[2].imshow(A, cmap='gray')
86 | plt.show()
87 | print(A_binary_with_I.shape, A_binary.shape, A.shape)
88 |
--------------------------------------------------------------------------------
/graph/directed_ntu_rgb_d.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from scipy import special
6 |
7 | # For NTU RGB+D, assume node 21 (centre of chest)
8 | # is the "centre of gravity" mentioned in the paper
9 |
10 | num_nodes = 25
11 | epsilon = 1e-6
12 |
13 | # Directed edges: (source, target), see
14 | # https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Shahroudy_NTU_RGBD_A_CVPR_2016_paper.pdf
15 | # for node IDs, and reduce index to 0-based
16 | directed_edges = [(i-1, j-1) for i, j in [
17 | (1, 13), (1, 17), (2, 1), (3, 4), (5, 6),
18 | (6, 7), (7, 8), (8, 22), (8, 23), (9, 10),
19 | (10, 11), (11, 12), (12, 24), (12, 25), (13, 14),
20 | (14, 15), (15, 16), (17, 18), (18, 19), (19, 20),
21 | (21, 2), (21, 3), (21, 5), (21, 9),
22 | (21, 21) # Add self loop for Node 21 (the centre) to avoid singular matrices
23 | ]]
24 |
25 | # NOTE: for now, let's not add self loops since the paper didn't mention this
26 | # self_loops = [(i, i) for i in range(num_nodes)]
27 |
28 |
29 | def build_digraph_adj_list(edges: List[Tuple]) -> np.ndarray:
30 | graph = defaultdict(list)
31 | for source, target in edges:
32 | graph[source].append(target)
33 | return graph
34 |
35 |
36 | def normalize_incidence_matrix(im: np.ndarray, full_im: np.ndarray) -> np.ndarray:
37 | # NOTE:
38 | # 1. The paper assumes that the Incidence matrix is square,
39 | # so that the normalized form A @ (D ** -1) is viable.
40 | # However, if the incidence matrix is non-square, then
41 | # the above normalization won't work.
42 | # For now, move the term (D ** -1) to the front
43 | # 2. It's not too clear whether the degree matrix of the FULL incidence matrix
44 | # should be calculated, or just the target/source IMs.
45 | # However, target/source IMs are SINGULAR matrices since not all nodes
46 | # have incoming/outgoing edges, but the full IM as described by the paper
47 | # is also singular, since ±1 is used for target/source nodes.
48 | # For now, we'll stick with adding target/source IMs.
49 | degree_mat = full_im.sum(-1) * np.eye(len(full_im))
50 | # Since all nodes should have at least some edge, degree matrix is invertible
51 | inv_degree_mat = np.linalg.inv(degree_mat)
52 | return (inv_degree_mat @ im) + epsilon
53 |
54 |
55 | def build_digraph_incidence_matrix(num_nodes: int, edges: List[Tuple]) -> np.ndarray:
56 | # NOTE: For now, we won't consider all possible edges
57 | # max_edges = int(special.comb(num_nodes, 2))
58 | max_edges = len(edges)
59 | source_graph = np.zeros((num_nodes, max_edges), dtype='float32')
60 | target_graph = np.zeros((num_nodes, max_edges), dtype='float32')
61 | for edge_id, (source_node, target_node) in enumerate(edges):
62 | source_graph[source_node, edge_id] = 1.
63 | target_graph[target_node, edge_id] = 1.
64 | full_graph = source_graph + target_graph
65 | source_graph = normalize_incidence_matrix(source_graph, full_graph)
66 | target_graph = normalize_incidence_matrix(target_graph, full_graph)
67 | return source_graph, target_graph
68 |
69 |
70 | def build_digraph_adj_matrix(edges: List[Tuple]) -> np.ndarray:
71 | graph = np.zeros((num_nodes, num_nodes), dtype='float32')
72 | for edge in edges:
73 | graph[edge] = 1
74 | return graph
75 |
76 |
77 | class Graph:
78 | def __init__(self):
79 | super().__init__()
80 | self.num_nodes = num_nodes
81 | self.edges = directed_edges
82 | # Incidence matrices
83 | self.source_M, self.target_M = \
84 | build_digraph_incidence_matrix(self.num_nodes, self.edges)
85 |
86 |
87 | # TODO:
88 | # Check whether self loop should be added inside the graph
89 | # Check incidence matrix size
90 |
91 |
92 | if __name__ == "__main__":
93 | import matplotlib.pyplot as plt
94 | graph = Graph()
95 | source_M = graph.source_M
96 | target_M = graph.target_M
97 | plt.imshow(source_M, cmap='gray')
98 | plt.show()
99 | plt.imshow(target_M, cmap='gray')
100 | plt.show()
101 | print(source_M)
102 |
--------------------------------------------------------------------------------
/graph/hyper_graphs.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | from graph import tools
4 | import torch
5 |
6 | sys.path.insert(0, '')
7 | sys.path.extend(['../'])
8 |
9 |
10 | local_bone_hyper_edge_dict = {
11 | 'ntu': (
12 | (1, 17, 13), (2, 21, 1), (3, 4, 21), (4, 4, 4), (5, 21, 6), (6, 5, 7), (7, 6, 8), (8, 23, 22),
13 | (9, 10, 21), (10, 11, 9), (11, 12, 10), (12, 24, 25), (13, 1, 14), (14, 13, 15), (15, 14, 16),
14 | (16, 16, 16), (17, 18, 1), (18, 19, 17), (19, 20, 18), (20, 20, 20), (21, 9, 5), (22, 8, 23),
15 | (23, 8, 22), (24, 25, 12), (25, 24, 12)
16 | )
17 | }
18 |
19 |
20 | center_hyper_edge_dict = {
21 | 'ntu': (
22 | (1, 2, 21), (2, 2, 21), (3, 2, 21), (4, 2, 21), (5, 2, 21), (6, 2, 21), (7, 2, 21), (8, 2, 21),
23 | (9, 2, 21), (10, 2, 21), (11, 2, 21), (12, 2, 21), (13, 2, 21), (14, 2, 21), (15, 2, 21),
24 | (16, 2, 21), (17, 2, 21), (18, 2, 21), (19, 2, 21), (20, 2, 21), (21, 2, 21), (22, 2, 21),
25 | (23, 2, 21), (24, 2, 21), (25, 2, 21),
26 | )
27 | }
28 |
29 |
30 | figure_l_hyper_edge_dict = {
31 | 'ntu': (
32 | (1, 24, 25), (2, 24, 25), (3, 24, 25), (4, 24, 25), (5, 24, 25), (6, 24, 25), (7, 24, 25), (8, 24, 25),
33 | (9, 24, 25), (10, 24, 25), (11, 24, 25), (12, 24, 25), (13, 24, 25), (14, 24, 25), (15, 24, 25),
34 | (16, 24, 25), (17, 24, 25), (18, 24, 25), (19, 24, 25), (20, 24, 25), (21, 24, 25), (22, 24, 25),
35 | (23, 24, 25), (24, 24, 25), (25, 24, 25),
36 | )
37 | }
38 |
39 |
40 | figure_r_hyper_edge_dict = {
41 | 'ntu': (
42 | (1, 22, 23), (2, 22, 23), (3, 22, 23), (4, 22, 23), (5, 22, 23), (6, 22, 23), (7, 22, 23), (8, 22, 23),
43 | (9, 22, 23), (10, 22, 23), (11, 22, 23), (12, 22, 23), (13, 22, 23), (14, 22, 23), (15, 22, 23),
44 | (16, 22, 23), (17, 22, 23), (18, 22, 23), (19, 22, 23), (20, 22, 23), (21, 22, 23), (22, 22, 23),
45 | (23, 22, 23), (24, 22, 23), (25, 22, 23),
46 | )
47 | }
48 |
49 |
50 | hand_hyper_edge_dict = {
51 | 'ntu': (
52 | (1, 24, 22), (2, 24, 22), (3, 24, 22), (4, 24, 22), (5, 24, 22), (6, 24, 22), (7, 24, 22), (8, 24, 22),
53 | (9, 24, 22), (10, 24, 22), (11, 24, 22), (12, 24, 22), (13, 24, 22), (14, 24, 22), (15, 24, 22),
54 | (16, 24, 22), (17, 24, 22), (18, 24, 22), (19, 24, 22), (20, 24, 22), (21, 24, 22), (22, 24, 22),
55 | (23, 24, 22), (24, 24, 22), (25, 24, 22),
56 | )
57 | }
58 |
59 |
60 | elbow_hyper_edge = {
61 | 'ntu': (
62 | (1, 10, 6), (2, 10, 6), (3, 10, 6), (4, 10, 6), (5, 10, 6), (6, 10, 6), (7, 10, 6), (8, 10, 6),
63 | (9, 10, 6), (10, 10, 6), (11, 10, 6), (12, 10, 6), (13, 10, 6), (14, 10, 6), (15, 10, 6),
64 | (16, 10, 6), (17, 10, 6), (18, 10, 6), (19, 10, 6), (20, 10, 6), (21, 10, 6), (22, 10, 6),
65 | (23, 10, 6), (24, 10, 6), (25, 10, 6),
66 | )
67 | }
68 |
69 |
70 | foot_hyper_edge = {
71 | 'ntu': (
72 | (1, 20, 16), (2, 20, 16), (3, 20, 16), (4, 20, 16), (5, 20, 16), (6, 20, 16), (7, 20, 16), (8, 20, 16),
73 | (9, 20, 16), (10, 20, 16), (11, 20, 16), (12, 20, 16), (13, 20, 16), (14, 20, 16), (15, 20, 16),
74 | (16, 20, 16), (17, 20, 16), (18, 20, 16), (19, 20, 16), (20, 20, 16), (21, 20, 16), (22, 20, 16),
75 | (23, 20, 16), (24, 20, 16), (25, 20, 16),
76 | )
77 | }
78 |
79 |
80 | def get_hyper_edge(dataset, edge_type):
81 | if edge_type == 'local_bone':
82 | tgt_dict = local_bone_hyper_edge_dict
83 | elif edge_type == 'center':
84 | tgt_dict = center_hyper_edge_dict
85 | elif edge_type == 'figure_l':
86 | tgt_dict = figure_l_hyper_edge_dict
87 | elif edge_type == 'figure_r':
88 | tgt_dict = figure_r_hyper_edge_dict
89 | elif edge_type == 'hand':
90 | tgt_dict = hand_hyper_edge_dict
91 | elif edge_type == 'elbow':
92 | tgt_dict = elbow_hyper_edge
93 | elif edge_type == 'foot':
94 | tgt_dict = foot_hyper_edge
95 | else:
96 | raise NotImplementedError
97 | tgt_hyper_edge = tgt_dict[dataset]
98 | if 'ntu' in dataset:
99 | node_num = 25
100 | hyper_edge_adj = torch.zeros((node_num, node_num))
101 | else:
102 | raise NotImplementedError
103 | for i in range(node_num):
104 | edge_idx = 0
105 | for a_hyper_edge in tgt_hyper_edge:
106 | if (i+1) in a_hyper_edge:
107 | hyper_edge_adj[i][edge_idx] = 1
108 | edge_idx += 1
109 |
110 | return hyper_edge_adj
111 |
112 |
113 |
114 | def get_local_bone_hyper_adj(dataset):
115 | if 'ntu' in dataset:
116 | the_hyper_edge = [
117 | (17, 13),
118 | (21, 1),
119 | (4, 21),
120 | (21, 6),
121 | (5, 7),
122 | (6, 8),
123 | (23, 22),
124 | (10, 21),
125 | (11, 9),
126 | (12, 10),
127 | (24, 25),
128 | (1, 14),
129 | (13, 15),
130 | (14, 16),
131 | (18, 1),
132 | (19, 17),
133 | (20, 18),
134 | (20, 20),
135 | (9, 5),
136 | (8, 23),
137 | (8, 22),
138 | (25, 12),
139 | (24, 12)
140 | ]
141 | else:
142 | raise NotImplementedError
143 | return np.array(the_hyper_edge)
144 |
145 |
146 | def get_ntu_local_bone_neighbor():
147 | local_bone_inward_ori_index = [(25, 24), (25, 12), (24, 25), (24, 12), (12, 24), (12, 25),
148 | (11, 12), (11, 10), (10, 11), (10, 9), (9, 10), (9, 21),
149 | (21, 9), (21, 5), (5, 21), (5, 6), (6, 5), (6, 7), (7, 6), (7, 8),
150 | (8, 23), (8, 22), (22, 8), (22, 23), (23, 8), (23, 22), (3, 4),
151 | (3, 21), (2, 21), (2, 1), (1, 17), (1, 13), (17, 18), (17, 1),
152 | (18, 19), (18, 17), (19, 20), (19, 18), (13, 1), (13, 14),
153 | (14, 13), (14, 15), (15, 14), (15, 16)
154 | ]
155 | inward = [(i - 1, j - 1) for (i, j) in local_bone_inward_ori_index]
156 | outward = [(j, i) for (i, j) in inward]
157 | neighbor = inward + outward
158 | return neighbor
159 |
160 |
161 | class LocalBoneAdj:
162 | def __init__(self, dataset):
163 | if 'ntu' in dataset:
164 | num_node = 25
165 | self.edges = get_ntu_local_bone_neighbor()
166 | elif 'kinetics' in dataset:
167 | num_node = 18
168 | raise NotImplementedError
169 | else:
170 | raise NotImplementedError
171 | self.num_nodes = num_node
172 | self.self_loops = [(i, i) for i in range(self.num_nodes)]
173 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes)
174 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes)
175 |
176 |
177 | if __name__ == '__main__':
178 | rtn = get_local_bone_hyper_edges('ntu')
179 | print('rtn: ', rtn)
180 |
--------------------------------------------------------------------------------
/graph/kinetics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '')
3 | sys.path.extend(['../'])
4 |
5 | import numpy as np
6 |
7 | from graph import tools
8 |
9 | # Joint index:
10 | # {0, "Nose"}
11 | # {1, "Neck"},
12 | # {2, "RShoulder"},
13 | # {3, "RElbow"},
14 | # {4, "RWrist"},
15 | # {5, "LShoulder"},
16 | # {6, "LElbow"},
17 | # {7, "LWrist"},
18 | # {8, "RHip"},
19 | # {9, "RKnee"},
20 | # {10, "RAnkle"},
21 | # {11, "LHip"},
22 | # {12, "LKnee"},
23 | # {13, "LAnkle"},
24 | # {14, "REye"},
25 | # {15, "LEye"},
26 | # {16, "REar"},
27 | # {17, "LEar"},
28 |
29 | num_node = 18
30 | self_link = [(i, i) for i in range(num_node)]
31 | inward = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), (10, 9), (9, 8),
32 | (11, 5), (8, 2), (5, 1), (2, 1), (0, 1), (15, 0), (14, 0), (17, 15),
33 | (16, 14)]
34 | outward = [(j, i) for (i, j) in inward]
35 | neighbor = inward + outward
36 |
37 |
38 | class AdjMatrixGraph:
39 | def __init__(self, *args, **kwargs):
40 | self.num_nodes = num_node
41 | self.edges = neighbor
42 | self.self_loops = [(i, i) for i in range(self.num_nodes)]
43 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes)
44 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes)
45 |
46 |
47 | if __name__ == '__main__':
48 | graph = AdjMatrixGraph()
49 | A_binary = graph.A_binary
50 | import matplotlib.pyplot as plt
51 | print(A_binary)
52 | plt.matshow(A_binary)
53 | plt.show()
54 |
--------------------------------------------------------------------------------
/graph/ntu_rgb_d.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '')
3 | sys.path.extend(['../'])
4 |
5 | import numpy as np
6 |
7 | from graph import tools
8 |
9 | num_node = 25
10 | self_link = [(i, i) for i in range(num_node)]
11 | inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6),
12 | (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1),
13 | (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
14 | (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)]
15 | inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
16 | outward = [(j, i) for (i, j) in inward]
17 | neighbor = inward + outward
18 |
19 |
20 | class AdjMatrixGraph:
21 | def __init__(self, *args, **kwargs):
22 | self.edges = neighbor
23 | self.num_nodes = num_node
24 | self.self_loops = [(i, i) for i in range(self.num_nodes)]
25 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes)
26 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes)
27 | self.A = tools.normalize_adjacency_matrix(self.A_binary)
28 |
29 |
30 | if __name__ == '__main__':
31 | import matplotlib.pyplot as plt
32 | graph = AdjMatrixGraph()
33 | A, A_binary, A_binary_with_I = graph.A, graph.A_binary, graph.A_binary_with_I
34 | f, ax = plt.subplots(1, 3)
35 | ax[0].imshow(A_binary_with_I, cmap='gray')
36 | ax[1].imshow(A_binary, cmap='gray')
37 | ax[2].imshow(A, cmap='gray')
38 | plt.show()
39 | print(A_binary_with_I.shape, A_binary.shape, A.shape)
40 |
41 |
42 | class Graph:
43 | def __init__(self, labeling_mode='spatial'):
44 | self.A = self.get_adjacency_matrix(labeling_mode)
45 | self.num_node = num_node
46 | self.self_link = self_link
47 | self.inward = inward
48 | self.outward = outward
49 | self.neighbor = neighbor
50 |
51 | def get_adjacency_matrix(self, labeling_mode=None):
52 | if labeling_mode is None:
53 | return self.A
54 | if labeling_mode == 'spatial':
55 | A = tools.get_spatial_graph(num_node, self_link, inward, outward)
56 | else:
57 | raise ValueError()
58 | return A
59 |
--------------------------------------------------------------------------------
/graph/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def edge2mat(link, num_node):
5 | A = np.zeros((num_node, num_node))
6 | for i, j in link:
7 | A[j, i] = 1
8 | return A
9 |
10 |
11 | def normalize_digraph(A):
12 | Dl = np.sum(A, 0)
13 | h, w = A.shape
14 | Dn = np.zeros((w, w))
15 | for i in range(w):
16 | if Dl[i] > 0:
17 | Dn[i, i] = Dl[i] ** (-1)
18 | AD = np.dot(A, Dn)
19 | return AD
20 |
21 |
22 | def get_spatial_graph(num_node, self_link, inward, outward):
23 | I = edge2mat(self_link, num_node)
24 | In = normalize_digraph(edge2mat(inward, num_node))
25 | Out = normalize_digraph(edge2mat(outward, num_node))
26 | A = np.stack((I, In, Out))
27 | return A
28 |
29 |
30 | def k_adjacency(A, k, with_self=False, self_factor=1):
31 | assert isinstance(A, np.ndarray)
32 | I = np.eye(len(A), dtype=A.dtype)
33 | if k == 0:
34 | return I
35 | Ak = np.minimum(np.linalg.matrix_power(A + I, k), 1) \
36 | - np.minimum(np.linalg.matrix_power(A + I, k - 1), 1)
37 | if with_self:
38 | Ak += (self_factor * I)
39 | return Ak
40 |
41 |
42 | def normalize_adjacency_matrix(A):
43 | node_degrees = A.sum(-1)
44 | degs_inv_sqrt = np.power(node_degrees, -0.5)
45 | norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt
46 | return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32)
47 |
48 |
49 | def get_adjacency_matrix(edges, num_nodes):
50 | A = np.zeros((num_nodes, num_nodes), dtype=np.float32)
51 | for edge in edges:
52 | A[edge] = 1.
53 | return A
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import network
2 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/activation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/activation.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/att_gcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/att_gcn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/hyper_gcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/hyper_gcn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/mlp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/mlp.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ms_gcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_gcn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ms_gtcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_gtcn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ms_tcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_tcn.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/network.cpython-36.pyc
--------------------------------------------------------------------------------
/model/activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from model.modules import Sine
5 |
6 |
7 | def activation_factory(name, inplace=True):
8 | if name == 'relu':
9 | return nn.ReLU(inplace=inplace)
10 | elif name == 'leakyrelu':
11 | return nn.LeakyReLU(0.2, inplace=inplace)
12 | elif name == 'tanh':
13 | return nn.Tanh()
14 | elif name == 'linear' or name is None:
15 | return nn.Identity()
16 | elif name == 'sine':
17 | return Sine()
18 | else:
19 | raise ValueError('Not supported activation:', name)
--------------------------------------------------------------------------------
/model/att_gcn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 |
4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder
5 |
6 | from graph.ang_adjs import get_ang_adjs
7 | from graph.hyper_graphs import get_hyper_edge
8 |
9 | sys.path.insert(0, '')
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import numpy as np
15 |
16 | from graph.tools import k_adjacency, normalize_adjacency_matrix
17 | from model.mlp import MLP
18 | from model.activation import activation_factory
19 |
20 |
21 | class Att_GraphConv(nn.Module):
22 | def __init__(self,
23 | in_channels,
24 | out_channels,
25 | dropout=0,
26 | activation='relu',
27 | **kwargs):
28 | super().__init__()
29 |
30 | self.local_bone_hyper_edges = get_hyper_edge('ntu', 'local_bone')
31 | self.center_hyper_edges = get_hyper_edge('ntu', 'center')
32 | self.figure_l_hyper_edges = get_hyper_edge('ntu', 'figure_l')
33 | self.figure_r_hyper_edges = get_hyper_edge('ntu', 'figure_r')
34 | self.hand_hyper_edges = get_hyper_edge('ntu', 'hand')
35 | # self.foot_hyper_edges = get_hyper_edge('ntu', 'foot')
36 |
37 | self.hyper_edge_num = 8
38 | self.in_fea_mlp = MLP(self.hyper_edge_num, [50, out_channels], dropout=dropout, activation=activation)
39 | self.in_fea_mlp_last = nn.Conv2d(out_channels, out_channels, kernel_size=1)
40 |
41 | def process_hyper_edge_w(self, he_w, device):
42 | he_w = he_w.repeat(1, 1, 1, he_w.shape[-2])
43 | for i in range(he_w.shape[0]):
44 | for j in range(he_w.shape[1]):
45 | he_w[i][j] *= torch.eye(he_w.shape[-2]).to(device)
46 | return he_w
47 |
48 | def normalized_aggregate(self, w, h):
49 | degree_v = torch.einsum('ve,bte->btv', h, w)
50 | degree_e = torch.sum(h, dim=0)
51 | degree_v = torch.pow(degree_v, -0.5)
52 | # degree_v = torch.pow(degree_v, -1)
53 | degree_e = torch.pow(degree_e, -1)
54 | degree_v[degree_v == float("Inf")] = 0
55 | degree_v[degree_v != degree_v] = 0
56 | degree_e[degree_e == float("Inf")] = 0
57 | degree_e[degree_e != degree_e] = 0
58 | dh = torch.einsum('btv,ve->btve', degree_v, h)
59 | dhw = torch.einsum('btve,bte->btve', dh, w)
60 | dhwb = torch.einsum('btve,e->btve', dhw, degree_e)
61 | dhwbht = torch.einsum('btve,eu->btvu', dhwb, torch.transpose(h, 0, 1))
62 | dhwbhtd = torch.einsum('btvu,btu->btvu', dhwbht, degree_v)
63 | if torch.max(dhwbhtd).item() != torch.max(dhwbhtd).item():
64 | print('max h: ', torch.max(h).item(), 'min h: ', torch.min(h).item())
65 | print('max w: ', torch.max(w).item(), 'min w: ', torch.min(w).item())
66 | print('max degree v: ', torch.max(degree_v).item(), 'min degree v: ', torch.min(degree_v).item())
67 | print('max degree e: ', torch.max(degree_e).item(), 'min degree e: ', torch.min(degree_e).item())
68 | print('max dh: ', torch.max(dh).item(), 'min dh: ', torch.min(dh).item())
69 | print('max dhw: ', torch.max(dhw).item(), 'min dhw: ', torch.min(dhw).item())
70 | print('max dhwb: ', torch.max(dhwb).item(), 'min dhwb: ', torch.min(dhwb).item())
71 | print('max dhwbht: ', torch.max(dhwbht).item(), 'min dhwbht: ', torch.min(dhwbht).item())
72 | print('max dhwbhtd: ', torch.max(dhwbhtd).item(), 'min dhwbhtd: ', torch.min(dhwbhtd).item())
73 | assert 0
74 |
75 | # dhwbhtd[dhwbhtd != dhwbhtd] = 0
76 | return dhwbhtd
77 |
78 | def att_convolve(self, x):
79 | cor_w = x[:, :3, :, :]
80 | local_bone_w = x[:, 6, :, :].unsqueeze(1) # 6
81 | center_w = x[:, 7, :, :].unsqueeze(1) # 7
82 | figure_l_w = x[:, 9, :, :].unsqueeze(1) # 9
83 | figure_r_w = x[:, 10, :, :].unsqueeze(1) # 10
84 | hand_w = x[:, 11, :, :].unsqueeze(1) # 11
85 |
86 | # Make channels more
87 | in_fea = torch.cat((cor_w, local_bone_w, center_w, figure_l_w, figure_r_w, hand_w), dim=1)
88 | in_fea = self.in_fea_mlp(in_fea)
89 | in_fea = self.in_fea_mlp_last(in_fea)
90 | in_fea = in_fea.permute(0, 2, 3, 1)
91 |
92 | in_fea = torch.einsum('btvm,btmu->btvu', in_fea, in_fea.permute(0, 1, 3, 2))
93 | in_fea = torch.softmax(in_fea, dim=-1)
94 | return in_fea
95 |
96 | def forward(self, x):
97 | return self.att_convolve(x)
98 |
99 |
100 | if __name__ == "__main__":
101 | from graph.ntu_rgb_d import AdjMatrixGraph
102 |
103 | graph = AdjMatrixGraph()
104 | A_binary = graph.A_binary
105 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary)
106 | msgcn.forward(torch.randn(16, 3, 30, 25))
107 |
--------------------------------------------------------------------------------
/model/hyper_gcn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 |
4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder
5 |
6 | from graph.ang_adjs import get_ang_adjs
7 | from graph.hyper_graphs import get_hyper_edge
8 |
9 | sys.path.insert(0, '')
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import numpy as np
15 |
16 | from graph.tools import k_adjacency, normalize_adjacency_matrix
17 | from model.mlp import MLP
18 | from model.activation import activation_factory
19 |
20 |
21 | class Hyper_GraphConv(nn.Module):
22 | def __init__(self,
23 | in_channels,
24 | out_channels,
25 | dropout=0,
26 | activation='relu',
27 | **kwargs):
28 | super().__init__()
29 |
30 | self.local_bone_hyper_edges = get_hyper_edge('ntu', 'local_bone')
31 | self.center_hyper_edges = get_hyper_edge('ntu', 'center')
32 | self.figure_l_hyper_edges = get_hyper_edge('ntu', 'figure_l')
33 | self.figure_r_hyper_edges = get_hyper_edge('ntu', 'figure_r')
34 | self.hand_hyper_edges = get_hyper_edge('ntu', 'hand')
35 | # self.foot_hyper_edges = get_hyper_edge('ntu', 'foot')
36 |
37 | self.hyper_edge_num = 5
38 |
39 | self.mlp = MLP(in_channels * self.hyper_edge_num, [out_channels], dropout=dropout, activation=activation)
40 |
41 | self.fea_mlp = MLP(self.hyper_edge_num, [50, 50, self.hyper_edge_num], dropout=dropout, activation=activation)
42 |
43 | def process_hyper_edge_w(self, he_w, device):
44 | he_w = he_w.repeat(1, 1, 1, he_w.shape[-2])
45 | for i in range(he_w.shape[0]):
46 | for j in range(he_w.shape[1]):
47 | he_w[i][j] *= torch.eye(he_w.shape[-2]).to(device)
48 | return he_w
49 |
50 | def normalized_aggregate(self, w, h):
51 | degree_v = torch.einsum('ve,bte->btv', h, w)
52 | degree_e = torch.sum(h, dim=0)
53 | degree_v = torch.pow(degree_v, -0.5)
54 | # degree_v = torch.pow(degree_v, -1)
55 | degree_e = torch.pow(degree_e, -1)
56 | degree_v[degree_v == float("Inf")] = 0
57 | degree_v[degree_v != degree_v] = 0
58 | degree_e[degree_e == float("Inf")] = 0
59 | degree_e[degree_e != degree_e] = 0
60 | dh = torch.einsum('btv,ve->btve', degree_v, h)
61 | dhw = torch.einsum('btve,bte->btve', dh, w)
62 | dhwb = torch.einsum('btve,e->btve', dhw, degree_e)
63 | dhwbht = torch.einsum('btve,eu->btvu', dhwb, torch.transpose(h, 0, 1))
64 | dhwbhtd = torch.einsum('btvu,btu->btvu', dhwbht, degree_v)
65 | if torch.max(dhwbhtd).item() != torch.max(dhwbhtd).item():
66 | print('max h: ', torch.max(h).item(), 'min h: ', torch.min(h).item())
67 | print('max w: ', torch.max(w).item(), 'min w: ', torch.min(w).item())
68 | print('max degree v: ', torch.max(degree_v).item(), 'min degree v: ', torch.min(degree_v).item())
69 | print('max degree e: ', torch.max(degree_e).item(), 'min degree e: ', torch.min(degree_e).item())
70 | print('max dh: ', torch.max(dh).item(), 'min dh: ', torch.min(dh).item())
71 | print('max dhw: ', torch.max(dhw).item(), 'min dhw: ', torch.min(dhw).item())
72 | print('max dhwb: ', torch.max(dhwb).item(), 'min dhwb: ', torch.min(dhwb).item())
73 | print('max dhwbht: ', torch.max(dhwbht).item(), 'min dhwbht: ', torch.min(dhwbht).item())
74 | print('max dhwbhtd: ', torch.max(dhwbhtd).item(), 'min dhwbhtd: ', torch.min(dhwbhtd).item())
75 | assert 0
76 |
77 | # dhwbhtd[dhwbhtd != dhwbhtd] = 0
78 | return dhwbhtd
79 |
80 | def hyper_edge_convolve(self, x):
81 | self.local_bone_hyper_edges = self.local_bone_hyper_edges.to(x.device)
82 | self.center_hyper_edges = self.center_hyper_edges.to(x.device)
83 | self.figure_l_hyper_edges = self.figure_l_hyper_edges.to(x.device)
84 | self.figure_r_hyper_edges = self.figure_r_hyper_edges.to(x.device)
85 | self.hand_hyper_edges = self.hand_hyper_edges.to(x.device)
86 | # self.foot_hyper_edges = self.foot_hyper_edges.to(x.device)
87 |
88 | # Not make the angular feature learnable
89 | # print('max x: ', torch.max(x).item(), 'min x: ', torch.min(x).item())
90 | local_bone_w = x[:, 6, :, :] # 6
91 | center_w = x[:, 7, :, :] # 7
92 | figure_l_w = x[:, 9, :, :] # 9
93 | figure_r_w = x[:, 10, :, :] # 10
94 | hand_w = x[:, 11, :, :] # 11
95 | # foot_w = x[:, 14, :, :].unsqueeze(1) # 14
96 |
97 | # Makes the angular feature learnble
98 | # local_bone_w = x[:, 0, :, :].unsqueeze(1) # 6
99 | # center_w = x[:, 0, :, :].unsqueeze(1) # 7
100 | # figure_l_w = x[:, 0, :, :].unsqueeze(1) # 9
101 | # figure_r_w = x[:, 0, :, :].unsqueeze(1) # 10
102 | # hand_w = x[:, 0, :, :].unsqueeze(1) # 11
103 | # # foot_w = x[:, 14, :, :].unsqueeze(1) # 14
104 | #
105 | # fea_w_cat = torch.cat((local_bone_w, center_w, figure_l_w, figure_r_w,
106 | # hand_w), dim=1)
107 | # fea_w_cat = self.fea_mlp(fea_w_cat)
108 | # local_bone_w = fea_w_cat[:, 0, :, :]
109 | # center_w = fea_w_cat[:, 1, :, :]
110 | # figure_l_w = fea_w_cat[:, 2, :, :]
111 | # figure_r_w = fea_w_cat[:, 3, :, :]
112 | # hand_w = fea_w_cat[:, 4, :, :]
113 | # # foot_w = fea_w_cat[:, 5, :, :]
114 |
115 | # local bone angle
116 | # local_bone_hwh = torch.einsum('ve,bte->btve', self.local_bone_hyper_edges, local_bone_w)
117 | # local_bone_hwh = torch.einsum('btvu,un->btvn', local_bone_hwh,
118 | # torch.transpose(self.local_bone_hyper_edges, 0, 1))
119 | local_bone_hwh = self.normalized_aggregate(local_bone_w, self.local_bone_hyper_edges)
120 |
121 | # center angle
122 | # center_hwh = torch.einsum('ve,bte->btve', self.center_hyper_edges, center_w)
123 | # center_hwh = torch.einsum('btvu,un->btvn', center_hwh,
124 | # torch.transpose(self.center_hyper_edges, 0, 1))
125 | center_hwh = self.normalized_aggregate(center_w, self.center_hyper_edges)
126 |
127 | # figure left angle
128 | # figure_l_hwh = torch.einsum('ve,bte->btve', self.figure_l_hyper_edges, figure_l_w)
129 | # figure_l_hwh = torch.einsum('btvu,un->btvn', figure_l_hwh,
130 | # torch.transpose(self.figure_l_hyper_edges, 0, 1))
131 | figure_l_hwh = self.normalized_aggregate(figure_l_w, self.figure_l_hyper_edges)
132 |
133 | # figure right angle
134 | # figure_r_hwh = torch.einsum('ve,bte->btve', self.figure_r_hyper_edges, figure_r_w)
135 | # figure_r_hwh = torch.einsum('btvu,un->btvn', figure_r_hwh,
136 | # torch.transpose(self.figure_r_hyper_edges, 0, 1))
137 | figure_r_hwh = self.normalized_aggregate(figure_r_w, self.figure_r_hyper_edges)
138 |
139 | # hand angle
140 | # hand_hwh = torch.einsum('ve,bte->btve', self.hand_hyper_edges, hand_w)
141 | # hand_hwh = torch.einsum('btvu,un->btvn', hand_hwh,
142 | # torch.transpose(self.hand_hyper_edges, 0, 1))
143 | hand_hwh = self.normalized_aggregate(hand_w, self.hand_hyper_edges)
144 |
145 | # foot angle
146 | # foot_hwh = torch.einsum('ve,bte->btve', self.foot_hyper_edges, foot_w)
147 | # foot_hwh = torch.einsum('btvu,un->btvn', foot_hwh,
148 | # torch.transpose(self.foot_hyper_edges, 0, 1))
149 |
150 | hwh_cat = torch.cat((local_bone_hwh, center_hwh, figure_l_hwh, figure_r_hwh,
151 | hand_hwh), dim=-2)
152 |
153 | # Softmax normalization
154 | # hwh_cat = torch.softmax(hwh_cat, dim=-1)
155 |
156 | # Use partial features
157 | x = x[:, :3, :, :]
158 | N, C, T, V = x.shape
159 |
160 | support = torch.einsum('btvu,bctu->bctv', hwh_cat, x)
161 | support = support.view(N, C, T, self.hyper_edge_num, V)
162 | support = support.permute(0, 3, 1, 2, 4).contiguous().view(N, self.hyper_edge_num * C, T, V)
163 | out = self.mlp(support)
164 |
165 | return out
166 |
167 | def forward(self, x):
168 | return self.hyper_edge_convolve(x)
169 |
170 |
171 | if __name__ == "__main__":
172 | from graph.ntu_rgb_d import AdjMatrixGraph
173 |
174 | graph = AdjMatrixGraph()
175 | A_binary = graph.A_binary
176 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary)
177 | msgcn.forward(torch.randn(16, 3, 30, 25))
178 |
--------------------------------------------------------------------------------
/model/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from model.activation import activation_factory
6 | from utils_dir.utils_visual import plot_multiple_lines
7 |
8 |
9 | class MLP(nn.Module):
10 | def __init__(self, in_channels, out_channels, activation='relu', dropout=0):
11 | super().__init__()
12 | channels = [in_channels] + out_channels
13 | self.layers = nn.ModuleList()
14 | for i in range(1, len(channels)):
15 | if dropout > 0.001:
16 | self.layers.append(nn.Dropout(p=dropout))
17 | self.layers.append(nn.Conv2d(channels[i-1], channels[i], kernel_size=1))
18 | self.layers.append(nn.BatchNorm2d(channels[i]))
19 | self.layers.append(activation_factory(activation, inplace=False))
20 |
21 | def forward(self, x):
22 | # Input shape: (N,C,T,V)
23 | # 这里是学习同一个joint的不同尺度的信息.
24 | # the_mlp_w = torch.sum(self.layers[0].weight, dim=0).squeeze().view(13, -1).sum(0).cpu().numpy()
25 | # print('the_mlp_w: ', the_mlp_w)
26 | # plot_multiple_lines([the_mlp_w])
27 | # assert 0
28 | for layer in self.layers:
29 | x = layer(x)
30 | return x
31 |
32 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | from collections import OrderedDict
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Sine(nn.Module):
10 | def __init(self):
11 | super().__init__()
12 |
13 | def forward(self, input):
14 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
15 | # return torch.sin(30 * input)
16 | return torch.sin(input)
17 |
18 |
19 | def sine_init(m):
20 | with torch.no_grad():
21 | if hasattr(m, 'weight'):
22 | num_input = m.weight.size(-1)
23 | # See supplement Sec. 1.5 for discussion of factor 30
24 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
25 |
26 |
27 | def first_layer_sine_init(m):
28 | with torch.no_grad():
29 | if hasattr(m, 'weight'):
30 | num_input = m.weight.size(-1)
31 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
32 | m.weight.uniform_(-1 / num_input, 1 / num_input)
33 |
34 |
35 | ###################
36 | # Complex operators
37 | def compl_conj(x):
38 | y = x.clone()
39 | y[..., 1::2] = -1 * y[..., 1::2]
40 | return y
41 |
42 |
43 | def compl_div(x, y):
44 | ''' x / y '''
45 | a = x[..., ::2]
46 | b = x[..., 1::2]
47 | c = y[..., ::2]
48 | d = y[..., 1::2]
49 |
50 | outr = (a * c + b * d) / (c ** 2 + d ** 2)
51 | outi = (b * c - a * d) / (c ** 2 + d ** 2)
52 | out = torch.zeros_like(x)
53 | out[..., ::2] = outr
54 | out[..., 1::2] = outi
55 | return out
56 |
57 |
58 | def compl_mul(x, y):
59 | ''' x * y '''
60 | a = x[..., ::2]
61 | b = x[..., 1::2]
62 | c = y[..., ::2]
63 | d = y[..., 1::2]
64 |
65 | outr = a * c - b * d
66 | outi = (a + b) * (c + d) - a * c - b * d
67 | out = torch.zeros_like(x)
68 | out[..., ::2] = outr
69 | out[..., 1::2] = outi
70 | return out
71 |
--------------------------------------------------------------------------------
/model/ms_gcn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 |
4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder
5 |
6 | from graph.ang_adjs import get_ang_adjs
7 | from model.hyper_gcn import Hyper_GraphConv
8 |
9 | sys.path.insert(0, '')
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import numpy as np
15 |
16 | from graph.tools import k_adjacency, normalize_adjacency_matrix
17 | from model.mlp import MLP
18 | from model.activation import activation_factory
19 |
20 |
21 | class MultiScale_GraphConv(nn.Module):
22 | def __init__(self,
23 | num_scales,
24 | in_channels,
25 | out_channels,
26 | A_binary,
27 | disentangled_agg=True,
28 | use_mask=True,
29 | dropout=0,
30 | activation='relu',
31 | to_use_hyper_conv=False,
32 | **kwargs):
33 | super().__init__()
34 | self.num_scales = num_scales
35 |
36 | if disentangled_agg:
37 | A_powers = [k_adjacency(A_binary, k, with_self=True) for k in range(num_scales)]
38 | A_powers = np.concatenate([normalize_adjacency_matrix(g) for g in A_powers])
39 | else:
40 | A_powers = [A_binary + np.eye(len(A_binary)) for k in range(num_scales)]
41 | A_powers = [normalize_adjacency_matrix(g) for g in A_powers]
42 | A_powers = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_powers)]
43 | A_powers = np.concatenate(A_powers)
44 |
45 | self.A_powers = torch.Tensor(A_powers)
46 |
47 | if 'hyper_conv' in kwargs and kwargs['hyper_conv'] == 'ntu':
48 | hyper_adjs = get_ang_adjs('ntu')
49 | self.A_powers = torch.cat((self.A_powers, hyper_adjs), dim=0)
50 | if kwargs['hyper_conv'] == 'ntu':
51 | self.num_scales += 6
52 | elif kwargs['hyper_conv'] == 'kinetics':
53 | self.num_scales += 4
54 |
55 | # self.A_powers_param = torch.nn.Parameter(self.A_powers)
56 |
57 | self.use_mask = use_mask
58 | if use_mask:
59 | # NOTE: the inclusion of residual mask appears to slow down training noticeably
60 | self.A_res = nn.init.uniform_(nn.Parameter(torch.Tensor(self.A_powers.shape)), -1e-6, 1e-6)
61 |
62 | # 这个MLP根本就不是MLP, 这是卷积, 只是类似于MLP的功能.
63 | self.mlp = MLP(in_channels * self.num_scales, [out_channels], dropout=dropout, activation=activation)
64 |
65 | # Spatial Transformer Attention
66 | if 'to_use_spatial_transformer' in kwargs and kwargs['to_use_spatial_transformer']:
67 | self.to_use_spatial_trans = True
68 | self.trans_conv = nn.Conv2d(out_channels, 1, (1, 1), (1, 1))
69 | self.temporal_len = kwargs['temporal_len']
70 | nhead = 5
71 | nlayers = 2
72 | trans_dropout = 0.5
73 | encoder_layers = nn.TransformerEncoderLayer(self.temporal_len,
74 | nhead=nhead, dropout=trans_dropout)
75 | self.trans_enc = nn.TransformerEncoder(encoder_layers, nlayers)
76 |
77 | # spatial point normalization
78 | self.point_norm_layer = nn.Sigmoid()
79 |
80 | else:
81 | self.to_use_spatial_trans = False
82 |
83 | if 'to_use_sptl_trans_feature' in kwargs and kwargs['to_use_sptl_trans_feature']:
84 | self.to_use_sptl_trans_feature = True
85 | self.fea_dim = kwargs['fea_dim']
86 | encoder_layers = nn.TransformerEncoderLayer(self.fea_dim,
87 | nhead=kwargs['sptl_trans_feature_n_head'],
88 | dropout=0.5)
89 | self.trans_enc_fea = nn.TransformerEncoder(encoder_layers,
90 | kwargs['sptl_trans_feature_n_layer'])
91 | else:
92 | self.to_use_sptl_trans_feature = False
93 |
94 | def forward(self, x):
95 | N, C, T, V = x.shape
96 | self.A_powers = self.A_powers.to(x.device)
97 |
98 | A = self.A_powers.to(x.dtype)
99 | if self.use_mask:
100 | A = A + self.A_res.to(x.dtype)
101 |
102 | support = torch.einsum('vu,nctu->nctv', A, x)
103 |
104 | support = support.view(N, C, T, self.num_scales, V)
105 | support = support.permute(0, 3, 1, 2, 4).contiguous().view(N, self.num_scales * C, T, V)
106 |
107 | out = self.mlp(support)
108 |
109 | # 实现kernel中, 只实现了一半.
110 | # out = torch.einsum('nijtv,njktv->niktv', out.unsqueeze(2), out.unsqueeze(1)).view(
111 | # N, self.out_channels * self.out_channels, T, V
112 | # )
113 |
114 | if self.to_use_spatial_trans:
115 | out_mean = self.trans_conv(out).squeeze()
116 | out_mean = out_mean.permute(0, 2, 1)
117 | out_mean = self.trans_enc(out_mean)
118 | out_mean = self.point_norm_layer(out_mean)
119 | out_mean = out_mean.permute(0, 2, 1)
120 | out_mean = torch.unsqueeze(out_mean, dim=1).repeat(1, out.shape[1], 1, 1)
121 | out = out_mean * out
122 |
123 | if self.to_use_sptl_trans_feature:
124 | out = out.permute(2, 3, 0, 1)
125 | for a_out_idx in range(len(out)):
126 | a_out = out[a_out_idx]
127 | a_out = self.trans_enc_fea(a_out)
128 | out[a_out_idx] = a_out
129 | out = out.permute(2, 3, 0, 1)
130 |
131 | return out
132 |
133 |
134 | if __name__ == "__main__":
135 | from graph.ntu_rgb_d import AdjMatrixGraph
136 |
137 | graph = AdjMatrixGraph()
138 | A_binary = graph.A_binary
139 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary)
140 | msgcn.forward(torch.randn(16, 3, 30, 25))
141 |
--------------------------------------------------------------------------------
/model/ms_gtcn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '')
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 |
9 | from model.ms_tcn import MultiScale_TemporalConv as MS_TCN
10 | from model.mlp import MLP
11 | from model.activation import activation_factory
12 | from graph.tools import k_adjacency, normalize_adjacency_matrix
13 |
14 |
15 | class UnfoldTemporalWindows(nn.Module):
16 | def __init__(self, window_size, window_stride, window_dilation=1):
17 | super().__init__()
18 | self.window_size = window_size
19 | self.window_stride = window_stride
20 | self.window_dilation = window_dilation
21 |
22 | self.padding = (window_size + (window_size-1) * (window_dilation-1) - 1) // 2
23 | # unfold之后的结果就是: batch size x receptive filed size x 多少个receptive field
24 | self.unfold = nn.Unfold(kernel_size=(self.window_size, 1),
25 | dilation=(self.window_dilation, 1),
26 | stride=(self.window_stride, 1),
27 | padding=(self.padding, 0))
28 |
29 | def forward(self, x):
30 | # Input shape: (N,C,T,V), out: (N,C,T,V*window_size)
31 | N, C, T, V = x.shape
32 | x = self.unfold(x)
33 | # Permute extra channels from window size to the graph dimension; -1 for number of windows
34 | x = x.view(N, C, self.window_size, -1, V).permute(0,1,3,2,4).contiguous()
35 | x = x.view(N, C, -1, self.window_size * V)
36 | return x
37 |
38 |
39 | class SpatialTemporal_MS_GCN(nn.Module):
40 | def __init__(self,
41 | in_channels,
42 | out_channels,
43 | A_binary,
44 | num_scales,
45 | window_size,
46 | disentangled_agg=True,
47 | use_Ares=True,
48 | residual=False,
49 | dropout=0,
50 | activation='relu'):
51 |
52 | super().__init__()
53 | self.num_scales = num_scales
54 | self.window_size = window_size
55 | self.use_Ares = use_Ares
56 | A = self.build_spatial_temporal_graph(A_binary, window_size)
57 |
58 | if disentangled_agg:
59 | A_scales = [k_adjacency(A, k, with_self=True) for k in range(num_scales)]
60 | A_scales = np.concatenate([normalize_adjacency_matrix(g) for g in A_scales])
61 | else:
62 | # Self-loops have already been included in A
63 | A_scales = [normalize_adjacency_matrix(A) for k in range(num_scales)]
64 | A_scales = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_scales)]
65 | A_scales = np.concatenate(A_scales)
66 |
67 | self.A_scales = torch.Tensor(A_scales)
68 | self.V = len(A_binary)
69 |
70 | if use_Ares:
71 | self.A_res = nn.init.uniform_(nn.Parameter(torch.randn(self.A_scales.shape)), -1e-6, 1e-6)
72 | else:
73 | self.A_res = torch.tensor(0)
74 |
75 | self.mlp = MLP(in_channels * num_scales, [out_channels], dropout=dropout, activation='linear')
76 |
77 | # Residual connection
78 | if not residual:
79 | self.residual = lambda x: 0
80 | elif (in_channels == out_channels):
81 | self.residual = lambda x: x
82 | else:
83 | self.residual = MLP(in_channels, [out_channels], activation='linear')
84 |
85 | self.act = activation_factory(activation)
86 |
87 | def build_spatial_temporal_graph(self, A_binary, window_size):
88 | assert isinstance(A_binary, np.ndarray), 'A_binary should be of type `np.ndarray`'
89 | V = len(A_binary)
90 | V_large = V * window_size
91 | A_binary_with_I = A_binary + np.eye(len(A_binary), dtype=A_binary.dtype)
92 | # Build spatial-temporal graph
93 | A_large = np.tile(A_binary_with_I, (window_size, window_size)).copy()
94 | return A_large
95 |
96 | def forward(self, x):
97 | N, C, T, V = x.shape # T = number of windows
98 |
99 | # Build graphs
100 | A = self.A_scales.to(x.dtype).to(x.device) + self.A_res.to(x.dtype).to(x.device)
101 |
102 | # Perform Graph Convolution
103 | res = self.residual(x)
104 | agg = torch.einsum('vu,nctu->nctv', A, x)
105 | agg = agg.view(N, C, T, self.num_scales, V)
106 | agg = agg.permute(0,3,1,2,4).contiguous().view(N, self.num_scales*C, T, V)
107 | out = self.mlp(agg)
108 | out += res
109 | return self.act(out)
110 |
111 |
--------------------------------------------------------------------------------
/model/ms_tcn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '')
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from model.activation import activation_factory
8 |
9 |
10 | class TemporalConv(nn.Module):
11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
12 | super(TemporalConv, self).__init__()
13 | pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2
14 | self.conv = nn.Conv2d(
15 | in_channels,
16 | out_channels,
17 | kernel_size=(kernel_size, 1),
18 | padding=(pad, 0),
19 | stride=(stride, 1),
20 | dilation=(dilation, 1))
21 |
22 | self.bn = nn.BatchNorm2d(out_channels)
23 |
24 | def forward(self, x):
25 | x = self.conv(x)
26 | x = self.bn(x)
27 | return x
28 |
29 |
30 | class MultiScale_TemporalConv(nn.Module):
31 | def __init__(self,
32 | in_channels,
33 | out_channels,
34 | kernel_size=3,
35 | stride=1,
36 | dilations=[1,2,3,4],
37 | residual=True,
38 | residual_kernel_size=1,
39 | activation='relu',
40 | **kwargs):
41 |
42 | super().__init__()
43 | assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'
44 |
45 | # Multiple branches of temporal convolution
46 | self.num_branches = len(dilations) + 2
47 | branch_channels = out_channels // self.num_branches
48 |
49 | # Temporal Convolution branches
50 | self.branches = nn.ModuleList([
51 | nn.Sequential(
52 | nn.Conv2d(
53 | in_channels,
54 | branch_channels,
55 | kernel_size=1,
56 | padding=0),
57 | nn.BatchNorm2d(branch_channels),
58 | activation_factory(activation),
59 | # 在时间轴上对每一个joint做卷积
60 | TemporalConv(
61 | branch_channels,
62 | branch_channels,
63 | kernel_size=kernel_size,
64 | stride=stride,
65 | dilation=dilation),
66 | )
67 | for dilation in dilations
68 | ])
69 |
70 | # Additional Max & 1x1 branch
71 | self.branches.append(nn.Sequential(
72 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
73 | nn.BatchNorm2d(branch_channels),
74 | activation_factory(activation),
75 | nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)),
76 | nn.BatchNorm2d(branch_channels)
77 | ))
78 |
79 | self.branches.append(nn.Sequential(
80 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)),
81 | nn.BatchNorm2d(branch_channels)
82 | ))
83 |
84 | # Residual connection
85 | if not residual:
86 | self.residual = lambda x: 0
87 | elif (in_channels == out_channels) and (stride == 1):
88 | self.residual = lambda x: x
89 | else:
90 | self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
91 |
92 | self.act = activation_factory(activation)
93 |
94 | # Transformer attention
95 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']:
96 | self.to_use_temporal_trans = True
97 | self.section_size = kwargs['section_size']
98 | self.num_point = kwargs['num_point']
99 | self.trans_conv = nn.Conv2d(1, 1, (self.section_size, 1), (self.section_size, 1))
100 | nhead = 5
101 | nlayers = 2
102 | trans_dropout = 0.5
103 | encoder_layers = nn.TransformerEncoderLayer(self.num_point,
104 | nhead=nhead, dropout=trans_dropout)
105 | self.trans_enc = nn.TransformerEncoder(encoder_layers, nlayers)
106 |
107 | # frame normalization
108 | self.frame_norm_layer = nn.Softmax(dim=1)
109 | if 'frame_norm' in kwargs:
110 | if kwargs['frame_norm'] == 'sigmoid':
111 | self.frame_norm_layer = nn.Sigmoid()
112 |
113 | else:
114 | self.to_use_temporal_trans = False
115 |
116 | # Transformer feature
117 | if 'to_use_temp_trans_feature' in kwargs and kwargs['to_use_temp_trans_feature']:
118 | self.to_use_temp_trans_feature = True
119 | self.fea_dim = kwargs['fea_dim']
120 | nhead = kwargs['temp_trans_feature_n_head']
121 | nlayers = kwargs['temp_trans_feature_n_layer']
122 | trans_dropout = 0.5
123 | encoder_layers = nn.TransformerEncoderLayer(self.fea_dim,
124 | nhead=nhead, dropout=trans_dropout)
125 | self.trans_enc_fea = nn.TransformerEncoder(encoder_layers, nlayers)
126 | else:
127 | self.to_use_temp_trans_feature = False
128 |
129 | def forward(self, x):
130 | # Input dim: (N,C,T,V)
131 | res = self.residual(x)
132 | branch_outs = []
133 | tempconv_idx = 0
134 | for tempconv in self.branches:
135 | x_in = x
136 | if self.to_use_temporal_trans:
137 | x_mean = torch.mean(x, dim=1)
138 | x_mean = x_mean.unsqueeze(1)
139 | x_mean = self.trans_conv(x_mean).squeeze(1)
140 | x_mean = self.trans_enc(x_mean)
141 | x_mean = self.frame_norm_layer(x_mean)
142 | x_mean = torch.repeat_interleave(x_mean, self.section_size, dim=1)
143 | x_mean = torch.unsqueeze(x_mean, dim=1).repeat(1, x.shape[1], 1, 1)
144 | x_in = x * x_mean
145 | out = tempconv(x_in)
146 | branch_outs.append(out)
147 |
148 | out = torch.cat(branch_outs, dim=1)
149 |
150 | if self.to_use_temp_trans_feature:
151 | out = out.permute(3, 2, 0, 1)
152 | for a_out_idx in range(len(out)):
153 | a_out = out[a_out_idx]
154 | a_out = self.trans_enc_fea(a_out)
155 | out[a_out_idx] = a_out
156 | out = out.permute(2, 3, 1, 0)
157 |
158 | out += res
159 | out = self.act(out)
160 |
161 | return out
162 |
163 |
164 | if __name__ == "__main__":
165 | mstcn = MultiScale_TemporalConv(288, 288)
166 | x = torch.randn(32, 288, 100, 20)
167 | mstcn.forward(x)
168 | for name, param in mstcn.named_parameters():
169 | print(f'{name}: {param.numel()}')
170 | print(sum(p.numel() for p in mstcn.parameters() if p.requires_grad))
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import sys
3 |
4 | from model.att_gcn import Att_GraphConv
5 | from model.hyper_gcn import Hyper_GraphConv
6 | # from model.transformers import get_pretrained_transformer
7 |
8 | sys.path.insert(0, '')
9 |
10 | import math
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | from utils import import_class, count_params
17 | from model.ms_gcn import MultiScale_GraphConv as MS_GCN
18 | from model.ms_tcn import MultiScale_TemporalConv as MS_TCN
19 | from model.ms_gtcn import SpatialTemporal_MS_GCN, UnfoldTemporalWindows
20 | from model.mlp import MLP
21 | from model.activation import activation_factory
22 |
23 |
24 | class MS_G3D(nn.Module):
25 | def __init__(self,
26 | in_channels,
27 | out_channels,
28 | A_binary,
29 | num_scales,
30 | window_size,
31 | window_stride,
32 | window_dilation,
33 | embed_factor=1,
34 | nonlinear='relu'):
35 | super().__init__()
36 |
37 | self.window_size = window_size
38 | self.out_channels = out_channels
39 | self.embed_channels_in = self.embed_channels_out = out_channels // embed_factor
40 | if embed_factor == 1:
41 | self.in1x1 = nn.Identity()
42 | self.embed_channels_in = self.embed_channels_out = in_channels
43 | # The first STGC block changes channels right away; others change at collapse
44 | if in_channels == 3:
45 | self.embed_channels_out = out_channels
46 | else:
47 | self.in1x1 = MLP(in_channels, [self.embed_channels_in])
48 |
49 | self.gcn3d = nn.Sequential(
50 | UnfoldTemporalWindows(window_size, window_stride, window_dilation),
51 | SpatialTemporal_MS_GCN(
52 | in_channels=self.embed_channels_in,
53 | out_channels=self.embed_channels_out,
54 | A_binary=A_binary,
55 | num_scales=num_scales,
56 | window_size=window_size,
57 | use_Ares=True,
58 | activation=nonlinear
59 | )
60 | )
61 |
62 | self.out_conv = nn.Conv3d(self.embed_channels_out, out_channels, kernel_size=(1, self.window_size, 1))
63 | self.out_bn = nn.BatchNorm2d(out_channels)
64 |
65 | def forward(self, x):
66 | N, _, T, V = x.shape
67 | x = self.in1x1(x)
68 | # Construct temporal windows and apply MS-GCN
69 | x = self.gcn3d(x)
70 |
71 | # Collapse the window dimension
72 | x = x.view(N, self.embed_channels_out, -1, self.window_size, V)
73 | x = self.out_conv(x).squeeze(dim=3)
74 | x = self.out_bn(x)
75 |
76 | # no activation
77 | return x
78 |
79 |
80 | class MultiWindow_MS_G3D(nn.Module):
81 | def __init__(self,
82 | in_channels,
83 | out_channels,
84 | A_binary,
85 | num_scales,
86 | window_sizes=[3, 5],
87 | window_stride=1,
88 | window_dilations=[1, 1]):
89 | super().__init__()
90 | self.gcn3d = nn.ModuleList([
91 | MS_G3D(
92 | in_channels,
93 | out_channels,
94 | A_binary,
95 | num_scales,
96 | window_size,
97 | window_stride,
98 | window_dilation
99 | )
100 | for window_size, window_dilation in zip(window_sizes, window_dilations)
101 | ])
102 |
103 | def forward(self, x):
104 | # Input shape: (N, C, T, V)
105 | out_sum = 0
106 | for gcn3d in self.gcn3d:
107 | out_sum += gcn3d(x)
108 | # no activation
109 | return out_sum
110 |
111 | ntu_bone_angle_pairs = {
112 | 25: (24, 12),
113 | 24: (25, 12),
114 | 12: (24, 25),
115 | 11: (12, 10),
116 | 10: (11, 9),
117 | 9: (10, 21),
118 | 21: (9, 5),
119 | 5: (21, 6),
120 | 6: (5, 7),
121 | 7: (6, 8),
122 | 8: (23, 22),
123 | 22: (8, 23),
124 | 23: (8, 22),
125 | 3: (4, 21),
126 | 4: (4, 4),
127 | 2: (21, 1),
128 | 1: (17, 13),
129 | 17: (18, 1),
130 | 18: (19, 17),
131 | 19: (20, 18),
132 | 20: (20, 20),
133 | 13: (1, 14),
134 | 14: (13, 15),
135 | 15: (14, 16),
136 | 16: (16, 16)
137 | }
138 |
139 | ntu_bone_adj = {
140 | 25: 12,
141 | 24: 12,
142 | 12: 11,
143 | 11: 10,
144 | 10: 9,
145 | 9: 21,
146 | 21: 21,
147 | 5: 21,
148 | 6: 5,
149 | 7: 6,
150 | 8: 7,
151 | 22: 8,
152 | 23: 8,
153 | 3: 21,
154 | 4: 3,
155 | 2: 21,
156 | 1: 2,
157 | 17: 1,
158 | 18: 17,
159 | 19: 18,
160 | 20: 19,
161 | 13: 1,
162 | 14: 13,
163 | 15: 14,
164 | 16: 15
165 | }
166 |
167 |
168 | class Model(nn.Module):
169 | def __init__(self,
170 | num_class,
171 | num_point,
172 | num_person,
173 | num_gcn_scales,
174 | num_g3d_scales,
175 | graph,
176 | in_channels=3,
177 | ablation='original',
178 | to_use_final_fc=True,
179 | to_fc_last=True,
180 | frame_len=300,
181 | nonlinear='relu',
182 | **kwargs):
183 | super(Model, self).__init__()
184 |
185 | # cosine
186 | self.cos = nn.CosineSimilarity(dim=1, eps=0)
187 |
188 | # Activation function
189 | self.nonlinear_f = activation_factory(nonlinear)
190 |
191 | # ZQ ablation studies
192 | self.ablation = ablation
193 |
194 | Graph = import_class(graph)
195 | A_binary = Graph().A_binary
196 |
197 | self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
198 |
199 | # channels
200 | # c1 = 96
201 | # c2 = c1 * 2 # 192
202 | # c3 = c2 * 2 # 384
203 |
204 | c1 = 96
205 | self.c1 = c1
206 | c2 = c1 * 2 # 192 # Original implementation
207 | self.c2 = c2
208 | c3 = c2 * 2 # 384 # Original implementation
209 | self.c3 = c3
210 |
211 | # r=3 STGC blocks
212 |
213 | # MSG3D
214 | self.gcn3d1 = MultiWindow_MS_G3D(in_channels, c1, A_binary, num_g3d_scales, window_stride=1)
215 | self.sgcn1_msgcn = MS_GCN(num_gcn_scales, in_channels, c1, A_binary, disentangled_agg=True,
216 | **kwargs, temporal_len=frame_len, fea_dim=c1, to_use_hyper_conv=True,
217 | activation=nonlinear)
218 | self.sgcn1_ms_tcn_1 = MS_TCN(c1, c1, activation=nonlinear)
219 | self.sgcn1_ms_tcn_2 = MS_TCN(c1, c1, activation=nonlinear)
220 | self.sgcn1_ms_tcn_2.act = nn.Identity()
221 |
222 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']:
223 | self.tcn1 = MS_TCN(c1, c1, **kwargs,
224 | section_size=kwargs['section_sizes'][0], num_point=num_point,
225 | fea_dim=c1, activation=nonlinear)
226 | else:
227 | self.tcn1 = MS_TCN(c1, c1, **kwargs, fea_dim=c1, activation=nonlinear)
228 |
229 | # MSG3D
230 | self.gcn3d2 = MultiWindow_MS_G3D(c1, c2, A_binary, num_g3d_scales, window_stride=2)
231 | self.sgcn2_msgcn = MS_GCN(num_gcn_scales, c1, c1, A_binary, disentangled_agg=True,
232 | **kwargs, temporal_len=frame_len, fea_dim=c1, activation=nonlinear)
233 | self.sgcn2_ms_tcn_1 = MS_TCN(c1, c2, stride=2, activation=nonlinear)
234 | # self.sgcn2_ms_tcn_1 = MS_TCN(c1, c2, activation=nonlinear)
235 | self.sgcn2_ms_tcn_2 = MS_TCN(c2, c2, activation=nonlinear)
236 | self.sgcn2_ms_tcn_2.act = nn.Identity()
237 |
238 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']:
239 | self.tcn2 = MS_TCN(c2, c2, **kwargs,
240 | section_size=kwargs['section_sizes'][1], num_point=num_point,
241 | fea_dim=c2, activation=nonlinear)
242 | else:
243 | self.tcn2 = MS_TCN(c2, c2, **kwargs, fea_dim=c2, activation=nonlinear)
244 |
245 | # MSG3D
246 | self.gcn3d3 = MultiWindow_MS_G3D(c2, c3, A_binary, num_g3d_scales, window_stride=2)
247 | self.sgcn3_msgcn = MS_GCN(num_gcn_scales, c2, c2, A_binary, disentangled_agg=True,
248 | **kwargs, temporal_len=frame_len // 2, fea_dim=c2,
249 | activation=nonlinear)
250 | self.sgcn3_ms_tcn_1 = MS_TCN(c2, c3, stride=2, activation=nonlinear)
251 | # self.sgcn3_ms_tcn_1 = MS_TCN(c2, c3, activation=nonlinear)
252 | self.sgcn3_ms_tcn_2 = MS_TCN(c3, c3, activation=nonlinear)
253 | self.sgcn3_ms_tcn_2.act = nn.Identity()
254 |
255 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']:
256 | self.tcn3 = MS_TCN(c3, c3, **kwargs,
257 | section_size=kwargs['section_sizes'][2], num_point=num_point,
258 | fea_dim=c3, activation=nonlinear)
259 | else:
260 | self.tcn3 = MS_TCN(c3, c3, **kwargs, fea_dim=c3, activation=nonlinear)
261 |
262 | self.use_temporal_transformer = False
263 |
264 | self.to_use_final_fc = to_use_final_fc
265 | if self.to_use_final_fc:
266 | self.fc = nn.Linear(c3, num_class)
267 |
268 | def forward(self, x, set_to_fc_last=True):
269 | # Select channels
270 | x = x[:, :3, :, :]
271 | x = self.preprocessing(x)
272 | # assert 0
273 | N, C, T, V, M = x.size()
274 |
275 | x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
276 | x = self.data_bn(x)
277 | x = x.view(N * M, V, C, T).permute(0, 2, 3, 1).contiguous()
278 |
279 | ###### First Component ######
280 | x = self.sgcn1_msgcn(x)
281 | x = self.sgcn1_ms_tcn_1(x)
282 | x = self.sgcn1_ms_tcn_2(x)
283 | x = self.nonlinear_f(x)
284 | x = self.tcn1(x)
285 | ###### End First Component ######
286 |
287 | ###### Second Component ######
288 | x = self.sgcn2_msgcn(x)
289 | x = self.sgcn2_ms_tcn_1(x)
290 | x = self.sgcn2_ms_tcn_2(x)
291 | x = self.nonlinear_f(x)
292 | x = self.tcn2(x)
293 | ###### End Second Component ######
294 |
295 | ###### Third Component ######
296 | x = self.sgcn3_msgcn(x)
297 | x = self.sgcn3_ms_tcn_1(x)
298 | x = self.sgcn3_ms_tcn_2(x)
299 | x = self.nonlinear_f(x)
300 | x = self.tcn3(x)
301 | ###### End Third Component ######
302 |
303 | out = x
304 |
305 | out_channels = out.size(1)
306 |
307 | t_dim = out.shape[2]
308 | out = out.view(N, M, out_channels, t_dim, -1)
309 | out = out.permute(0, 1, 3, 4, 2) # N, M, T, V, C
310 | out = out.mean(3) # Global Average Pooling (Spatial)
311 |
312 | out = out.mean(2) # Global Average Pooling (Temporal)
313 | out = out.mean(1) # Average pool number of bodies in the sequence
314 |
315 | if set_to_fc_last:
316 | if self.to_use_final_fc:
317 | out = self.fc(out)
318 |
319 | other_outs = {}
320 | return out, other_outs
321 |
322 | def preprocessing(self, x):
323 | # Extract Bone and Angular Features
324 | fp_sp_joint_list_bone = []
325 | fp_sp_joint_list_bone_angle = []
326 | fp_sp_joint_list_body_center_angle_1 = []
327 | fp_sp_joint_list_body_center_angle_2 = []
328 | fp_sp_left_hand_angle = []
329 | fp_sp_right_hand_angle = []
330 | fp_sp_two_hand_angle = []
331 | fp_sp_two_elbow_angle = []
332 | fp_sp_two_knee_angle = []
333 | fp_sp_two_feet_angle = []
334 |
335 | all_list = [
336 | fp_sp_joint_list_bone, fp_sp_joint_list_bone_angle, fp_sp_joint_list_body_center_angle_1,
337 | fp_sp_joint_list_body_center_angle_2, fp_sp_left_hand_angle, fp_sp_right_hand_angle,
338 | fp_sp_two_hand_angle, fp_sp_two_elbow_angle, fp_sp_two_knee_angle,
339 | fp_sp_two_feet_angle
340 | ]
341 |
342 | for a_key in ntu_bone_angle_pairs:
343 | a_angle_value = ntu_bone_angle_pairs[a_key]
344 | a_bone_value = ntu_bone_adj[a_key]
345 | the_joint = a_key - 1
346 | a_adj = a_bone_value - 1
347 | bone_diff = (x[:, :3, :, the_joint, :] -
348 | x[:, :3, :, a_adj, :]).unsqueeze(3).cpu()
349 | fp_sp_joint_list_bone.append(bone_diff)
350 |
351 | # bone angles
352 | v1 = a_angle_value[0] - 1
353 | v2 = a_angle_value[1] - 1
354 | vec1 = x[:, :3, :, v1, :] - x[:, :3, :, the_joint, :]
355 | vec2 = x[:, :3, :, v2, :] - x[:, :3, :, the_joint, :]
356 | angular_feature = (1.0 - self.cos(vec1, vec2))
357 | angular_feature[angular_feature != angular_feature] = 0
358 | fp_sp_joint_list_bone_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
359 |
360 | # body angles 1
361 | vec1 = x[:, :3, :, 2 - 1, :] - x[:, :3, :, the_joint, :]
362 | vec2 = x[:, :3, :, 21 - 1, :] - x[:, :3, :, the_joint, :]
363 | angular_feature = (1.0 - self.cos(vec1, vec2))
364 | angular_feature[angular_feature != angular_feature] = 0
365 | fp_sp_joint_list_body_center_angle_1.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
366 |
367 | # body angles 2
368 | vec1 = x[:, :3, :, the_joint, :] - x[:, :3, :, 21 - 1, :]
369 | vec2 = x[:, :3, :, 2 - 1, :] - x[:, :3, :, 21 - 1, :]
370 | angular_feature = (1.0 - self.cos(vec1, vec2))
371 | angular_feature[angular_feature != angular_feature] = 0
372 | fp_sp_joint_list_body_center_angle_2.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
373 |
374 | # left hand angle
375 | vec1 = x[:, :3, :, 24 - 1, :] - x[:, :3, :, the_joint, :]
376 | vec2 = x[:, :3, :, 25 - 1, :] - x[:, :3, :, the_joint, :]
377 | angular_feature = (1.0 - self.cos(vec1, vec2))
378 | angular_feature[angular_feature != angular_feature] = 0
379 | fp_sp_left_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
380 |
381 | # right hand angle
382 | vec1 = x[:, :3, :, 22 - 1, :] - x[:, :3, :, the_joint, :]
383 | vec2 = x[:, :3, :, 23 - 1, :] - x[:, :3, :, the_joint, :]
384 | angular_feature = (1.0 - self.cos(vec1, vec2))
385 | angular_feature[angular_feature != angular_feature] = 0
386 | fp_sp_right_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
387 |
388 | # two hand angle
389 | vec1 = x[:, :3, :, 24 - 1, :] - x[:, :3, :, the_joint, :]
390 | vec2 = x[:, :3, :, 22 - 1, :] - x[:, :3, :, the_joint, :]
391 | angular_feature = (1.0 - self.cos(vec1, vec2))
392 | angular_feature[angular_feature != angular_feature] = 0
393 | fp_sp_two_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
394 |
395 | # two elbow angle
396 | vec1 = x[:, :3, :, 10 - 1, :] - x[:, :3, :, the_joint, :]
397 | vec2 = x[:, :3, :, 6 - 1, :] - x[:, :3, :, the_joint, :]
398 | angular_feature = (1.0 - self.cos(vec1, vec2))
399 | angular_feature[angular_feature != angular_feature] = 0
400 | fp_sp_two_elbow_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
401 |
402 | # two knee angle
403 | vec1 = x[:, :3, :, 18 - 1, :] - x[:, :3, :, the_joint, :]
404 | vec2 = x[:, :3, :, 14 - 1, :] - x[:, :3, :, the_joint, :]
405 | angular_feature = (1.0 - self.cos(vec1, vec2))
406 | angular_feature[angular_feature != angular_feature] = 0
407 | fp_sp_two_knee_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
408 |
409 | # two feet angle
410 | vec1 = x[:, :3, :, 20 - 1, :] - x[:, :3, :, the_joint, :]
411 | vec2 = x[:, :3, :, 16 - 1, :] - x[:, :3, :, the_joint, :]
412 | angular_feature = (1.0 - self.cos(vec1, vec2))
413 | angular_feature[angular_feature != angular_feature] = 0
414 | fp_sp_two_feet_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu())
415 |
416 | for a_list_id in range(len(all_list)):
417 | all_list[a_list_id] = torch.cat(all_list[a_list_id], dim=3)
418 |
419 | all_list = torch.cat(all_list, dim=1)
420 | # print('All_list:', all_list.shape)
421 |
422 | features = torch.cat((x, all_list.cuda()), dim=1)
423 | # print('features:', features.shape)
424 | return features
425 |
426 |
427 | if __name__ == "__main__":
428 | # For debugging purposes
429 | import sys
430 |
431 | sys.path.append('..')
432 |
433 | model = Model(
434 | num_class=60,
435 | num_point=25,
436 | num_person=2,
437 | num_gcn_scales=13,
438 | num_g3d_scales=6,
439 | graph='graph.ntu_rgb_d.AdjMatrixGraph'
440 | )
441 |
442 | N, C, T, V, M = 6, 3, 50, 25, 2
443 | x = torch.randn(N, C, T, V, M)
444 | model.forward(x)
445 |
446 | print('Model total # params:', count_params(model))
447 |
--------------------------------------------------------------------------------
/notification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__init__.py
--------------------------------------------------------------------------------
/notification/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_config.cpython-36.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_config.cpython-37.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_sender.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_sender.cpython-36.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_sender.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_sender.cpython-37.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_templates.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_templates.cpython-36.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/email_templates.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_templates.cpython-37.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/html_templates.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/html_templates.cpython-36.pyc
--------------------------------------------------------------------------------
/notification/__pycache__/html_templates.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/html_templates.cpython-37.pyc
--------------------------------------------------------------------------------
/notification/email_config.py:
--------------------------------------------------------------------------------
1 | EMAIL_S_ADDRESS = 'YOUR SENDER EMAIL ADDRESS'
2 | PASSWORD = 'YOUR PASSWORD'
3 | EMAIL_R_ADDRESS = 'YOUR RECEIVER EMAIL ADDRESS'
4 |
--------------------------------------------------------------------------------
/notification/email_sender.py:
--------------------------------------------------------------------------------
1 | import smtplib
2 |
3 | from notification.email_templates import *
4 | from .email_config import *
5 | import traceback
6 | from email.mime.text import MIMEText
7 | from email.mime.multipart import MIMEMultipart
8 |
9 | from .html_templates import *
10 |
11 |
12 | def send_email(receivers, email_type, msg_content):
13 | return
14 | template_1 = None
15 | sub_email = None
16 | if email_type == 'exp_end':
17 | template_1 = exp_complete_1
18 | template_2 = exp_complete_2
19 | sub_email = 'Watchtower at Warrumbul: Experiment Complete'
20 | elif email_type == 'test_end':
21 | template_1 = exp_progress_1
22 | template_2 = exp_progress_2
23 | sub_email = 'Watchtower at Warrumbul: Experiment Progress'
24 | elif email_type == 'error':
25 | template_1 = emergency_1
26 | template_2 = emergency_2
27 | sub_email = 'Watchtower at Warrumbul: Emergency'
28 | else:
29 | raise NotImplementedError
30 |
31 | for a_person in receivers:
32 |
33 | def send_email(msg):
34 | try:
35 | server = smtplib.SMTP('smtp-relay.sendinblue.com', port=587)
36 | server.ehlo()
37 | server.starttls()
38 | server.login(EMAIL_S_ADDRESS, PASSWORD)
39 | message = msg
40 | server.sendmail(EMAIL_S_ADDRESS, a_person, message)
41 | server.quit()
42 | except Exception:
43 | traceback.print_exc()
44 |
45 | message = MIMEMultipart("alternative")
46 | if sub_email is not None:
47 | message["Subject"] = sub_email
48 | else:
49 | message["Subject"] = "实验进展: 加急公文 御赐金牌 马上飞递"
50 | message["From"] = EMAIL_S_ADDRESS
51 | message["To"] = a_person
52 |
53 | html = template_1 + msg_content + template_2
54 |
55 | part2 = MIMEText(html, 'html', 'utf-8')
56 |
57 | # message.attach(part1)
58 | message.attach(part2)
59 |
60 | send_email(message.as_string())
61 |
62 | print(f'Success: Email sent to {a_person}')
63 |
64 |
65 | if __name__ == '__main__':
66 | test_msg = '这是一个测试. '
67 | send_email(receivers=['zhenyue.qin@anu.edu.au'],
68 | email_type='error',
69 | msg_content=test_msg)
--------------------------------------------------------------------------------
/notification/email_test.py:
--------------------------------------------------------------------------------
1 | from notification.email_sender import send_email
2 |
3 |
4 | test_msg = '这是一个测试. '
5 | send_email(receivers=['zhenyue.qin@anu.edu.au'],
6 | email_type='exp_end',
7 | msg_content=test_msg)
8 |
--------------------------------------------------------------------------------
/processor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__init__.py
--------------------------------------------------------------------------------
/processor/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/processor/__pycache__/args.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__pycache__/args.cpython-36.pyc
--------------------------------------------------------------------------------
/processor/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import os
4 |
5 |
6 | def str2bool(v):
7 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
8 | return True
9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
10 | return False
11 | else:
12 | raise argparse.ArgumentTypeError('Boolean value expected.')
13 |
14 |
15 | def get_parser():
16 | # parameter priority: command line > config file > default
17 | parser = argparse.ArgumentParser(description='MS-G3D')
18 |
19 | parser.add_argument(
20 | '--work-dir',
21 | type=str,
22 | required=False,
23 | default='.work_dir/unknown_path',
24 | help='the work folder for storing results')
25 | parser.add_argument('--model_saved_name', default='')
26 | parser.add_argument(
27 | '--config',
28 | default='./config/nturgbd-cross1-view/test_bone.yaml',
29 | help='path to the configuration file')
30 | parser.add_argument(
31 | '--assume-yes',
32 | action='store_true',
33 | help='Say yes to every prompt')
34 |
35 | parser.add_argument(
36 | '--phase',
37 | default='train',
38 | help='must be train or test')
39 | parser.add_argument(
40 | '--save-score',
41 | type=str2bool,
42 | default=False,
43 | help='if ture, the classification score will be stored')
44 |
45 | parser.add_argument(
46 | '--seed',
47 | type=int,
48 | default=0,
49 | help='random seed')
50 | parser.add_argument(
51 | '--log-interval',
52 | type=int,
53 | default=1,
54 | help='the interval for printing messages (#iteration)')
55 | parser.add_argument(
56 | '--save-interval',
57 | type=int,
58 | default=1,
59 | help='the interval for storing models (#iteration)')
60 | parser.add_argument(
61 | '--eval-interval',
62 | type=int,
63 | default=5,
64 | help='the interval for evaluating models (#iteration)')
65 | parser.add_argument(
66 | '--eval-start',
67 | type=int,
68 | default=1,
69 | help='The epoch number to start evaluating models')
70 | parser.add_argument(
71 | '--print-log',
72 | type=str2bool,
73 | default=True,
74 | help='print logging or not')
75 | parser.add_argument(
76 | '--show-topk',
77 | type=int,
78 | default=[1, 5],
79 | nargs='+',
80 | help='which Top K accuracy will be shown')
81 |
82 | parser.add_argument(
83 | '--feeder',
84 | default='feeder.feeder',
85 | help='data loader will be used')
86 | parser.add_argument(
87 | '--num-worker',
88 | type=int,
89 | default=os.cpu_count(),
90 | help='the number of worker for data loader')
91 | parser.add_argument(
92 | '--train-feeder-args',
93 | '--train-feeder-args',
94 | default=dict(),
95 | help='the arguments of data loader for training')
96 | parser.add_argument(
97 | '--test-feeder-args',
98 | default=dict(),
99 | help='the arguments of data loader for test')
100 |
101 | parser.add_argument(
102 | '--model',
103 | default=None,
104 | help='the model will be used')
105 | parser.add_argument(
106 | '--model-args',
107 | type=dict,
108 | default=dict(),
109 | help='the arguments of model')
110 | parser.add_argument(
111 | '--weights',
112 | default=None,
113 | help='the weights for network initialization')
114 | parser.add_argument(
115 | '--ignore-weights',
116 | type=str,
117 | default=[],
118 | nargs='+',
119 | help='the name of weights which will be ignored in the initialization')
120 | parser.add_argument(
121 | '--half',
122 | action='store_true',
123 | help='Use half-precision (FP16) training')
124 | parser.add_argument(
125 | '--amp-opt-level',
126 | type=int,
127 | default=1,
128 | help='NVIDIA Apex AMP optimization level')
129 |
130 | parser.add_argument(
131 | '--base-lr',
132 | type=float,
133 | default=0.01,
134 | help='initial learning rate')
135 | parser.add_argument(
136 | '--step',
137 | type=int,
138 | default=[20, 40, 60],
139 | nargs='+',
140 | help='the epoch where optimizer reduce the learning rate')
141 | parser.add_argument(
142 | '--lr_decay',
143 | type=float,
144 | default=0.1,
145 | help='learning rate decay degree'
146 | )
147 | parser.add_argument(
148 | '--device',
149 | type=int,
150 | default=0,
151 | nargs='+',
152 | help='the indexes of GPUs for training or testing')
153 | parser.add_argument(
154 | '--optimizer',
155 | default='SGD',
156 | help='type of optimizer')
157 | parser.add_argument(
158 | '--nesterov',
159 | type=str2bool,
160 | default=False,
161 | help='use nesterov or not')
162 | parser.add_argument(
163 | '--batch-size',
164 | type=int,
165 | default=32,
166 | help='training batch size')
167 | parser.add_argument(
168 | '--test-batch-size',
169 | type=int,
170 | default=256,
171 | help='test batch size')
172 | parser.add_argument(
173 | '--forward-batch-size',
174 | type=int,
175 | default=16,
176 | help='Batch size during forward pass, must be factor of --batch-size')
177 | parser.add_argument(
178 | '--start-epoch',
179 | type=int,
180 | default=0,
181 | help='start training from which epoch')
182 | parser.add_argument(
183 | '--num-epoch',
184 | type=int,
185 | default=80,
186 | help='stop training in which epoch')
187 | parser.add_argument(
188 | '--weight-decay',
189 | type=float,
190 | default=0.0005,
191 | help='weight decay for optimizer')
192 | parser.add_argument(
193 | '--optimizer-states',
194 | type=str,
195 | help='path of previously saved optimizer states')
196 | parser.add_argument(
197 | '--checkpoint',
198 | type=str,
199 | help='path of previously saved training checkpoint')
200 | parser.add_argument(
201 | '--debug',
202 | type=str2bool,
203 | default=False,
204 | help='Debug mode; default false')
205 |
206 | # ZQ resume training
207 | parser.add_argument(
208 | '--resume',
209 | type=str2bool,
210 | default=False,
211 | help='resume previous training'
212 | )
213 |
214 | parser.add_argument(
215 | '--tbatch',
216 | type=int,
217 | default=128,
218 | help='batch size for transformers'
219 | )
220 |
221 | parser.add_argument(
222 | '--train_print_freq',
223 | type=int,
224 | default=100,
225 | help='training printing frequency'
226 | )
227 |
228 | # one hot
229 | parser.add_argument(
230 | '--to_add_onehot',
231 | type=bool,
232 | default=False,
233 | help='to add one hot in the input data'
234 | )
235 |
236 | # feature selection
237 | parser.add_argument(
238 | '--feature_combo',
239 | type=str,
240 | default='',
241 | help='what features to use'
242 | )
243 |
244 | parser.add_argument(
245 | '--additional_loss',
246 | type=dict,
247 | default=dict(),
248 | )
249 |
250 | parser.add_argument( # encode data
251 | '--encoding_args',
252 | default=dict(),
253 | )
254 |
255 | return parser
256 |
--------------------------------------------------------------------------------
/processor/io.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | import yaml
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import torchlight
11 | from torchlight.torchlight.gpu import visible_gpu, occupy_gpu
12 | from torchlight.torchlight.io import str2bool
13 | from torchlight.torchlight.io import DictAction
14 | from torchlight.torchlight.io import import_class
15 |
16 |
17 | class IO():
18 |
19 | def __init__(self, argv=None):
20 |
21 | self.load_arg(argv)
22 | self.init_environment()
23 | self.load_model()
24 | self.load_weights()
25 | self.gpu()
26 |
27 | def load_arg(self, argv=None):
28 | parser = self.get_parser()
29 |
30 | # load arg form config file
31 | p = parser.parse_args(argv)
32 | if p.config is not None:
33 | # load config file
34 | with open(p.config, 'r') as f:
35 | default_arg = yaml.load(f)
36 |
37 | # update parser from config file
38 | key = vars(p).keys()
39 | for k in default_arg.keys():
40 | if k not in key:
41 | print('Unknown Arguments: {}'.format(k))
42 | assert k in key
43 |
44 | parser.set_defaults(**default_arg)
45 |
46 | self.arg = parser.parse_args(argv)
47 |
48 | def init_environment(self):
49 | self.save_dir = os.path.join(self.arg.work_dir,
50 | self.arg.max_hop_dir,
51 | self.arg.lamda_act_dir)
52 | self.io = torchlight.torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log)
53 | self.io.save_arg(self.arg)
54 |
55 | # gpu
56 | if self.arg.use_gpu:
57 | gpus = visible_gpu(self.arg.device)
58 | occupy_gpu(gpus)
59 | self.gpus = gpus
60 | self.dev = "cuda:0"
61 | else:
62 | self.dev = "cpu"
63 |
64 | def load_model(self):
65 | self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args))
66 | self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args))
67 |
68 | def load_weights(self):
69 | if self.arg.weights1:
70 | self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights)
71 | self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights)
72 |
73 | def gpu(self):
74 | # move modules to gpu
75 | self.model1 = self.model1.to(self.dev)
76 | self.model2 = self.model2.to(self.dev)
77 | for name, value in vars(self).items():
78 | cls_name = str(value.__class__)
79 | if cls_name.find('torch.nn.modules') != -1:
80 | setattr(self, name, value.to(self.dev))
81 |
82 | # model parallel
83 | if self.arg.use_gpu and len(self.gpus) > 1:
84 | self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus)
85 | self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus)
86 |
87 | def start(self):
88 | self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg))))
89 |
90 | @staticmethod
91 | def get_parser(add_help=False):
92 |
93 | #region arguments yapf: disable
94 | # parameter priority: command line > config > default
95 | parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor')
96 |
97 | parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results')
98 | parser.add_argument('-c', '--config', default=None, help='path to the configuration file')
99 |
100 | # processor
101 | parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
102 | parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')
103 |
104 | # visulize and debug
105 | parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not')
106 | parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not')
107 |
108 | # model
109 | parser.add_argument('--model1', default=None, help='the model will be used')
110 | parser.add_argument('--model2', default=None, help='the model will be used')
111 | parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model')
112 | parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model')
113 | parser.add_argument('--weights', default=None, help='the weights for network initialization')
114 | parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization')
115 | #endregion yapf: enable
116 |
117 | return parser
118 |
--------------------------------------------------------------------------------
/processor/processor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import yaml
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 |
10 | import torchlight
11 | from torchlight.torchlight.io import str2bool
12 | from torchlight.torchlight.io import DictAction
13 | from torchlight.torchlight.io import import_class
14 |
15 | from .io import IO
16 |
17 |
18 | class Processor(IO):
19 |
20 | def __init__(self, argv=None):
21 |
22 | self.load_arg(argv)
23 | self.init_environment()
24 | self.load_model()
25 | self.load_weights()
26 | self.gpu()
27 | self.load_data()
28 |
29 | def init_environment(self):
30 |
31 | super().init_environment()
32 | self.result = dict()
33 | self.iter_info = dict()
34 | self.epoch_info = dict()
35 | self.meta_info = dict(epoch=0, iter=0)
36 |
37 |
38 | def load_data(self):
39 | Feeder = import_class(self.arg.feeder)
40 | if 'debug' not in self.arg.train_feeder_args:
41 | self.arg.train_feeder_args['debug'] = self.arg.debug
42 | self.data_loader = dict()
43 | if self.arg.phase == 'train':
44 | self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args),
45 | batch_size=self.arg.batch_size,
46 | shuffle=True,
47 | num_workers=self.arg.num_worker,
48 | drop_last=True)
49 | if self.arg.test_feeder_args:
50 | self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args),
51 | batch_size=self.arg.test_batch_size,
52 | shuffle=False,
53 | num_workers=self.arg.num_worker)
54 |
55 | def show_epoch_info(self):
56 | for k, v in self.epoch_info.items():
57 | self.io.print_log('\t{}: {}'.format(k, v))
58 | if self.arg.pavi_log:
59 | self.io.log('train', self.meta_info['iter'], self.epoch_info)
60 |
61 | def show_iter_info(self):
62 | if self.meta_info['iter'] % self.arg.log_interval == 0:
63 | info ='\tIter {} Done.'.format(self.meta_info['iter'])
64 | for k, v in self.iter_info.items():
65 | if isinstance(v, float):
66 | info = info + ' | {}: {:.4f}'.format(k, v)
67 | else:
68 | info = info + ' | {}: {}'.format(k, v)
69 |
70 | self.io.print_log(info)
71 |
72 | if self.arg.pavi_log:
73 | self.io.log('train', self.meta_info['iter'], self.iter_info)
74 |
75 | def train(self):
76 | for _ in range(100):
77 | self.iter_info['loss'] = 0
78 | self.iter_info['loss_class'] = 0
79 | self.iter_info['loss_recon'] = 0
80 | self.show_iter_info()
81 | self.meta_info['iter'] += 1
82 | self.epoch_info['mean_loss'] = 0
83 | self.epoch_info['mean_loss_class'] = 0
84 | self.epoch_info['mean_loss_recon'] = 0
85 | self.show_epoch_info()
86 |
87 | def test(self):
88 | for _ in range(100):
89 | self.iter_info['loss'] = 1
90 | self.iter_info['loss_class'] = 1
91 | self.iter_info['loss_recon'] = 1
92 | self.show_iter_info()
93 | self.epoch_info['mean_loss'] = 1
94 | self.epoch_info['mean_loss_class'] = 1
95 | self.epoch_info['mean_loss_recon'] = 1
96 | self.show_epoch_info()
97 |
98 | def start(self):
99 | self.arg.eval_interval = 5
100 |
101 | self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg))))
102 |
103 | if self.arg.phase == 'train':
104 | for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
105 | self.meta_info['epoch'] = epoch
106 |
107 | if epoch < 10:
108 | self.io.print_log('Training epoch: {}'.format(epoch))
109 | self.train(training_A=True)
110 | self.io.print_log('Done.')
111 | else:
112 | self.io.print_log('Training epoch: {}'.format(epoch))
113 | self.train(training_A=False)
114 | self.io.print_log('Done.')
115 |
116 | # save model
117 | if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch):
118 | filename1 = 'epoch{}_model1.pt'.format(epoch)
119 | self.io.save_model(self.model1, filename1)
120 | filename2 = 'epoch{}_model2.pt'.format(epoch)
121 | self.io.save_model(self.model2, filename2)
122 |
123 | # evaluation
124 | if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch):
125 | self.io.print_log('Eval epoch: {}'.format(epoch))
126 | if epoch <= 10:
127 | self.test(testing_A=True)
128 | else:
129 | self.test(testing_A=False)
130 | self.io.print_log('Done.')
131 |
132 |
133 | elif self.arg.phase == 'test':
134 | if self.arg.weights2 is None:
135 | raise ValueError('Please appoint --weights.')
136 | self.io.print_log('Model: {}.'.format(self.arg.model2))
137 | self.io.print_log('Weights: {}.'.format(self.arg.weights2))
138 |
139 | self.io.print_log('Evaluation Start:')
140 | self.test(testing_A=False, save_feature=True)
141 | self.io.print_log('Done.\n')
142 |
143 | if self.arg.save_result:
144 | result_dict = dict(
145 | zip(self.data_loader['test'].dataset.sample_name,
146 | self.result))
147 | self.io.save_pkl(result_dict, 'test_result.pkl')
148 |
149 |
150 | @staticmethod
151 | def get_parser(add_help=False):
152 |
153 | parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor')
154 |
155 | parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results')
156 | parser.add_argument('-c', '--config', default=None, help='path to the configuration file')
157 |
158 | parser.add_argument('--phase', default='train', help='must be train or test')
159 | parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored')
160 | parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch')
161 | parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch')
162 | parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
163 | parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')
164 |
165 | parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)')
166 | parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)')
167 | parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)')
168 | parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not')
169 | parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not')
170 | parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not')
171 |
172 | parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used')
173 | parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader')
174 | parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training')
175 | parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test')
176 | parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
177 | parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size')
178 | parser.add_argument('--debug', action="store_true", help='less data, faster loading')
179 |
180 | parser.add_argument('--model1', default=None, help='the model will be used')
181 | parser.add_argument('--model2', default=None, help='the model will be used')
182 | parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model')
183 | parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model')
184 | parser.add_argument('--weights1', default=None, help='the weights for network initialization')
185 | parser.add_argument('--weights2', default=None, help='the weights for network initialization')
186 | parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization')
187 |
188 | return parser
189 |
--------------------------------------------------------------------------------
/processor/recognition.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 |
12 | from processor.torchlight_io import str2bool
13 |
14 | from .processor import Processor
15 |
16 |
17 | def weights_init(m):
18 | classname = m.__class__.__name__
19 | if classname.find('Conv1d') != -1:
20 | m.weight.data.normal_(0.0, 0.02)
21 | if m.bias is not None:
22 | m.bias.data.fill_(0)
23 | elif classname.find('Conv2d') != -1:
24 | m.weight.data.normal_(0.0, 0.02)
25 | if m.bias is not None:
26 | m.bias.data.fill_(0)
27 | elif classname.find('BatchNorm') != -1:
28 | m.weight.data.normal_(1.0, 0.02)
29 | m.bias.data.fill_(0)
30 |
31 |
32 | class REC_Processor(Processor):
33 |
34 | def load_model(self):
35 | self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args))
36 | self.model1.apply(weights_init)
37 | self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args))
38 |
39 | self.loss_class = nn.CrossEntropyLoss().cuda()
40 | self.loss_pred = nn.MSELoss()
41 | self.w_pred = 0.01
42 |
43 | prior = np.array([0.95, 0.05/2, 0.05/2])
44 | self.log_prior = torch.FloatTensor(np.log(prior))
45 | self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0)
46 |
47 | self.load_optimizer()
48 |
49 | def load_optimizer(self):
50 | if self.arg.optimizer == 'SGD':
51 | self.optimizer1 = optim.SGD(params=self.model1.parameters(),
52 | lr=self.arg.base_lr1,
53 | momentum=0.9,
54 | nesterov=self.arg.nesterov,
55 | weight_decay=self.arg.weight_decay)
56 | elif self.arg.optimizer == 'Adam':
57 | self.optimizer1 = optim.Adam(params=self.model1.parameters(),
58 | lr=self.arg.base_lr1,
59 | weight_decay=self.arg.weight_decay)
60 | else:
61 | raise ValueError()
62 | self.optimizer2 = optim.Adam(params=self.model2.parameters(),
63 | lr=self.arg.base_lr2)
64 |
65 | def adjust_lr(self):
66 | if self.arg.optimizer == 'SGD' and self.arg.step:
67 | lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step)))
68 | for param_group in self.optimizer1.param_groups:
69 | param_group['lr'] = lr
70 | self.lr = lr
71 | else:
72 | self.lr = self.arg.base_lr1
73 | self.lr2 = self.arg.base_lr2
74 |
75 | def nll_gaussian(self, preds, target, variance, add_const=False):
76 | neg_log_p = ((preds-target)**2/(2*variance))
77 | if add_const:
78 | const = 0.5*np.log(2*np.pi*variance)
79 | neg_log_p += const
80 | return neg_log_p.sum() / (target.size(0) * target.size(1))
81 |
82 | def kl_categorical(self, preds, log_prior, num_node, eps=1e-16):
83 | kl_div = preds*(torch.log(preds+eps)-log_prior)
84 | return kl_div.sum()/(num_node*preds.size(0))
85 |
86 |
87 | def train(self, training_A=False):
88 | self.model1.train()
89 | self.model2.train()
90 | self.adjust_lr()
91 | loader = self.data_loader['train']
92 | loss1_value = []
93 | loss_class_value = []
94 | loss_recon_value = []
95 | loss2_value = []
96 | loss_nll_value = []
97 | loss_kl_value = []
98 |
99 | if training_A:
100 | for param1 in self.model1.parameters():
101 | param1.requires_grad = False
102 | for param2 in self.model2.parameters():
103 | param2.requires_grad = True
104 | self.iter_info.clear()
105 | self.epoch_info.clear()
106 |
107 | for data, data_downsample, target_data, data_last, label in loader:
108 | data = data.float().to(self.dev)
109 | data_downsample = data_downsample.float().to(self.dev)
110 | label = label.long().to(self.dev)
111 |
112 | gpu_id = data.get_device()
113 | self.log_prior = self.log_prior.cuda(gpu_id)
114 | A_batch, prob, outputs, data_target = self.model2(data_downsample)
115 | loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4)
116 | loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25)
117 | loss2 = loss_nll + loss_kl
118 |
119 | self.optimizer2.zero_grad()
120 | loss2.backward()
121 | self.optimizer2.step()
122 |
123 | self.iter_info['loss2'] = loss2.data.item()
124 | self.iter_info['loss_nll'] = loss_nll.data.item()
125 | self.iter_info['loss_kl'] = loss_kl.data.item()
126 | self.iter_info['lr'] = '{:.6f}'.format(self.lr2)
127 |
128 | loss2_value.append(self.iter_info['loss2'])
129 | loss_nll_value.append(self.iter_info['loss_nll'])
130 | loss_kl_value.append(self.iter_info['loss_kl'])
131 | self.show_iter_info()
132 | self.meta_info['iter'] += 1
133 | self.epoch_info['mean_loss2'] = np.mean(loss2_value)
134 | self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value)
135 | self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value)
136 | self.show_epoch_info()
137 | self.io.print_timer()
138 |
139 | else:
140 | for param1 in self.model1.parameters():
141 | param1.requires_grad = True
142 | for param2 in self.model2.parameters():
143 | param2.requires_grad = True
144 | self.iter_info.clear()
145 | self.epoch_info.clear()
146 | for data, data_downsample, target_data, data_last, label in loader:
147 | data = data.float().to(self.dev)
148 | data_downsample = data_downsample.float().to(self.dev)
149 | target_data = target_data.float().to(self.dev)
150 | data_last = data_last.float().to(self.dev)
151 | label = label.long().to(self.dev)
152 |
153 | A_batch, prob, outputs, _ = self.model2(data_downsample)
154 | x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act)
155 | loss_class = self.loss_class(x_class, label)
156 | loss_recon = self.loss_pred(pred, target)
157 | loss1 = loss_class + self.w_pred*loss_recon
158 |
159 | self.optimizer1.zero_grad()
160 | loss1.backward()
161 | self.optimizer1.step()
162 |
163 | self.iter_info['loss1'] = loss1.data.item()
164 | self.iter_info['loss_class'] = loss_class.data.item()
165 | self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred
166 | self.iter_info['lr'] = '{:.6f}'.format(self.lr)
167 |
168 | loss1_value.append(self.iter_info['loss1'])
169 | loss_class_value.append(self.iter_info['loss_class'])
170 | loss_recon_value.append(self.iter_info['loss_recon'])
171 | self.show_iter_info()
172 | self.meta_info['iter'] += 1
173 |
174 | self.epoch_info['mean_loss1']= np.mean(loss1_value)
175 | self.epoch_info['mean_loss_class'] = np.mean(loss_class_value)
176 | self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value)
177 | self.show_epoch_info()
178 | self.io.print_timer()
179 |
180 |
181 | def test(self, evaluation=True, testing_A=False, save=False, save_feature=False):
182 |
183 | self.model1.eval()
184 | self.model2.eval()
185 | loader = self.data_loader['test']
186 | loss1_value = []
187 | loss_class_value = []
188 | loss_recon_value = []
189 | loss2_value = []
190 | loss_nll_value = []
191 | loss_kl_value = []
192 | result_frag = []
193 | label_frag = []
194 |
195 | if testing_A:
196 | A_all = []
197 | self.epoch_info.clear()
198 | for data, data_downsample, target_data, data_last, label in loader:
199 | data = data.float().to(self.dev)
200 | data_downsample = data_downsample.float().to(self.dev)
201 | label = label.long().to(self.dev)
202 |
203 | with torch.no_grad():
204 | A_batch, prob, outputs, data_bn = self.model2(data_downsample)
205 |
206 | if save:
207 | n = A_batch.size(0)
208 | a = A_batch[:int(n/2),:,:,:].cpu().numpy()
209 | A_all.extend(a)
210 |
211 | if evaluation:
212 | gpu_id = data.get_device()
213 | self.log_prior = self.log_prior.cuda(gpu_id)
214 | loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4)
215 | loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25)
216 | loss2 = loss_nll + loss_kl
217 |
218 | loss2_value.append(loss2.item())
219 | loss_nll_value.append(loss_nll.item())
220 | loss_kl_value.append(loss_kl.item())
221 |
222 | if save:
223 | A_all = np.array(A_all)
224 | np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all)
225 |
226 | if evaluation:
227 | self.epoch_info['mean_loss2'] = np.mean(loss2_value)
228 | self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value)
229 | self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value)
230 | self.show_epoch_info()
231 |
232 | else:
233 | recon_data = []
234 | feature_map = []
235 | self.epoch_info.clear()
236 | for data, data_downsample, target_data, data_last, label in loader:
237 | data = data.float().to(self.dev)
238 | data_downsample = data_downsample.float().to(self.dev)
239 | target_data = target_data.float().to(self.dev)
240 | data_last = data_last.float().to(self.dev)
241 | label = label.long().to(self.dev)
242 |
243 | with torch.no_grad():
244 | A_batch, prob, outputs, _ = self.model2(data_downsample)
245 | x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act)
246 | result_frag.append(x_class.data.cpu().numpy())
247 |
248 | if save:
249 | n = pred.size(0)
250 | p = pred[::2,:,:,:].cpu().numpy()
251 | recon_data.extend(p)
252 |
253 | if evaluation:
254 | loss_class = self.loss_class(x_class, label)
255 | loss_recon = self.loss_pred(pred, target)
256 | loss1 = loss_class + self.w_pred*loss_recon
257 |
258 | loss1_value.append(loss1.item())
259 | loss_class_value.append(loss_class.item())
260 | loss_recon_value.append(loss_recon.item())
261 | label_frag.append(label.data.cpu().numpy())
262 |
263 | if save:
264 | recon_data = np.array(recon_data)
265 | np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data)
266 |
267 |
268 | self.result = np.concatenate(result_frag)
269 | if evaluation:
270 | self.label = np.concatenate(label_frag)
271 | self.epoch_info['mean_loss1'] = np.mean(loss1_value)
272 | self.epoch_info['mean_loss_class'] = np.mean(loss_class_value)
273 | self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value)
274 | self.show_epoch_info()
275 |
276 | for k in self.arg.show_topk:
277 | hit_top_k = []
278 | rank = self.result.argsort()
279 | for i,l in enumerate(self.label):
280 | hit_top_k.append(l in rank[i, -k:])
281 | self.io.print_log('\n')
282 | accuracy = sum(hit_top_k)*1.0/len(hit_top_k)
283 | self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy))
284 |
285 |
286 |
287 | @staticmethod
288 | def get_parser(add_help=False):
289 |
290 | parent_parser = Processor.get_parser(add_help=False)
291 | parser = argparse.ArgumentParser(
292 | add_help=add_help,
293 | parents=[parent_parser],
294 | description='Spatial Temporal Graph Convolution Network')
295 |
296 | parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown')
297 | parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate')
298 | parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate')
299 | parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate')
300 | parser.add_argument('--optimizer', default='SGD', help='type of optimizer')
301 | parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not')
302 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer')
303 |
304 | parser.add_argument('--max_hop_dir', type=str, default='max_hop_4')
305 | parser.add_argument('--lamda_act', type=float, default=0.5)
306 | parser.add_argument('--lamda_act_dir', type=str, default='lamda_05')
307 |
308 | return parser
309 |
--------------------------------------------------------------------------------
/processor/torchlight_io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import os
4 | import sys
5 | import traceback
6 | import time
7 | import warnings
8 | import pickle
9 | from collections import OrderedDict
10 | import yaml
11 | import numpy as np
12 | # torch
13 | import torch
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.autograd import Variable
17 |
18 | with warnings.catch_warnings():
19 | warnings.filterwarnings("ignore", category=FutureWarning)
20 | import h5py
21 |
22 |
23 | class IO():
24 | def __init__(self, work_dir, save_log=True, print_log=True):
25 | self.work_dir = work_dir
26 | self.save_log = save_log
27 | self.print_to_screen = print_log
28 | self.cur_time = time.time()
29 | self.split_timer = {}
30 | self.pavi_logger = None
31 | self.session_file = None
32 | self.model_text = ''
33 |
34 | # PaviLogger is removed in this version
35 | def log(self, *args, **kwargs):
36 | pass
37 |
38 | # try:
39 | # if self.pavi_logger is None:
40 | # from torchpack.runner.hooks import PaviLogger
41 | # url = 'http://pavi.parrotsdnn.org/log'
42 | # with open(self.session_file, 'r') as f:
43 | # info = dict(
44 | # session_file=self.session_file,
45 | # session_text=f.read(),
46 | # model_text=self.model_text)
47 | # self.pavi_logger = PaviLogger(url)
48 | # self.pavi_logger.connect(self.work_dir, info=info)
49 | # self.pavi_logger.log(*args, **kwargs)
50 | # except: #pylint: disable=W0702
51 | # pass
52 |
53 | def load_model(self, model, **model_args):
54 | Model = import_class(model)
55 | model = Model(**model_args)
56 | self.model_text += '\n\n' + str(model)
57 | return model
58 |
59 | def load_weights(self, model, weights_path, ignore_weights=None):
60 | if ignore_weights is None:
61 | ignore_weights = []
62 | if isinstance(ignore_weights, str):
63 | ignore_weights = [ignore_weights]
64 |
65 | self.print_log('Load weights from {}.'.format(weights_path))
66 | weights = torch.load(weights_path)
67 | weights = OrderedDict([[k.split('module.')[-1],
68 | v.cpu()] for k, v in weights.items()])
69 |
70 | # filter weights
71 | for i in ignore_weights:
72 | ignore_name = list()
73 | for w in weights:
74 | if w.find(i) == 0:
75 | ignore_name.append(w)
76 | for n in ignore_name:
77 | weights.pop(n)
78 | self.print_log('Filter [{}] remove weights [{}].'.format(i, n))
79 |
80 | for w in weights:
81 | self.print_log('Load weights [{}].'.format(w))
82 |
83 | try:
84 | model.load_state_dict(weights)
85 | except (KeyError, RuntimeError):
86 | state = model.state_dict()
87 | diff = list(set(state.keys()).difference(set(weights.keys())))
88 | for d in diff:
89 | self.print_log('Can not find weights [{}].'.format(d))
90 | state.update(weights)
91 | model.load_state_dict(state)
92 | return model
93 |
94 | def save_pkl(self, result, filename):
95 | with open('{}/{}'.format(self.work_dir, filename), 'wb') as f:
96 | pickle.dump(result, f)
97 |
98 | def save_h5(self, result, filename):
99 | with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f:
100 | for k in result.keys():
101 | f[k] = result[k]
102 |
103 | def save_model(self, model, name):
104 | model_path = '{}/{}'.format(self.work_dir, name)
105 | state_dict = model.state_dict()
106 | weights = OrderedDict([[''.join(k.split('module.')),
107 | v.cpu()] for k, v in state_dict.items()])
108 | torch.save(weights, model_path)
109 | self.print_log('The model has been saved as {}.'.format(model_path))
110 |
111 | def save_arg(self, arg):
112 |
113 | self.session_file = '{}/config.yaml'.format(self.work_dir)
114 |
115 | # save arg
116 | arg_dict = vars(arg)
117 | if not os.path.exists(self.work_dir):
118 | os.makedirs(self.work_dir)
119 | with open(self.session_file, 'w') as f:
120 | f.write('# command line: {}\n\n'.format(' '.join(sys.argv)))
121 | yaml.dump(arg_dict, f, default_flow_style=False, indent=4)
122 |
123 | def print_log(self, str, print_time=True):
124 | if print_time:
125 | # localtime = time.asctime(time.localtime(time.time()))
126 | str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str
127 |
128 | if self.print_to_screen:
129 | print(str)
130 | if self.save_log:
131 | with open('{}/log.txt'.format(self.work_dir), 'a') as f:
132 | print(str, file=f)
133 |
134 | def init_timer(self, *name):
135 | self.record_time()
136 | self.split_timer = {k: 0.0000001 for k in name}
137 |
138 | def check_time(self, name):
139 | self.split_timer[name] += self.split_time()
140 |
141 | def record_time(self):
142 | self.cur_time = time.time()
143 | return self.cur_time
144 |
145 | def split_time(self):
146 | split_time = time.time() - self.cur_time
147 | self.record_time()
148 | return split_time
149 |
150 | def print_timer(self):
151 | proportion = {
152 | k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values()))))
153 | for k, v in self.split_timer.items()
154 | }
155 | self.print_log('Time consumption:')
156 | for k in proportion:
157 | self.print_log(
158 | '\t[{}][{}]: {:.4f}'.format(k, proportion[k], self.split_timer[k])
159 | )
160 |
161 |
162 | def str2bool(v):
163 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
164 | return True
165 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
166 | return False
167 | else:
168 | raise argparse.ArgumentTypeError('Boolean value expected.')
169 |
170 |
171 | def str2dict(v):
172 | return eval('dict({})'.format(v)) # pylint: disable=W0123
173 |
174 |
175 | def _import_class_0(name):
176 | components = name.split('.')
177 | mod = __import__(components[0])
178 | for comp in components[1:]:
179 | mod = getattr(mod, comp)
180 | return mod
181 |
182 |
183 | def import_class(import_str):
184 | mod_str, _sep, class_str = import_str.rpartition('.')
185 | __import__(mod_str)
186 | try:
187 | return getattr(sys.modules[mod_str], class_str)
188 | except AttributeError:
189 | raise ImportError('Class %s cannot be found (%s)' %
190 | (class_str,
191 | traceback.format_exception(*sys.exc_info())))
192 |
193 |
194 | class DictAction(argparse.Action):
195 | def __init__(self, option_strings, dest, nargs=None, **kwargs):
196 | if nargs is not None:
197 | raise ValueError("nargs not allowed")
198 | super(DictAction, self).__init__(option_strings, dest, **kwargs)
199 |
200 | def __call__(self, parser, namespace, values, option_string=None):
201 | input_dict = eval('dict({})'.format(values)) # pylint: disable=W0123
202 | output_dict = getattr(namespace, self.dest)
203 | for k in input_dict:
204 | output_dict[k] = input_dict[k]
205 | setattr(namespace, self.dest, output_dict)
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES=0,1 python main.py \
4 | --config config/train.yaml \
5 | # >> outs_files/22_03_11-ntu120xsub-angular.log 2>&1 &
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sys
3 | import traceback
4 |
5 |
6 | def import_class(name):
7 | components = name.split('.')
8 | mod = __import__(components[0])
9 | for comp in components[1:]:
10 | mod = getattr(mod, comp)
11 | return mod
12 |
13 |
14 | def import_class_2(import_str):
15 | mod_str, _sep, class_str = import_str.rpartition('.')
16 | __import__(mod_str)
17 | try:
18 | return getattr(sys.modules[mod_str], class_str)
19 | except AttributeError:
20 | raise ImportError('Class %s cannot be found (%s)' %
21 | (class_str,
22 | traceback.format_exception(*sys.exc_info())))
23 |
24 |
25 |
26 | def count_params(model):
27 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
28 |
29 |
30 | def get_current_time():
31 | currentDT = datetime.datetime.now()
32 | return str(currentDT.strftime("%Y-%m-%dT%H-%M-%S"))
33 |
--------------------------------------------------------------------------------
/utils_dir/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__init__.py
--------------------------------------------------------------------------------
/utils_dir/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/__pycache__/utils_cam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_cam.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/__pycache__/utils_io.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_io.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/__pycache__/utils_math.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_math.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/__pycache__/utils_result.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_result.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/__pycache__/utils_visual.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_visual.cpython-36.pyc
--------------------------------------------------------------------------------
/utils_dir/utils_cam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 |
5 |
6 | def plot_freq(cam_dict):
7 | total_num = 0
8 | for a_key in cam_dict:
9 | a_action_cam = cam_dict[a_key].cpu().numpy()
10 | print('action: ', a_key, 'cam shape: ', a_action_cam.shape)
11 | total_num += a_action_cam.shape[0]
12 | ax = sns.heatmap(a_action_cam)
13 | plt
14 | plt.savefig(f'test_fields/action_heatmaps/action_{a_key}_heatmap.png', dpi=200,
15 | bbox_inches='tight')
16 | # plt.show()
17 | plt.close()
18 | print('total num: ', total_num)
19 |
--------------------------------------------------------------------------------
/utils_dir/utils_io.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 |
4 |
5 | def mv_py_files_to_dir(a_dir, tgt_dir=None):
6 | if tgt_dir is None:
7 | py_dir = os.path.join(a_dir, 'py_dir')
8 | else:
9 | py_dir = tgt_dir
10 | if not os.path.exists(py_dir):
11 | os.makedirs(py_dir)
12 |
13 | # for root, dirs, files in os.walk(a_dir): # copy files
14 | # for file in files:
15 | # if file.endswith(".py"):
16 | # cp_tgt = os.path.join(root, file)
17 | # shutil.copy2(cp_tgt, py_dir)
18 |
19 | for item in os.listdir(a_dir):
20 | s = os.path.join(a_dir, item)
21 | d = os.path.join(py_dir, item)
22 | if os.path.isdir(s):
23 | shutil.copytree(s, d)
24 | else:
25 | shutil.copy2(s, d)
26 |
27 |
--------------------------------------------------------------------------------
/utils_dir/utils_math.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | class Embedder_DCT:
6 | def __init__(self, frm_len, multires, inc_input=True, inc_func='linear'):
7 | self.frm_len = frm_len
8 | self.multires = multires
9 | self.inc_input = inc_input
10 | self.inc_func = inc_func
11 | self.periodic_fns = [torch.cos]
12 |
13 | self.create_embedding_fn()
14 |
15 | def create_embedding_fn(self):
16 | embed_fns = []
17 | if self.inc_input:
18 | embed_fns.append(lambda x, y: x) # with x
19 |
20 | N_freqs = self.multires
21 |
22 | freq_bands = []
23 | for k in range(1, N_freqs+1):
24 | if self.inc_func == 'linear':
25 | a_freq = k
26 | elif self.inc_func == 'exp':
27 | a_freq = 2 ** (k-1)
28 | elif self.inc_func == 'pow':
29 | a_freq = k ** 2
30 | else:
31 | raise NotImplementedError('Unsupported inc_func.')
32 |
33 | freq_bands.append(math.pi / self.frm_len * a_freq) # This is DCT
34 |
35 | freq_bands = torch.tensor(freq_bands)
36 |
37 | for freq in freq_bands:
38 | for p_fn in self.periodic_fns:
39 | embed_fns.append(lambda x, frm_idx, p_fn=p_fn, freq=freq: (x * p_fn(freq * (frm_idx + 1/2)))) # this is DCT
40 |
41 | self.embed_fns = embed_fns
42 |
43 | def embed(self, inputs, dim):
44 | t_len_all = inputs.shape[2]
45 | time_list = []
46 | for t_idx in range(t_len_all):
47 | a_series = inputs[:, :, t_idx, :, :].unsqueeze(2)
48 | # new_time_list = torch.cat([fn(a_series, t_idx) for fn in self.embed_fns], dim) # DCT
49 |
50 | # To try positional encoding
51 | new_time_list = []
52 | for fn in self.embed_fns:
53 | a_new_one = fn(a_series, t_idx)
54 | new_time_list.append(a_new_one)
55 | new_time_list = torch.cat(new_time_list, dim)
56 |
57 | # To sum encodes
58 | # new_time_list = None
59 | # for fn in self.embed_fns:
60 | # if new_time_list is None:
61 | # new_time_list = fn(a_series, t_idx)
62 | # else:
63 | # new_time_list += fn(a_series, t_idx)
64 |
65 | # print('new_time_list: ', new_time_list.squeeze())
66 | time_list.append(new_time_list)
67 | rtn = torch.cat(time_list, 2)
68 | return rtn
69 |
70 |
71 | an_embed = Embedder_DCT(300, 8)
72 |
73 |
74 | def gen_dct_on_the_fly(the_data, K=None):
75 | N, C, T, V, M = the_data.shape
76 | if K is None:
77 | K = T
78 |
79 | the_data = an_embed.embed(the_data, dim=1)
80 | return the_data
81 |
82 |
83 | def dct_2_no_sum_parallel(bch_seq, K0=0, K1=None):
84 | bch_seq_rsp = bch_seq.view(-1, bch_seq.shape[1])
85 | N = bch_seq_rsp.shape[1]
86 | if K1 is None:
87 | K1 = N
88 | basis_list = []
89 | for k in range(K0, K1):
90 | a_basis_list = []
91 | for i in range(N):
92 | a_basis_list.append(math.cos(math.pi / N * (i + 0.5) * k))
93 | basis_list.append(a_basis_list)
94 | basis_list = torch.tensor(basis_list).to(bch_seq_rsp.device)
95 | bch_seq_rsp = bch_seq_rsp.unsqueeze(1).repeat(1, K1 - K0, 1)
96 | dot_prod = torch.einsum('abc,bc->abc', bch_seq_rsp, basis_list)
97 | return dot_prod.view(-1, K1 - K0)
98 |
--------------------------------------------------------------------------------
/utils_dir/utils_result.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | from sklearn.metrics import confusion_matrix
5 |
6 | from metadata.class_labels import ntu120_code_labels, anu_bullying_pair_labels, bly_labels, anubis_ind_actions
7 | # from test_fields.kinetics_analysis import get_kinetics_dict
8 | import numpy as np
9 |
10 |
11 | def get_result_confusion_jsons(gt, pred, data_type, acc_f_name_prefix=None):
12 | if 'ntu' in data_type:
13 | code_labels = ntu120_code_labels
14 | elif 'bly' in data_type:
15 | code_labels = bly_labels
16 | gt = np.array(gt)[:, 0]
17 | elif 'front' in data_type:
18 | code_labels = {}
19 | tmp = {}
20 | for key in anubis_ind_actions:
21 | tmp[anubis_ind_actions[key]] = key
22 | for key in tmp:
23 | code_labels[tmp[key] + 1] = key
24 | # elif 'kinetics' in data_type:
25 | # code_labels = get_kinetics_dict()
26 | else:
27 | raise NotImplementedError
28 |
29 | correct_dict = defaultdict(list)
30 | for idx in range(len(gt)):
31 | correct_dict[gt[idx]].append(int(pred[idx] == gt[idx]))
32 | correct_dict_ = correct_dict.copy()
33 |
34 | for a_key in correct_dict:
35 | correct_dict[a_key] = '{:.6f}'.format(sum(correct_dict[a_key]) / len(correct_dict[a_key]))
36 |
37 | label_acc = {}
38 | for a_key in correct_dict:
39 | label_acc[code_labels[int(a_key) + 1]] = float(correct_dict[a_key])
40 |
41 | label_acc = dict(sorted(label_acc.items(), key=lambda item: item[1]))
42 | label_acc_keys = list(label_acc.keys())
43 |
44 | conf_mat = confusion_matrix(gt, pred)
45 |
46 | most_confused = {}
47 | for i in correct_dict.keys():
48 | confusion_0 = np.argsort(conf_mat[int(i)])[::-1][0]
49 | confusion_1 = np.argsort(conf_mat[int(i)])[::-1][1]
50 |
51 | most_confused[code_labels[int(i) + 1]] = [
52 | "{} {}".format(code_labels[confusion_0 + 1], conf_mat[int(i)][confusion_0]),
53 | "{} {}".format(code_labels[confusion_1 + 1], conf_mat[int(i)][confusion_1]),
54 | "{}".format(len(correct_dict_[i]))
55 | ]
56 | most_confused_ = {}
57 | for i in label_acc_keys:
58 | most_confused_[i] = most_confused[i]
59 |
60 | if acc_f_name_prefix is not None:
61 | with open('{}_confusion_matrix.json'.format(acc_f_name_prefix), 'w') as f:
62 | json.dump(most_confused_, f, indent=4)
63 | with open('{}_accuracy_per_class.json'.format(acc_f_name_prefix), 'w') as f:
64 | json.dump(label_acc, f, indent=4)
65 |
66 | return label_acc, most_confused_
67 |
68 |
--------------------------------------------------------------------------------
/utils_dir/utils_visual.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | from matplotlib.animation import FuncAnimation
4 | import numpy as np
5 | import torch
6 |
7 | azure_kinect_bone_pairs = (
8 | (1, 0), (2, 1), (3, 2), (4, 2), (5, 4), (6, 5), (7, 6), (8, 7), (9, 8), (10, 7), (11, 2), (12, 11),
9 | (13, 12), (14, 13), (15, 14), (16, 15), (17, 14), (18, 0), (19, 18), (20, 19), (21, 20), (22, 0),
10 | (23, 22), (24, 23), (25, 24), (26, 3), (27, 26), (28, 27), (29, 28), (30, 27), (31, 30)
11 | )
12 |
13 | kinect_v2_bone_pairs = tuple((i - 1, j - 1) for (i, j) in (
14 | (1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5),
15 | (7, 6), (8, 7), (9, 21), (10, 9), (11, 10), (12, 11),
16 | (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), (18, 17),
17 | (19, 18), (20, 19), (22, 23), (21, 21), (23, 8), (24, 25), (25, 12)
18 | ))
19 |
20 | bone_pair_dict = {
21 | 'azure_kinect': azure_kinect_bone_pairs,
22 | 'kinect_v2': kinect_v2_bone_pairs
23 | }
24 |
25 |
26 | def azure_kinect_post_visualize(frames, save_name=None, sklt_type='azure_kinect'):
27 | min_x, max_x = torch.min(frames[0][0]).item(), torch.max(frames[0][0]).item()
28 | min_y, max_y = torch.min(frames[0][1]).item(), torch.max(frames[0][1]).item()
29 | min_z, max_z = torch.min(frames[0][2]).item(), torch.max(frames[0][2]).item()
30 |
31 | bones = bone_pair_dict[sklt_type]
32 | def animate(skeletons):
33 | # Skeleton shape is 3*25. 3 corresponds to the 3D coordinates. 25 is the number of joints.
34 | ax.clear()
35 |
36 | # ax.set_xlim([-3000, 1000])
37 | # ax.set_ylim([-3000, 1000])
38 | # ax.set_zlim([-3000, 1000])
39 |
40 | ax.set_xlim([min_x, max_x])
41 | ax.set_ylim([min_y, max_y])
42 | ax.set_zlim([min_z, max_z])
43 |
44 | # ax.set_xlim([-2, -1])
45 | # ax.set_ylim([2, 3])
46 | # ax.set_zlim([-2, 0])
47 |
48 | # ax.set_xticklabels([])
49 | # ax.set_yticklabels([])
50 | # ax.set_zticklabels([])
51 |
52 | # person 1
53 | k = 0
54 | color_list = ('blue', 'orange', 'cyan', 'purple')
55 | # color_list = ('blue', 'blue', 'cyan', 'purple')
56 | color_idx = 0
57 |
58 | while k < skeletons.shape[0]:
59 | for i, j in bones:
60 | joint_locs = skeletons[:, [i, j]]
61 | # plot them
62 | ax.plot(joint_locs[k+0], joint_locs[k+1], joint_locs[k+2], color=color_list[color_idx])
63 | # ax.plot(-joint_locs[k+0], -joint_locs[k+2], -joint_locs[k+1], color=color_list[color_idx])
64 |
65 | k += 3
66 | color_idx = (color_idx + 1) % len(color_list)
67 |
68 | # Rotate
69 | # X, Y, Z = axes3d.get_test_data(0.1)
70 | # ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5)
71 | #
72 | # # rotate the axes and update
73 | # for angle in range(0, 360):
74 | # ax.view_init(30, angle)
75 |
76 | if save_name is None:
77 | title = 'Action Visualization'
78 | else:
79 | title = os.path.split(save_name)[-1]
80 | plt.title(title)
81 | skeleton_index[0] += 1
82 | return ax
83 |
84 | for an_entry in range(1):
85 | if isinstance(an_entry, tuple) and len(an_entry) == 2:
86 | index = int(an_entry[0])
87 | pred_idx = int(an_entry[1])
88 | else:
89 | index = an_entry
90 | # get data
91 | skeletons = np.copy(frames[index])
92 |
93 | fig = plt.figure()
94 | ax = fig.gca(projection='3d')
95 | # ax.set_xlim([-1, 1])
96 | # ax.set_ylim([-1, 1])
97 | # ax.set_zlim([-1, 1])
98 |
99 | # print(f'Sample index: {index}\nAction: {action_class}-{action_name}\n') # (C,T,V,M)
100 |
101 | # Pick the first body to visualize
102 | skeleton1 = skeletons[..., 0] # out (C,T,V)
103 | # make it shorter
104 | shorter_frame_start = 0
105 | shorter_frame_end = 300
106 | if skeletons.shape[-1] > 1:
107 | skeleton2 = np.copy(skeletons[..., 1]) # out (C,T,V)
108 | # make it shorter
109 | skeleton2 = skeleton2[:, shorter_frame_start:shorter_frame_end, :]
110 | # print('max of skeleton 2: ', np.max(skeleton2))
111 | skeleton_frames_2 = skeleton2.transpose(1, 0, 2)
112 | else:
113 | skeleton_frames_2 = None
114 |
115 | skeleton_index = [0]
116 | skeleton_frames_1 = skeleton1[:, shorter_frame_start:shorter_frame_end, :].transpose(1, 0, 2)
117 |
118 | if skeleton_frames_2 is None:
119 | # skeleton_frames_1 = center_normalize_skeleton(skeleton_frames_1)
120 | ani = FuncAnimation(fig, animate,
121 | skeleton_frames_1,
122 | interval=150)
123 | else:
124 | # skeleton_frames_1 = center_normalize_skeleton(skeleton_frames_1)
125 | # skeleton_frames_2 = center_normalize_skeleton(skeleton_frames_2)
126 | ani = FuncAnimation(fig, animate,
127 | np.concatenate((skeleton_frames_1, skeleton_frames_2), axis=1),
128 | interval=150)
129 |
130 | if save_name is None:
131 | save_name = 'tmp_skeleton_video_2.mp4'
132 | print('save name: ', save_name)
133 | ani.save(save_name, dpi=200, writer='ffmpeg')
134 | plt.close('all')
135 |
136 |
137 | def plot_multiple_lines(lines, save_name=None, labels=None, every_n=1):
138 | font = {'size': 30}
139 | import matplotlib
140 | matplotlib.rc('font', **font)
141 |
142 | if labels is not None:
143 | assert len(lines) == len(labels)
144 | markers = ['^', '.', '*']
145 | # colors = ['#dd0100', '#225095', '#fac901']
146 | colors = ['#dd0100']
147 |
148 | # plt.xlim(0, len(lines[0])) # Chronological loss value
149 | # plt.ylim(-0.25, 1.25) # Chronological loss value
150 |
151 | plt.xlim(0, len(lines[0]))
152 | plt.ylim(0, 9)
153 |
154 | x_axis_list = list([x for x in range(0, len(lines[0]), every_n)])
155 | for line_idx, a_line in enumerate(lines):
156 | a_line_plot = list([a_line[i] for i in range(0, len(lines[0]), every_n)])
157 | plt.plot(x_axis_list, a_line_plot,
158 | color=colors[line_idx % len(colors)],
159 | marker=markers[line_idx % len(markers)],
160 | markersize=30,
161 | label=labels[line_idx] if labels is not None else None)
162 |
163 | # plt.xticks(x_axis_list, a_line_plot)
164 | plt.grid()
165 | # plt.legend(loc=4) # Chronological loss value
166 |
167 | fig = matplotlib.pyplot.gcf()
168 |
169 | # fig.set_size_inches(7, 10) # Chronological loss value
170 | fig.set_size_inches(10, 8)
171 |
172 | if save_name is None:
173 | plt.show()
174 | else:
175 | plt.savefig(save_name, bbox_inches='tight')
176 | plt.close()
177 | plt.show()
178 |
--------------------------------------------------------------------------------