├── .gitignore ├── .idea ├── .gitignore ├── Angular.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── __init__.py ├── __pycache__ └── utils.cpython-36.pyc ├── ang_adjs.py ├── class_labels.py ├── config ├── test.yaml └── train.yaml ├── encoding ├── __pycache__ │ └── data_encoder.cpython-36.pyc └── data_encoder.py ├── feeders ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── feeder.cpython-36.pyc │ ├── feeder.cpython-37.pyc │ ├── feeder_as_gcn.cpython-36.pyc │ ├── feeder_as_gcn.cpython-37.pyc │ ├── feeder_dgnn.cpython-36.pyc │ ├── feeder_dgnn.cpython-37.pyc │ ├── tools.cpython-36.pyc │ └── tools.cpython-37.pyc ├── feeder.py ├── feeder_as_gcn.py ├── feeder_dgnn.py └── tools.py ├── figures ├── Architecture.png ├── angle.png └── skeletons.png ├── graph ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── ang_adjs.cpython-36.pyc │ ├── ang_adjs.cpython-37.pyc │ ├── azure_kinect.cpython-36.pyc │ ├── azure_kinect.cpython-37.pyc │ ├── directed_ntu_rgb_d.cpython-36.pyc │ ├── directed_ntu_rgb_d.cpython-37.pyc │ ├── hyper_graphs.cpython-36.pyc │ ├── hyper_graphs.cpython-37.pyc │ ├── kinetics.cpython-36.pyc │ ├── kinetics.cpython-37.pyc │ ├── ntu_rgb_d.cpython-36.pyc │ ├── ntu_rgb_d.cpython-37.pyc │ ├── tools.cpython-36.pyc │ └── tools.cpython-37.pyc ├── ang_adjs.py ├── azure_kinect.py ├── directed_ntu_rgb_d.py ├── hyper_graphs.py ├── kinetics.py ├── ntu_rgb_d.py └── tools.py ├── main.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── activation.cpython-36.pyc │ ├── att_gcn.cpython-36.pyc │ ├── hyper_gcn.cpython-36.pyc │ ├── mlp.cpython-36.pyc │ ├── modules.cpython-36.pyc │ ├── ms_gcn.cpython-36.pyc │ ├── ms_gtcn.cpython-36.pyc │ ├── ms_tcn.cpython-36.pyc │ └── network.cpython-36.pyc ├── activation.py ├── att_gcn.py ├── hyper_gcn.py ├── mlp.py ├── modules.py ├── ms_gcn.py ├── ms_gtcn.py ├── ms_tcn.py └── network.py ├── notification ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── email_config.cpython-36.pyc │ ├── email_config.cpython-37.pyc │ ├── email_sender.cpython-36.pyc │ ├── email_sender.cpython-37.pyc │ ├── email_templates.cpython-36.pyc │ ├── email_templates.cpython-37.pyc │ ├── html_templates.cpython-36.pyc │ └── html_templates.cpython-37.pyc ├── email_config.py ├── email_sender.py ├── email_templates.py ├── email_test.py └── html_templates.py ├── processor ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── args.cpython-36.pyc ├── args.py ├── io.py ├── processor.py ├── recognition.py └── torchlight_io.py ├── train.sh ├── utils.py └── utils_dir ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── utils_cam.cpython-36.pyc ├── utils_io.cpython-36.pyc ├── utils_math.cpython-36.pyc ├── utils_result.cpython-36.pyc └── utils_visual.cpython-36.pyc ├── utils_cam.py ├── utils_io.py ├── utils_math.py ├── utils_result.py └── utils_visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | work_dir 2 | apex -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/Angular.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 179 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |
4 | Angular Encoding for Skeleton-Based Action Recognition 5 |
6 |

7 | 8 |

9 | Overview 10 |

11 | 12 |

13 | 14 | PyTorch implementation of "TNNLS 2022: Fusing Higher-Order Features in Graph Neural Networks for Skeleton-Based Action Recognition". 15 | (https://arxiv.org/pdf/2105.01563.pdf). 16 | 17 |

18 | 19 | ## Angular Features 20 | 21 | 22 | 23 | ## Network Architecture 24 | 25 | 26 | 27 | ## Dependencies 28 | 29 | - Python >= 3.6 30 | - PyTorch >= 1.2.0 31 | - [NVIDIA Apex](https://github.com/NVIDIA/apex) (auto mixed precision training) 32 | - PyYAML, tqdm, tensorboardX, matplotlib, seaborn 33 | 34 | ## Data Preparation 35 | 36 | ### Download Datasets 37 | 38 | There are 2 datasets to download: 39 | - NTU RGB+D 60 Skeleton 40 | - NTU RGB+D 120 Skeleton 41 | 42 | Request the datasets here: http://rose1.ntu.edu.sg/Datasets/actionRecognition.asp 43 | 44 | ### Data Preprocessing 45 | 46 | #### Directory Structure 47 | 48 | Put downloaded data into the following directory structure: 49 | 50 | ``` 51 | - data/ 52 | - nturgbd_raw/ 53 | - nturgb+d_skeletons/ # from `nturgbd_skeletons_s001_to_s017.zip` 54 | ... 55 | - nturgb+d_skeletons120/ # from `nturgbd_skeletons_s018_to_s032.zip` 56 | ``` 57 | 58 | #### Generating Data 59 | 60 | - `cd data_gen` 61 | - `python3 ntu_gendata.py` 62 | - `python3 ntu120_gendata.py` 63 | - This can take hours. Better CPUs lead to much faster processing. 64 | 65 | ## Training 66 | ``` 67 | bash train.sh 68 | ``` 69 | 70 | ## Testing 71 | ``` 72 | bash test.sh 73 | ``` 74 | 75 | ## Acknowledgements 76 | 77 | This repo is based on 78 | - [MS-G3D](https://github.com/kenziyuliu/ms-g3d) 79 | - [2s-AGCN](https://github.com/lshiwjx/2s-AGCN) 80 | - [ST-GCN](https://github.com/yysijie/st-gcn) 81 | 82 | Thanks to the original authors for their work! 83 | 84 | The flat icon is from [Freepik](https://www.freepik.com/). 85 | 86 | ## Citation 87 | 88 | Please cite this work if you find it useful: 89 | 90 | ``` 91 | @article{DBLP:journals/corr/abs-2105-01563, 92 | author = {Zhenyue Qin and Yang Liu and Pan Ji and Dongwoo Kim and Lei Wang and 93 | Bob McKay and Saeed Anwar and Tom Gedeon}, 94 | title = {Fusing Higher-Order Features in Graph Neural Networks for Skeleton-based Action Recognition}, 95 | journal = {IEEE Transactions on Neural Networks and Learning Systems (TNNLS)}, 96 | year = {2022} 97 | } 98 | ``` 99 | 100 | 101 | ## Contact 102 | If you have further question, please email `zhenyue.qin@anu.edu.au` or `yang.liu3@anu.edu.au`. 103 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import ntu_rgb_d 3 | from . import kinetics 4 | from . import azure_kinect 5 | from . import directed_ntu_rgb_d 6 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /ang_adjs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from graph.tools import normalize_adjacency_matrix 5 | 6 | 7 | def get_ang_adjs(data_type): 8 | rtn_adjs = [] 9 | 10 | if data_type == 'ntu': 11 | node_num = 25 12 | sym_pairs = ((21, 2), (11, 7), (18, 14), (20, 16), (24, 25), (22, 23)) 13 | for a_sym_pair in sym_pairs: 14 | a_adj = np.eye(node_num) 15 | a_adj[:, a_sym_pair[0]-1] = 1 16 | a_adj[:, a_sym_pair[1]-1] = 1 17 | a_adj = torch.tensor(normalize_adjacency_matrix(a_adj)) 18 | rtn_adjs.append(a_adj) 19 | 20 | return torch.cat(rtn_adjs, dim=0) 21 | -------------------------------------------------------------------------------- /class_labels.py: -------------------------------------------------------------------------------- 1 | anu_bullying_pair_labels = { 2 | 1: 'G1A1: hit with knees', 3 | 2: 'G2A1: hit with head', 4 | 3: 'G3A1: punch to face', 5 | 4: 'G4A1: punch to body', 6 | 5: 'G5A1: cover mouth', 7 | 6: 'G6A1: pinch neck', 8 | 7: 'G7A1: slap', 9 | 8: 'G8A1: kicking', 10 | 9: 'G9A1: pushing', 11 | 10: 'G10A1: pierce others', 12 | 11: 'G11A1: pull hairs', 13 | 12: 'G12A1: drag other person', 14 | 13: 'G13A1: pull collar', 15 | 14: 'G14A1: swing others', 16 | 15: 'G15A1: beat with elbow', 17 | 16: 'G16A1: knoch over', 18 | 17: 'G17A1: hit with object', 19 | 18: 'G18A1: point to person', 20 | 19: 'G19A1: cuff ear', 21 | 20: 'G20A1: pinch arms', 22 | 21: 'G21A1: use cigarette to burn', 23 | 22: 'G22A1: sidekick person', 24 | 23: 'G23A1: cast to person', 25 | 24: 'G24A1: shoot person', 26 | 25: 'G25A1: stab person', 27 | 26: 'G26A1: wave knife to others', 28 | 27: 'G27A1: splash liquid on person', 29 | 28: 'G28A1: stumble person', 30 | 29: 'G29A1: step on foot', 31 | 30: 'G30A1: touch pocket', 32 | 31: 'G31A1: bite person', 33 | 32: 'G32A1: take picture for others', 34 | 33: 'G33A1: spiting to person', 35 | 34: 'G34A1: chop person', 36 | 35: 'G35A1: take chair while other sitting', 37 | 36: 'G36A1: pat on head', 38 | 37: 'G37A1: pinch face', 39 | 38: 'G38A1: pinch body', 40 | 39: 'G39A1: follow person', 41 | 40: 'G40A1: belt person', 42 | 41: 'G1A2: nod head', 43 | 42: 'G2A2: bow', 44 | 43: 'G3A2: shake hands', 45 | 44: 'G4A2: rock-paper-scissors', 46 | 45: 'G5A2: touch elbows', 47 | 46: 'G6A2: wave hand', 48 | 47: 'G7A2: fist bumping', 49 | 48: 'G8A2: pat on back', 50 | 49: 'G9A2: giving object', 51 | 50: 'G10A2: exchange object', 52 | 51: 'G11A2: clapping; hushing', 53 | 52: 'G12A2: drink water; brush teeth', 54 | 53: 'G13A2: stand up; jump up', 55 | 54: 'G14A2: take off a hat; play a phone', 56 | 55: 'G15A2: take a selfie; wipe face', 57 | 56: 'G16A2: cross hands in front; throat-slitting', 58 | 57: 'G17A2: crawling; open bottle', 59 | 58: 'G18A2: sneeze; yawn', 60 | 59: 'G19A2: self-cutting with knife; take off headphone', 61 | 60: 'G20A2: stretch oneself; flick hair', 62 | 61: 'G21A2: thumb up; thumb down', 63 | 62: 'G22A2: make ok sign; make victory sign', 64 | 63: 'G23A2: cutting nails; cutting paper', 65 | 64: 'G24A2: squat down; toss a coin', 66 | 65: 'G25A2: fold paper; ball up paper', 67 | 66: 'G26A2: play magic cube; surrender', 68 | 67: 'G27A2: apply cream on face; apply cream on hand', 69 | 68: 'G28A2: put on bag; take off bag', 70 | 69: 'G29A2: put object into bag; take object out of bag', 71 | 70: 'G30A2: open a box; yelling', 72 | 71: 'G31A2: arm circles; arm swings', 73 | 72: 'G32A2: whisper', 74 | 73: 'G33A2: clapping each other', 75 | 74: 'G34A2: running; vomiting', 76 | 75: 'G35A2: walk apart', 77 | 76: 'G36A2: headache; back pain', 78 | 77: 'G37A2: walk toward', 79 | 78: 'G38A2: falling down; chest pain', 80 | 79: 'G39A2: walk and hold person', 81 | 80: 'G40A2: cheers and drink', 82 | } 83 | 84 | ntu120_code_labels = { 85 | 1: 'drink water', 86 | 2: 'eat meal/snack"', 87 | 3: "brushing teeth", 88 | 4: "brushing hair", 89 | 5: "drop", 90 | 6: "pickup", 91 | 7: "throw", 92 | 8: "sitting down", 93 | 9: "standing up (from sitting position)", 94 | 10: "clapping", 95 | 11: "reading", 96 | 12: "writing", 97 | 13: "tear up paper", 98 | 14: "wear jacket", 99 | 15: "take off jacket", 100 | 16: "wear a shoe", 101 | 17: "take off a shoe", 102 | 18: "wear on glasses", 103 | 19: "take off glasses", 104 | 20: "put on a hat/cap", 105 | 21: "take off a hat/cap", 106 | 22: "cheer up", 107 | 23: "hand waving", 108 | 24: "kicking something", 109 | 25: "reach into pocket", 110 | 26: "hopping (one foot jumping)", 111 | 27: "jump up", 112 | 28: "make a phone call/answer phone", 113 | 29: "playing with phone/tablet", 114 | 30: "typing on a keyboard", 115 | 31: "pointing to something with finger", 116 | 32: "taking a selfie", 117 | 33: "check time (from watch)", 118 | 34: "rub two hands together", 119 | 35: "nod head/bow", 120 | 36: "shake head", 121 | 37: "wipe face", 122 | 38: "salute", 123 | 39: "put the palms together", 124 | 40: "cross hands in front (say stop)", 125 | 41: "sneeze/cough", 126 | 42: "staggering", 127 | 43: "falling", 128 | 44: "touch head (headache)", 129 | 45: "touch chest (stomachache/heart pain)", 130 | 46: "touch back (backache)", 131 | 47: "touch neck (neckache)", 132 | 48: "nausea or vomiting condition", 133 | 49: "use a fan (with hand or paper)/feeling warm", 134 | 50: "punching/slapping other person", 135 | 51: "kicking other person", 136 | 52: "pushing other person", 137 | 53: "pat on back of other person", 138 | 54: "point finger at the other person", 139 | 55: "hugging other person", 140 | 56: "giving something to other person", 141 | 57: "touch other person's pocket", 142 | 58: "handshaking", 143 | 59: "walking towards each other", 144 | 60: "walking apart from each other", 145 | 61: "put on headphone", 146 | 62: "take off headphone", 147 | 63: "shoot at the basket", 148 | 64: "bounce ball", 149 | 65: "tennis bat swing", 150 | 66: "juggling table tennis balls", 151 | 67: "hush (quite)", 152 | 68: "flick hair", 153 | 69: "thumb up", 154 | 70: "thumb down", 155 | 71: "make ok sign", 156 | 72: "make victory sign", 157 | 73: "staple book", 158 | 74: "counting money", 159 | 75: "cutting nails", 160 | 76: "cutting paper (using scissors)", 161 | 77: "snapping fingers", 162 | 78: "open bottle", 163 | 79: "sniff (smell)", 164 | 80: "squat down", 165 | 81: "toss a coin", 166 | 82: "fold paper", 167 | 83: "ball up paper", 168 | 84: "play magic cube", 169 | 85: "apply cream on face", 170 | 86: "apply cream on hand back", 171 | 87: "put on bag", 172 | 88: "take off bag", 173 | 89: "put something into a bag", 174 | 90: "take something out of a bag", 175 | 91: "open a box", 176 | 92: "move heavy objects", 177 | 93: "shake fist", 178 | 94: "throw up cap/hat", 179 | 95: "hands up (both hands)", 180 | 96: "cross arms", 181 | 97: "arm circles", 182 | 98: "arm swings", 183 | 99: "running on the spot", 184 | 100: "butt kicks (kick backward)", 185 | 101: "cross toe touch", 186 | 102: "side kick", 187 | 103: "yawn", 188 | 104: "stretch oneself", 189 | 105: "blow nose", 190 | 106: "hit other person with something", 191 | 107: "wield knife towards other person", 192 | 108: "knock over other person (hit with body)", 193 | 109: "grab other person’s stuff", 194 | 110: "shoot at other person with a gun", 195 | 111: "step on foot", 196 | 112: "high-five", 197 | 113: "cheers and drink", 198 | 114: "carry something with other person", 199 | 115: "take a photo of other person", 200 | 116: "follow other person", 201 | 117: "whisper in other person’s ear", 202 | 118: "exchange things with other person", 203 | 119: "support somebody with hand", 204 | 120: "finger-guessing game (playing rock-paper-scissors)", 205 | } 206 | 207 | bly_labels = { 208 | 1: 'G1A1: hit with knees', 209 | 2: 'G2A1: hit with head', 210 | 3: 'G3A1: punch to face', 211 | 4: 'G4A1: punch to body', 212 | 5: 'G5A1: cover mouth', 213 | 6: 'G6A1: pinch neck', 214 | 7: 'G7A1: slap', 215 | 8: 'G8A1: kicking', 216 | 9: 'G9A1: pushing', 217 | 10: 'G10A1: pierce others', 218 | 11: 'G11A1: pull hairs', 219 | 12: 'G12A1: drag other person', 220 | 13: 'G13A1: pull collar', 221 | 14: 'G14A1: swing others', 222 | 15: 'G15A1: beat with elbow', 223 | 16: 'G16A1: knoch over', 224 | 17: 'G17A1: hit with object', 225 | 18: 'G18A1: point to person', 226 | 19: 'G19A1: cuff ear', 227 | 20: 'G20A1: pinch arms', 228 | 21: 'G21A1: use cigarette to burn', 229 | 22: 'G22A1: sidekick person', 230 | 23: 'G23A1: cast to person', 231 | 24: 'G24A1: shoot person', 232 | 25: 'G25A1: stab person', 233 | 26: 'G26A1: wave knife to others', 234 | 27: 'G27A1: splash liquid on person', 235 | 28: 'G28A1: stumble person', 236 | 29: 'G29A1: step on foot', 237 | 30: 'G30A1: touch pocket', 238 | 31: 'G31A1: bite person', 239 | 32: 'G33A1: spiting to person', 240 | 33: 'G34A1: chop person', 241 | 34: 'G35A1: take chair while other sitting', 242 | 35: 'G36A1: pat on head', 243 | 36: 'G37A1: pinch face', 244 | 37: 'G38A1: pinch body', 245 | 38: 'G40A1: belt person', 246 | } 247 | 248 | anubis_ind_actions = { 249 | 0: 'hit with knees', 250 | 1: 'hit with head', 251 | 2: 'punch to face', 252 | 3: 'punch to body', 253 | 4: 'cover mouth', 254 | 5: 'strangling neck', 255 | 6: 'slap', 256 | 7: 'kicking', 257 | 8: 'pushing', 258 | 9: 'pierce arms', 259 | 10: 'pull hairs', 260 | 11: 'drag other person (other resist)', 261 | 12: 'pull collar', 262 | 13: 'shake others', 263 | 14: 'beat with elbow', 264 | 15: 'knock over', 265 | 16: 'hit with object', 266 | 17: 'point to person', 267 | 18: 'lift ear', 268 | 19: 'pinch arms', 269 | 20: 'use cigarette to burn', 270 | 21: 'sidekick person', 271 | 22: 'pick and throw an object to person', 272 | 23: 'shoot person', 273 | 24: 'stab person', 274 | 25: 'wave knife to others', 275 | 26: 'splash liquid on person', 276 | 27: 'stumble person', 277 | 28: 'step on foot', 278 | 29: 'pickpocketing', 279 | 30: 'bite person', 280 | 31: 'take picture for others (sneakily)', 281 | 32: 'spiting to person', 282 | 33: 'chop (cut) person', 283 | 34: 'take chair while other sitting', 284 | 35: 'hit the head with hand', 285 | 36: 'pinch face with two hands', 286 | 37: 'pinch body (not arm)', 287 | 38: 'follow person', 288 | 39: 'belt person', 289 | 40: 'nod head', 290 | 41: 'bow', 291 | 42: 'shake hands', 292 | 43: 'rock-paper-scissors', 293 | 44: 'touch elbows', 294 | 45: 'wave hand', 295 | 46: 'fist bumping', 296 | 47: 'pat on shoulders', 297 | 48: 'giving object', 298 | 49: 'exchange object', 299 | 50: 'clapping and hushing', 300 | 51: 'drink water', 301 | 52: 'brush teeth', 302 | 53: 'stand up', 303 | 54: 'jump up', 304 | 55: 'take off a hat', 305 | 56: 'play a phone', 306 | 57: 'take a selfie', 307 | 58: 'wipe face', 308 | 59: 'cross hands in front', 309 | 60: 'throat-slitting', 310 | 61: 'crawling', 311 | 62: 'open bottle', 312 | 63: 'sneeze', 313 | 64: 'yawn', 314 | 65: 'self-cutting with knife', 315 | 66: 'take off headphone', 316 | 67: 'stretch oneself', 317 | 68: 'flick hair', 318 | 69: 'thumb up', 319 | 70: 'thumb down', 320 | 71: 'make ok sign', 321 | 72: 'make victory sign', 322 | 73: 'cutting nails', 323 | 74: 'cutting paper', 324 | 75: 'squat down', 325 | 76: 'toss a coin', 326 | 77: 'fold paper', 327 | 78: 'ball up paper', 328 | 79: 'play magic cube', 329 | 80: 'surrender', 330 | 81: 'apply cream on face', 331 | 82: 'apply cream on hand', 332 | 83: 'put on bag', 333 | 84: 'take off bag', 334 | 85: 'put object into bag', 335 | 86: 'take object out of bag', 336 | 87: 'open a box and yelling', 337 | 88: 'arm circles', 338 | 89: 'arm swings', 339 | 90: 'whisper', 340 | 91: 'clapping each other', 341 | 92: 'running', 342 | 93: 'vomiting', 343 | 94: 'walk apart', 344 | 95: 'headache', 345 | 96: 'back pain', 346 | 97: 'walk form apart to together', 347 | 98: 'falling down', 348 | 99: 'chest pain', 349 | 100: 'support with arms for old people walking', 350 | 101: 'cheers and drink'} 351 | -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- 1 | 2 | work_dir: ./work_dir/ntu120_xsub_2021_angular_test 3 | 4 | # feeder 5 | feeder: feeders.feeder.Feeder 6 | 7 | test_feeder_args: 8 | data_path: ./data/ntu120/xsub/val_data_joint.npy 9 | label_path: ./data/ntu120/xsub/val_label.pkl 10 | 11 | # model 12 | model: model.network.Model 13 | model_args: 14 | in_channels: 15 15 | num_class: 120 16 | num_point: 25 17 | num_person: 2 18 | num_gcn_scales: 13 19 | num_g3d_scales: 6 20 | graph: graph.ntu_rgb_d.AdjMatrixGraph 21 | 22 | # ablation 23 | ablation: sgcn_only 24 | 25 | 26 | # optim 27 | weight_decay: 0.0005 28 | base_lr: 0.05 29 | step: [30,40,50] 30 | 31 | num_epoch: 60 32 | device: [0,1] 33 | batch_size: 40 34 | forward_batch_size: 40 35 | test_batch_size: 40 36 | nesterov: True 37 | 38 | optimizer: SGD 39 | 40 | eval_start: 5 41 | eval_interval: 5 42 | 43 | phase: test 44 | weights: "PATH TO THE PRETRAINED MODEL" 45 | save_score: True 46 | -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | 2 | work_dir: ./work_dir/ntu120_xsub_2021_angular_train 3 | 4 | # feeder 5 | feeder: feeders.feeder.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/xsub/train_data_joint.npy 8 | label_path: ./data/ntu120/xsub/train_label.pkl 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: -1 14 | normalization: False 15 | 16 | test_feeder_args: 17 | data_path: ./data/ntu120/xsub/val_data_joint.npy 18 | label_path: ./data/ntu120/xsub/val_label.pkl 19 | 20 | # model 21 | model: model.network.Model 22 | model_args: 23 | in_channels: 15 24 | num_class: 120 25 | num_point: 25 26 | num_person: 2 27 | num_gcn_scales: 13 28 | num_g3d_scales: 6 29 | graph: graph.ntu_rgb_d.AdjMatrixGraph 30 | 31 | # ablation 32 | ablation: sgcn_only 33 | 34 | 35 | # optim 36 | weight_decay: 0.0005 37 | base_lr: 0.05 38 | step: [30,40,50] 39 | 40 | # training 41 | num_epoch: 60 42 | device: [0,1] 43 | batch_size: 40 44 | forward_batch_size: 40 45 | test_batch_size: 80 46 | nesterov: True 47 | 48 | # 额外的 49 | to_add_onehot: False 50 | optimizer: SGD 51 | 52 | # Pretrained models 53 | # # This is to load the pretrained model, uncomment to use. 54 | # weights: "" 55 | # checkpoint: "" 56 | # resume: True 57 | 58 | eval_start: 1 59 | eval_interval: 5 60 | -------------------------------------------------------------------------------- /encoding/__pycache__/data_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/encoding/__pycache__/data_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /encoding/data_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as functional 4 | 5 | 6 | class DataRepeatEncoder: 7 | def __init__(self, rep_num): 8 | self.rep_num = rep_num 9 | 10 | def encode_data(self, a_bch, **kwargs): 11 | rtn = a_bch.repeat(1, self.rep_num, 1, 1, 1) 12 | return rtn 13 | 14 | 15 | class DataInterpolatingEncoder: 16 | def __init__(self, new_length=None): 17 | self.new_length = new_length 18 | 19 | def encode_data(self, a_bch, **kwargs): 20 | N, C, T, V, M = a_bch.size() 21 | a_bch = a_bch.permute(0, 4, 1, 3, 2).contiguous().view(N*M*C, V, T) 22 | a_bch = functional.interpolate(a_bch, size=self.new_length, mode='linear') 23 | a_bch = a_bch.view(N, M, C, V, self.new_length).permute(0, 2, 4, 3, 1) 24 | return a_bch 25 | 26 | class TrigonometricTemporalEncoder: 27 | def __init__(self, inc_type, freq_num, seq_len, is_with_orig=True): 28 | self.inc_func = inc_type 29 | self.periodic_fns = [torch.cos] 30 | self.is_with_orig = is_with_orig 31 | self.K = freq_num 32 | self.T = seq_len 33 | self.prepare_period_fns() 34 | 35 | def prepare_period_fns(self): 36 | assert self.inc_func is not None 37 | assert self.periodic_fns is not None 38 | 39 | # Get frequency values 40 | self.temp_freq_bands = [] 41 | for k in range(1, self.K + 1): 42 | if self.inc_func == 'linear': 43 | a_freq = k 44 | elif self.inc_func == 'exp': 45 | a_freq = 2 ** (k - 1) 46 | elif self.inc_func == 'pow': 47 | a_freq = k ** 2 48 | else: 49 | raise NotImplementedError('Unsupported inc_func.') 50 | self.temp_freq_bands.append(math.pi / self.T * a_freq) 51 | 52 | self.temp_freq_bands = torch.tensor(self.temp_freq_bands) 53 | print('Temporal frequency components: ', self.temp_freq_bands) 54 | 55 | # Get embed functions 56 | self.temp_embed_fns = [] 57 | if self.is_with_orig: 58 | self.temp_embed_fns.append(lambda x, frm_idx: x) 59 | 60 | for freq_t in self.temp_freq_bands: 61 | for p_fn in self.periodic_fns: 62 | self.temp_embed_fns.append( 63 | lambda x, frm_idx, p_fn=p_fn, freq=freq_t: (x * p_fn(freq * (frm_idx + 1 / 2))) 64 | ) # TTE 65 | 66 | def encode_data(self, a_bch, dim): 67 | t_len_all = a_bch.shape[2] 68 | time_list = [] 69 | for t_idx in range(t_len_all): 70 | a_series = a_bch[:, :, t_idx, :, :].unsqueeze(2) 71 | 72 | new_time_list = [] 73 | for fn in self.temp_embed_fns: 74 | a_new_one = fn(a_series, t_idx) 75 | new_time_list.append(a_new_one) 76 | new_time_list = torch.cat(new_time_list, dim) 77 | 78 | time_list.append(new_time_list) 79 | rtn = torch.cat(time_list, 2) 80 | return rtn 81 | -------------------------------------------------------------------------------- /feeders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import feeder 3 | from . import feeder_as_gcn 4 | from . import feeder_dgnn 5 | -------------------------------------------------------------------------------- /feeders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder.cpython-37.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_as_gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_as_gcn.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_as_gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_as_gcn.cpython-37.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_dgnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_dgnn.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_dgnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/feeder_dgnn.cpython-37.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/feeders/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /feeders/feeder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.extend(['../']) 3 | 4 | import torch 5 | import pickle 6 | import numpy as np 7 | from torch.utils.data import Dataset, TensorDataset 8 | from numpy import inf 9 | 10 | import scipy.fftpack 11 | 12 | from feeders import tools 13 | 14 | 15 | class Feeder(Dataset): 16 | def __init__(self, data_path, label_path, 17 | random_choose=False, random_shift=False, random_move=False, 18 | window_size=-1, normalization=False, debug=False, use_mmap=True, 19 | tgt_labels=None, frame_len=300, **kwargs): 20 | """ 21 | :param data_path: 22 | :param label_path: 23 | :param random_choose: If true, randomly choose a portion of the input sequence 24 | :param random_shift: If true, randomly pad zeros at the begining or end of sequence 25 | :param random_move: 26 | :param window_size: The length of the output sequence 27 | :param normalization: If true, normalize input sequence 28 | :param debug: If true, only use the first 100 samples 29 | :param use_mmap: If true, use mmap mode to load data, which can save the running memory 30 | """ 31 | 32 | self.debug = debug 33 | self.data_path = data_path 34 | self.label_path = label_path 35 | self.random_choose = random_choose 36 | self.random_shift = random_shift 37 | self.random_move = random_move 38 | self.window_size = window_size 39 | self.normalization = normalization 40 | self.use_mmap = use_mmap 41 | self.tgt_labels = tgt_labels 42 | self.kwargs = kwargs 43 | 44 | # other parameters 45 | self.frame_len = frame_len 46 | 47 | self.load_data() 48 | 49 | # Internal dataloader parameters 50 | # self.load_bch_sz = 2000 51 | # self.internal_dataloader = self.get_a_dataloader() 52 | 53 | if normalization: 54 | self.get_mean_map() 55 | 56 | def get_a_dataloader(self): 57 | a_dataset = TensorDataset(torch.tensor(self.data)) 58 | a_dataloader = torch.utils.data.DataLoader( 59 | dataset=a_dataset, 60 | batch_size=self.load_bch_sz, 61 | shuffle=False, 62 | num_workers=4, 63 | drop_last=False 64 | ) 65 | return a_dataloader 66 | 67 | def load_data(self): 68 | # data: N C T V M 69 | try: 70 | with open(self.label_path) as f: 71 | self.sample_name, self.label = pickle.load(f) 72 | except: 73 | # for pickle file from python2 74 | with open(self.label_path, 'rb') as f: 75 | self.sample_name, self.label = pickle.load(f, encoding='latin1') 76 | # self.label = np.array(self.label) 77 | 78 | # load data 79 | if self.use_mmap: 80 | self.data = np.load(self.data_path, mmap_mode='r') 81 | else: 82 | self.data = np.load(self.data_path) 83 | 84 | # Use tgt labels 85 | if self.tgt_labels is not None: 86 | self.label = np.array(self.label) 87 | tmp_data = None 88 | tmp_label = None 89 | for a_tgt_label in self.tgt_labels: 90 | selected_idxes = np.array(self.label) == a_tgt_label 91 | if tmp_data is None: 92 | tmp_data = self.data[selected_idxes] 93 | tmp_label = self.label[selected_idxes] 94 | else: 95 | tmp_data = np.concatenate((tmp_data, self.data[selected_idxes]), axis=0) 96 | tmp_label = np.concatenate((tmp_label, self.label[selected_idxes]), axis=0) 97 | self.data = tmp_data 98 | self.label = tmp_label 99 | 100 | if 'process_type' in self.kwargs: 101 | self.process_data(process_type=self.kwargs['process_type']) 102 | 103 | if self.debug: 104 | self.label = self.label[0:1000] 105 | self.data = self.data[0:1000] 106 | self.sample_name = self.sample_name[0:1000] 107 | # debug_tgt = 4206 # 4206 # 108 | # hard_samples = [9566, 13297, 15239, 13351, 11670, 8935, 2815, 9329, 15238, 8896] 109 | # self.label = list(self.label[i] for i in hard_samples) 110 | # self.data = list(self.data[i] for i in hard_samples) 111 | # self.sample_name = list(self.sample_name[i] for i in hard_samples) 112 | 113 | # Discrete cosine transform 114 | if 'dct' in self.kwargs: 115 | self.dct_data(self.kwargs['dct']) 116 | print('Discrete cosine transform completed. DCT type: ', self.kwargs['dct']) 117 | 118 | def dct_data(self, dct_op): 119 | dct_out = scipy.fftpack.dct(self.data, axis=2) 120 | if dct_op == 'overwrite': 121 | self.data = dct_out 122 | elif dct_op == 'concat': 123 | self.data = np.concatenate((self.data, dct_out), axis=1) 124 | elif dct_op == 'lengthen': 125 | dct_out = dct_out[:, :, :(self.frame_len // 2), :, :] 126 | self.data = np.concatenate((self.data, dct_out), axis=2) 127 | 128 | def process_data(self, process_type): 129 | rtn_data = [] 130 | rtn_label = [] 131 | if process_type == 'single_person': 132 | data_idx = 0 133 | for a_data in self.data: 134 | for ppl_id in range(self.data.shape[-1]): 135 | a_ppl = a_data[:, :, :, ppl_id] 136 | # comment the below one 137 | # if np.max(a_ppl) > 0.01: 138 | 139 | # Keep all data (some mutual data also contain zero) 140 | if np.max(a_ppl) > -1: 141 | rtn_data.append(np.expand_dims(a_ppl, axis=-1)) 142 | rtn_label.append(self.label[data_idx]) 143 | data_idx += 1 144 | else: 145 | raise NotImplementedError 146 | rtn_data = np.stack(rtn_data, axis=0) 147 | rtn_label = np.stack(rtn_label, axis=0) 148 | 149 | # relabel data to consider actor and receiver 150 | rtn_label = self.relabel_by_energy() 151 | 152 | self.data = rtn_data 153 | self.label = rtn_label 154 | 155 | def get_energy(self, s): # ctv 156 | index = s.sum(-1).sum(0) != 0 # select valid frames 157 | s = s[:, index, :] 158 | if len(s) != 0: 159 | s = s[0, :, :].std() + s[1, :, :].std() + s[2, :, :].std() # three channels 160 | else: 161 | s = 0 162 | return s 163 | 164 | def relabel_by_energy(self): 165 | rtn_label = [] 166 | for a_idx, a_data in enumerate(self.data): 167 | person_1 = a_data[:, :, :, 0] # C,T,V 168 | person_2 = a_data[:, :, :, 1] # C,T,V 169 | energy_1 = self.get_energy(person_1) 170 | energy_2 = self.get_energy(person_2) 171 | if energy_1 > energy_2: # first kicks the second 172 | rtn_label.append((1, 0)) 173 | else: 174 | rtn_label.append((0, 1)) 175 | return np.concatenate(rtn_label, axis=0) 176 | 177 | def get_mean_map(self): 178 | data = self.data 179 | N, C, T, V, M = data.shape 180 | self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0) 181 | self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1)) 182 | 183 | def __len__(self): 184 | return len(self.label) 185 | 186 | def __iter__(self): 187 | return self 188 | 189 | def __getitem__(self, index): 190 | data_numpy = self.data[index] 191 | label = self.label[index] 192 | data_numpy = np.array(data_numpy) 193 | 194 | if self.normalization: 195 | data_numpy = (data_numpy - self.mean_map) / self.std_map 196 | if self.random_shift: 197 | data_numpy = tools.random_shift(data_numpy) 198 | if self.random_choose: 199 | data_numpy = tools.random_choose(data_numpy, self.window_size) 200 | elif self.window_size > 0: 201 | data_numpy = tools.auto_pading(data_numpy, self.window_size) 202 | if self.random_move: 203 | data_numpy = tools.random_move(data_numpy) 204 | 205 | # Remove NAN 206 | data_numpy = np.nan_to_num(data_numpy) 207 | data_numpy[data_numpy == -inf] = 0 208 | 209 | return data_numpy, label, index 210 | 211 | def top_k(self, score, top_k): 212 | # 如果label只有一个值的话 213 | the_label = np.array(self.label) 214 | if len(the_label.shape) == 1 or the_label.shape[1] == 1: 215 | rank = score.argsort() 216 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] 217 | return sum(hit_top_k) * 1.0 / len(hit_top_k) 218 | # label里面包含fine grain的label 219 | else: 220 | act_label = the_label[:, 0] 221 | fgr_label = the_label[:, 1] 222 | act_score, fgr_score = score 223 | rank_act = act_score.argsort() 224 | rank_fgr = fgr_score.argsort() 225 | hit_top_k_acc = [l in rank_act[i, -top_k:] for i, l in enumerate(act_label)] 226 | hit_top_k_fgr = [l in rank_fgr[i, -top_k:] for i, l in enumerate(fgr_label)] 227 | return sum(hit_top_k_acc) * 1.0 / len(hit_top_k_acc), \ 228 | sum(hit_top_k_fgr) * 1.0 / len(hit_top_k_fgr) 229 | 230 | 231 | def import_class(name): 232 | components = name.split('.') 233 | mod = __import__(components[0]) 234 | for comp in components[1:]: 235 | mod = getattr(mod, comp) 236 | return mod 237 | 238 | 239 | def test(data_path, label_path, vid=None, graph=None, is_3d=False): 240 | ''' 241 | vis the samples using matplotlib 242 | :param data_path: 243 | :param label_path: 244 | :param vid: the id of sample 245 | :param graph: 246 | :param is_3d: when vis NTU, set it True 247 | :return: 248 | ''' 249 | import matplotlib.pyplot as plt 250 | loader = torch.utils.data.DataLoader( 251 | dataset=Feeder(data_path, label_path), 252 | batch_size=64, 253 | shuffle=False, 254 | num_workers=2) 255 | 256 | if vid is not None: 257 | sample_name = loader.dataset.sample_name 258 | sample_id = [name.split('.')[0] for name in sample_name] 259 | index = sample_id.index(vid) 260 | data, label, index = loader.dataset[index] 261 | data = data.reshape((1,) + data.shape) 262 | 263 | # for batch_idx, (data, label) in enumerate(loader): 264 | N, C, T, V, M = data.shape 265 | 266 | plt.ion() 267 | fig = plt.figure() 268 | if is_3d: 269 | from mpl_toolkits.mplot3d import Axes3D 270 | ax = fig.add_subplot(111, projection='3d') 271 | else: 272 | ax = fig.add_subplot(111) 273 | 274 | if graph is None: 275 | p_type = ['b.', 'g.', 'r.', 'c.', 'm.', 'y.', 'k.', 'k.', 'k.', 'k.'] 276 | pose = [ 277 | ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M) 278 | ] 279 | ax.axis([-1, 1, -1, 1]) 280 | for t in range(T): 281 | for m in range(M): 282 | pose[m].set_xdata(data[0, 0, t, :, m]) 283 | pose[m].set_ydata(data[0, 1, t, :, m]) 284 | fig.canvas.draw() 285 | plt.pause(0.001) 286 | else: 287 | p_type = ['b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-'] 288 | import sys 289 | from os import path 290 | sys.path.append( 291 | path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) 292 | G = import_class(graph)() 293 | edge = G.inward 294 | pose = [] 295 | for m in range(M): 296 | a = [] 297 | for i in range(len(edge)): 298 | if is_3d: 299 | a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0]) 300 | else: 301 | a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0]) 302 | pose.append(a) 303 | ax.axis([-1, 1, -1, 1]) 304 | if is_3d: 305 | ax.set_zlim3d(-1, 1) 306 | for t in range(T): 307 | for m in range(M): 308 | for i, (v1, v2) in enumerate(edge): 309 | x1 = data[0, :2, t, v1, m] 310 | x2 = data[0, :2, t, v2, m] 311 | if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1: 312 | pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m]) 313 | pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m]) 314 | if is_3d: 315 | pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m]) 316 | fig.canvas.draw() 317 | # plt.savefig('/home/lshi/Desktop/skeleton_sequence/' + str(t) + '.jpg') 318 | plt.pause(0.01) 319 | 320 | 321 | if __name__ == '__main__': 322 | import os 323 | os.environ['DISPLAY'] = 'localhost:10.0' 324 | data_path = "../data/ntu/xview/val_data_joint.npy" 325 | label_path = "../data/ntu/xview/val_label.pkl" 326 | graph = 'graph.ntu_rgb_d.Graph' 327 | test(data_path, label_path, vid='S004C001P003R001A032', graph=graph, is_3d=True) 328 | # data_path = "../data/kinetics/val_data.npy" 329 | # label_path = "../data/kinetics/val_label.pkl" 330 | # graph = 'graph.Kinetics' 331 | # test(data_path, label_path, vid='UOD7oll3Kqo', graph=graph) 332 | -------------------------------------------------------------------------------- /feeders/feeder_as_gcn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.extend(['../']) 3 | 4 | import numpy as np 5 | import random 6 | import pickle 7 | import time 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torchvision import datasets, transforms 15 | 16 | from . import tools 17 | 18 | 19 | class Feeder(torch.utils.data.Dataset): 20 | def __init__(self, 21 | data_path, label_path, 22 | repeat_pad=False, 23 | random_choose=False, 24 | random_move=False, 25 | window_size=-1, 26 | debug=False, 27 | down_sample = False, 28 | mmap=True): 29 | self.debug = debug 30 | self.data_path = data_path 31 | self.label_path = label_path 32 | self.repeat_pad = repeat_pad 33 | self.random_choose = random_choose 34 | self.random_move = random_move 35 | self.window_size = window_size 36 | self.down_sample = down_sample 37 | 38 | self.load_data(mmap) 39 | 40 | def load_data(self, mmap): 41 | 42 | with open(self.label_path, 'rb') as f: 43 | self.sample_name, self.label = pickle.load(f) 44 | 45 | if mmap: 46 | self.data = np.load(self.data_path, mmap_mode='r') 47 | else: 48 | self.data = np.load(self.data_path) 49 | 50 | if self.debug: 51 | self.label = self.label[0:100] 52 | self.data = self.data[0:100] 53 | self.sample_name = self.sample_name[0:100] 54 | 55 | self.N, self.C, self.T, self.V, self.M = self.data.shape 56 | 57 | def __len__(self): 58 | return len(self.label) 59 | 60 | def __getitem__(self, index): 61 | data_numpy = np.array(self.data[index]).astype(np.float32) 62 | label = self.label[index] 63 | 64 | valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0 65 | begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax() 66 | length = end-begin 67 | 68 | if self.repeat_pad: 69 | data_numpy = tools.repeat_pading(data_numpy) 70 | if self.random_choose: 71 | data_numpy = tools.random_choose(data_numpy, self.window_size) 72 | elif self.window_size > 0: 73 | data_numpy = tools.auto_pading(data_numpy, self.window_size) 74 | if self.random_move: 75 | data_numpy = tools.random_move(data_numpy) 76 | 77 | data_last = copy.copy(data_numpy[:,-11:-10,:,:]) 78 | target_data = copy.copy(data_numpy[:,-10:,:,:]) 79 | input_data = copy.copy(data_numpy[:,:-10,:,:]) 80 | 81 | if self.down_sample: 82 | if length<=60: 83 | input_data_dnsp = input_data[:,:50,:,:] 84 | else: 85 | rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50))) 86 | input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]] 87 | input_data_dnsp = np.array(input_data_dnsp).astype(np.float32) 88 | input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3)) 89 | 90 | return input_data, input_data_dnsp, target_data, data_last, label -------------------------------------------------------------------------------- /feeders/feeder_dgnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | from torch.utils.data import Dataset 5 | import sys 6 | 7 | sys.path.extend(['../']) 8 | from feeders import tools 9 | 10 | 11 | class BaseDataset(Dataset): 12 | def __init__(self, data_path, label_path, 13 | random_choose=False, random_shift=False, random_move=False, 14 | window_size=-1, normalization=False, debug=False, use_mmap=True): 15 | """ 16 | :param data_path: 17 | :param label_path: 18 | :param random_choose: If true, randomly choose a portion of the input sequence 19 | :param random_shift: If true, randomly pad zeros at the begining or end of sequence 20 | :param random_move: 21 | :param window_size: The length of the output sequence 22 | :param normalization: If true, normalize input sequence 23 | :param debug: If true, only use the first 100 samples 24 | :param use_mmap: If true, use mmap mode to load data, which can save the running memory 25 | """ 26 | self.debug = debug 27 | self.data_path = data_path 28 | self.label_path = label_path 29 | self.random_choose = random_choose 30 | self.random_shift = random_shift 31 | self.random_move = random_move 32 | self.window_size = window_size 33 | self.normalization = normalization 34 | self.use_mmap = use_mmap 35 | self.load_data() 36 | if normalization: 37 | self.get_mean_map() 38 | 39 | def load_data(self): 40 | # data: (N,C,V,T,M) 41 | try: 42 | with open(self.label_path) as f: 43 | self.sample_name, self.label = pickle.load(f) 44 | except: 45 | # for pickle file from python2 46 | with open(self.label_path, 'rb') as f: 47 | self.sample_name, self.label = pickle.load(f, encoding='latin1') 48 | 49 | # load data 50 | if self.use_mmap: 51 | self.data = np.load(self.data_path, mmap_mode='r') 52 | else: 53 | self.data = np.load(self.data_path) 54 | if self.debug: 55 | self.label = self.label[0:100] 56 | self.data = self.data[0:100] 57 | self.sample_name = self.sample_name[0:100] 58 | 59 | def get_mean_map(self): 60 | """Computes the mean and standard deviation of the dataset""" 61 | data = self.data 62 | N, C, T, V, M = data.shape 63 | self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0) 64 | self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1)) 65 | 66 | def __len__(self): 67 | return len(self.label) 68 | 69 | def __iter__(self): 70 | return self 71 | 72 | def __getitem__(self, index): 73 | data_numpy = self.data[index] 74 | label = self.label[index] 75 | data_numpy = np.array(data_numpy) 76 | 77 | if self.normalization: 78 | data_numpy = (data_numpy - self.mean_map) / self.std_map 79 | if self.random_shift: 80 | data_numpy = tools.random_shift(data_numpy) 81 | if self.random_choose: 82 | data_numpy = tools.random_choose(data_numpy, self.window_size) 83 | elif self.window_size > 0: 84 | data_numpy = tools.auto_pading(data_numpy, self.window_size) 85 | if self.random_move: 86 | data_numpy = tools.random_move(data_numpy) 87 | 88 | return data_numpy, label, index 89 | 90 | def top_k(self, score, top_k): 91 | rank = score.argsort() 92 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] 93 | return sum(hit_top_k) * 1.0 / len(hit_top_k) 94 | 95 | 96 | class Feeder(Dataset): 97 | def __init__(self, joint_data_path, bone_data_path, label_path, 98 | random_choose=False, random_shift=False, random_move=False, 99 | window_size=-1, normalization=False, debug=False, use_mmap=True): 100 | 101 | self.joint_dataset = BaseDataset(joint_data_path, label_path, random_choose, random_shift, random_move, window_size, normalization, debug, use_mmap) 102 | self.bone_dataset = BaseDataset(bone_data_path, label_path, random_choose, random_shift, random_move, window_size, normalization, debug, use_mmap) 103 | self.sample_name = self.joint_dataset.sample_name 104 | 105 | def __len__(self): 106 | return min(len(self.joint_dataset), len(self.bone_dataset)) 107 | 108 | def __iter__(self): 109 | return self 110 | 111 | def __getitem__(self, index): 112 | joint_data, label, index = self.joint_dataset[index] 113 | bone_data, label, index = self.bone_dataset[index] 114 | # Either label is fine 115 | return joint_data, bone_data, label, index 116 | 117 | def top_k(self, score, top_k): 118 | # Either dataset can be delegate 119 | return self.joint_dataset.top_k(score, top_k) 120 | 121 | 122 | # def import_class(name): 123 | # components = name.split('.') 124 | # mod = __import__(components[0]) 125 | # for comp in components[1:]: 126 | # mod = getattr(mod, comp) 127 | # return mod 128 | 129 | 130 | # def test(data_path, label_path, vid=None, graph=None, is_3d=False): 131 | # ''' 132 | # vis the samples using matplotlib 133 | # :param data_path: 134 | # :param label_path: 135 | # :param vid: the id of sample 136 | # :param graph: 137 | # :param is_3d: when vis NTU, set it True 138 | # :return: 139 | # ''' 140 | # import matplotlib.pyplot as plt 141 | # loader = torch.utils.data.DataLoader( 142 | # dataset=Feeder(data_path, label_path), 143 | # batch_size=64, 144 | # shuffle=False, 145 | # num_workers=2) 146 | 147 | # if vid is not None: 148 | # sample_name = loader.dataset.sample_name 149 | # sample_id = [name.split('.')[0] for name in sample_name] 150 | # index = sample_id.index(vid) 151 | # data, label, index = loader.dataset[index] 152 | # data = data.reshape((1,) + data.shape) 153 | 154 | # # for batch_idx, (data, label) in enumerate(loader): 155 | # N, C, T, V, M = data.shape 156 | 157 | # plt.ion() 158 | # fig = plt.figure() 159 | # if is_3d: 160 | # from mpl_toolkits.mplot3d import Axes3D 161 | # ax = fig.add_subplot(111, projection='3d') 162 | # else: 163 | # ax = fig.add_subplot(111) 164 | 165 | # if graph is None: 166 | # p_type = ['b.', 'g.', 'r.', 'c.', 'm.', 'y.', 'k.', 'k.', 'k.', 'k.'] 167 | # pose = [ 168 | # ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0] for m in range(M) 169 | # ] 170 | # ax.axis([-1, 1, -1, 1]) 171 | # for t in range(T): 172 | # for m in range(M): 173 | # pose[m].set_xdata(data[0, 0, t, :, m]) 174 | # pose[m].set_ydata(data[0, 1, t, :, m]) 175 | # fig.canvas.draw() 176 | # plt.pause(0.001) 177 | # else: 178 | # p_type = ['b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-'] 179 | # import sys 180 | # from os import path 181 | # sys.path.append( 182 | # path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) 183 | # G = import_class(graph)() 184 | # edge = G.inward 185 | # pose = [] 186 | # for m in range(M): 187 | # a = [] 188 | # for i in range(len(edge)): 189 | # if is_3d: 190 | # a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0]) 191 | # else: 192 | # a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0]) 193 | # pose.append(a) 194 | # ax.axis([-1, 1, -1, 1]) 195 | # if is_3d: 196 | # ax.set_zlim3d(-1, 1) 197 | # for t in range(T): 198 | # for m in range(M): 199 | # for i, (v1, v2) in enumerate(edge): 200 | # x1 = data[0, :2, t, v1, m] 201 | # x2 = data[0, :2, t, v2, m] 202 | # if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1: 203 | # pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m]) 204 | # pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m]) 205 | # if is_3d: 206 | # pose[m][i].set_3d_properties(data[0, 2, t, [v1, v2], m]) 207 | # fig.canvas.draw() 208 | # # plt.savefig('/home/lshi/Desktop/skeleton_sequence/' + str(t) + '.jpg') 209 | # plt.pause(0.01) 210 | 211 | 212 | if __name__ == '__main__': 213 | pass 214 | # import os 215 | # os.environ['DISPLAY'] = 'localhost:10.0' 216 | # data_path = "../data/ntu/xview/val_data_joint.npy" 217 | # label_path = "../data/ntu/xview/val_label.pkl" 218 | # graph = 'graph.ntu_rgb_d.Graph' 219 | # test(data_path, label_path, vid='S004C001P003R001A032', graph=graph, is_3d=True) 220 | # data_path = "../data/kinetics/val_data.npy" 221 | # label_path = "../data/kinetics/val_label.pkl" 222 | # graph = 'graph.Kinetics' 223 | # test(data_path, label_path, vid='UOD7oll3Kqo', graph=graph) 224 | -------------------------------------------------------------------------------- /feeders/tools.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | 6 | def downsample(data_numpy, step, random_sample=True): 7 | # input: C,T,V,M 8 | begin = np.random.randint(step) if random_sample else 0 9 | return data_numpy[:, begin::step, :, :] 10 | 11 | 12 | def temporal_slice(data_numpy, step): 13 | # input: C,T,V,M 14 | C, T, V, M = data_numpy.shape 15 | return data_numpy.reshape(C, T / step, step, V, M).transpose( 16 | (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) 17 | 18 | 19 | def mean_subtractor(data_numpy, mean): 20 | # input: C,T,V,M 21 | # naive version 22 | if mean == 0: 23 | return 24 | C, T, V, M = data_numpy.shape 25 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 26 | begin = valid_frame.argmax() 27 | end = len(valid_frame) - valid_frame[::-1].argmax() 28 | data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean 29 | return data_numpy 30 | 31 | 32 | def auto_pading(data_numpy, size, random_pad=False): 33 | C, T, V, M = data_numpy.shape 34 | if T < size: 35 | begin = random.randint(0, size - T) if random_pad else 0 36 | data_numpy_paded = np.zeros((C, size, V, M)) 37 | data_numpy_paded[:, begin:begin + T, :, :] = data_numpy 38 | return data_numpy_paded 39 | else: 40 | return data_numpy 41 | 42 | 43 | def random_choose(data_numpy, size, auto_pad=True): 44 | C, T, V, M = data_numpy.shape 45 | if T == size: 46 | return data_numpy 47 | elif T < size: 48 | if auto_pad: 49 | return auto_pading(data_numpy, size, random_pad=True) 50 | else: 51 | return data_numpy 52 | else: 53 | begin = random.randint(0, T - size) 54 | return data_numpy[:, begin:begin + size, :, :] 55 | 56 | 57 | def random_move(data_numpy, 58 | angle_candidate=[-10., -5., 0., 5., 10.], 59 | scale_candidate=[0.9, 1.0, 1.1], 60 | transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], 61 | move_time_candidate=[1]): 62 | # input: C,T,V,M 63 | C, T, V, M = data_numpy.shape 64 | move_time = random.choice(move_time_candidate) 65 | node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) 66 | node = np.append(node, T) 67 | num_node = len(node) 68 | 69 | A = np.random.choice(angle_candidate, num_node) 70 | S = np.random.choice(scale_candidate, num_node) 71 | T_x = np.random.choice(transform_candidate, num_node) 72 | T_y = np.random.choice(transform_candidate, num_node) 73 | 74 | a = np.zeros(T) 75 | s = np.zeros(T) 76 | t_x = np.zeros(T) 77 | t_y = np.zeros(T) 78 | 79 | # linspace 80 | for i in range(num_node - 1): 81 | a[node[i]:node[i + 1]] = np.linspace( 82 | A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 83 | s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], 84 | node[i + 1] - node[i]) 85 | t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], 86 | node[i + 1] - node[i]) 87 | t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], 88 | node[i + 1] - node[i]) 89 | 90 | theta = np.array([[np.cos(a) * s, -np.sin(a) * s], 91 | [np.sin(a) * s, np.cos(a) * s]]) 92 | 93 | # perform transformation 94 | for i_frame in range(T): 95 | xy = data_numpy[0:2, i_frame, :, :] 96 | new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) 97 | new_xy[0] += t_x[i_frame] 98 | new_xy[1] += t_y[i_frame] 99 | data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) 100 | 101 | return data_numpy 102 | 103 | 104 | def random_shift(data_numpy): 105 | C, T, V, M = data_numpy.shape 106 | data_shift = np.zeros(data_numpy.shape) 107 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 108 | begin = valid_frame.argmax() 109 | end = len(valid_frame) - valid_frame[::-1].argmax() 110 | 111 | size = end - begin 112 | bias = random.randint(0, T - size) 113 | data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] 114 | 115 | return data_shift 116 | 117 | 118 | def openpose_match(data_numpy): 119 | C, T, V, M = data_numpy.shape 120 | assert (C == 3) 121 | score = data_numpy[2, :, :, :].sum(axis=1) 122 | # the rank of body confidence in each frame (shape: T-1, M) 123 | rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) 124 | 125 | # data of frame 1 126 | xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) 127 | # data of frame 2 128 | xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) 129 | # square of distance between frame 1&2 (shape: T-1, M, M) 130 | distance = ((xy2 - xy1) ** 2).sum(axis=2).sum(axis=0) 131 | 132 | # match pose 133 | forward_map = np.zeros((T, M), dtype=int) - 1 134 | forward_map[0] = range(M) 135 | for m in range(M): 136 | choose = (rank == m) 137 | forward = distance[choose].argmin(axis=1) 138 | for t in range(T - 1): 139 | distance[t, :, forward[t]] = np.inf 140 | forward_map[1:][choose] = forward 141 | assert (np.all(forward_map >= 0)) 142 | 143 | # string data 144 | for t in range(T - 1): 145 | forward_map[t + 1] = forward_map[t + 1][forward_map[t]] 146 | 147 | # generate data 148 | new_data_numpy = np.zeros(data_numpy.shape) 149 | for t in range(T): 150 | new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ 151 | t]].transpose(1, 2, 0) 152 | data_numpy = new_data_numpy 153 | 154 | # score sort 155 | trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) 156 | rank = (-trace_score).argsort() 157 | data_numpy = data_numpy[:, :, :, rank] 158 | 159 | return data_numpy 160 | 161 | 162 | def repeat_pading(data_numpy): 163 | data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3] 164 | for i_p, person in enumerate(data_tmp): 165 | if person.sum()==0: 166 | continue 167 | if person[0].sum()==0: 168 | index = (person.sum(-1).sum(-1)!=0) 169 | tmp = person[index].copy() 170 | person*=0 171 | person[:len(tmp)] = tmp 172 | for i_f, frame in enumerate(person): 173 | if frame.sum()==0: 174 | if person[i_f:].sum()==0: 175 | rest = len(person)-i_f 176 | num = int(np.ceil(rest/i_f)) 177 | pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] 178 | data_tmp[i_p,i_f:] = pad 179 | break 180 | data_numpy = np.transpose(data_tmp, [3,1,2,0]) 181 | return data_numpy -------------------------------------------------------------------------------- /figures/Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/Architecture.png -------------------------------------------------------------------------------- /figures/angle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/angle.png -------------------------------------------------------------------------------- /figures/skeletons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/figures/skeletons.png -------------------------------------------------------------------------------- /graph/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import ntu_rgb_d 3 | from . import kinetics 4 | from . import azure_kinect 5 | from . import directed_ntu_rgb_d 6 | -------------------------------------------------------------------------------- /graph/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ang_adjs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ang_adjs.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ang_adjs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ang_adjs.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/azure_kinect.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/azure_kinect.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/azure_kinect.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/azure_kinect.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/directed_ntu_rgb_d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/directed_ntu_rgb_d.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/directed_ntu_rgb_d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/directed_ntu_rgb_d.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/hyper_graphs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/hyper_graphs.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/hyper_graphs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/hyper_graphs.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/kinetics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/kinetics.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/kinetics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/kinetics.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ntu_rgb_d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ntu_rgb_d.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ntu_rgb_d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/ntu_rgb_d.cpython-37.pyc -------------------------------------------------------------------------------- /graph/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/graph/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /graph/ang_adjs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from graph.tools import normalize_adjacency_matrix 5 | 6 | 7 | def get_ang_adjs(data_type): 8 | rtn_adjs = [] 9 | 10 | if data_type == 'ntu': 11 | node_num = 25 12 | sym_pairs = ((21, 2), (11, 7), (18, 14), (20, 16), (24, 25), (22, 23)) 13 | for a_sym_pair in sym_pairs: 14 | a_adj = np.eye(node_num) 15 | a_adj[:, a_sym_pair[0]-1] = 1 16 | a_adj[:, a_sym_pair[1]-1] = 1 17 | a_adj = torch.tensor(normalize_adjacency_matrix(a_adj)) 18 | rtn_adjs.append(a_adj) 19 | 20 | return torch.cat(rtn_adjs, dim=0) 21 | -------------------------------------------------------------------------------- /graph/azure_kinect.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '') 3 | sys.path.extend(['../']) 4 | 5 | import numpy as np 6 | 7 | from graph import tools 8 | 9 | num_node = 32 10 | self_link = [(i, i) for i in range(num_node)] 11 | inward_ori_index = [ 12 | (1, 0), 13 | (2, 1), 14 | (3, 2), 15 | (4, 2), 16 | (5, 4), 17 | (6, 5), 18 | (7, 6), 19 | (8, 7), 20 | (9, 8), 21 | (10, 7), 22 | (11, 2), 23 | (12, 11), 24 | (13, 12), 25 | (14, 13), 26 | (15, 14), 27 | (16, 15), 28 | (17, 14), 29 | (18, 0), 30 | (19, 18), 31 | (20, 19), 32 | (21, 20), 33 | (22, 0), 34 | (23, 22), 35 | (24, 23), 36 | (25, 24), 37 | (26, 3), 38 | (27, 26), 39 | (28, 27), 40 | (29, 28), 41 | (30, 27), 42 | (31, 30), 43 | ] 44 | inward = [(i, j) for (i, j) in inward_ori_index] 45 | outward = [(j, i) for (i, j) in inward] 46 | neighbor = inward + outward 47 | 48 | 49 | class Graph: 50 | def __init__(self, labeling_mode='spatial'): 51 | self.A = self.get_adjacency_matrix(labeling_mode) 52 | self.num_node = num_node 53 | self.self_link = self_link 54 | self.inward = inward 55 | self.outward = outward 56 | self.neighbor = neighbor 57 | 58 | def get_adjacency_matrix(self, labeling_mode=None): 59 | if labeling_mode is None: 60 | return self.A 61 | if labeling_mode == 'spatial': 62 | A = tools.get_spatial_graph(num_node, self_link, inward, outward) 63 | else: 64 | raise ValueError() 65 | return A 66 | 67 | 68 | class AdjMatrixGraph: 69 | def __init__(self, *args, **kwargs): 70 | self.edges = neighbor 71 | self.num_nodes = num_node 72 | self.self_loops = [(i, i) for i in range(self.num_nodes)] 73 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes) 74 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes) 75 | self.A = tools.normalize_adjacency_matrix(self.A_binary) 76 | 77 | 78 | if __name__ == '__main__': 79 | import matplotlib.pyplot as plt 80 | graph = AdjMatrixGraph() 81 | A, A_binary, A_binary_with_I = graph.A, graph.A_binary, graph.A_binary_with_I 82 | f, ax = plt.subplots(1, 3) 83 | ax[0].imshow(A_binary_with_I, cmap='gray') 84 | ax[1].imshow(A_binary, cmap='gray') 85 | ax[2].imshow(A, cmap='gray') 86 | plt.show() 87 | print(A_binary_with_I.shape, A_binary.shape, A.shape) 88 | -------------------------------------------------------------------------------- /graph/directed_ntu_rgb_d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from scipy import special 6 | 7 | # For NTU RGB+D, assume node 21 (centre of chest) 8 | # is the "centre of gravity" mentioned in the paper 9 | 10 | num_nodes = 25 11 | epsilon = 1e-6 12 | 13 | # Directed edges: (source, target), see 14 | # https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Shahroudy_NTU_RGBD_A_CVPR_2016_paper.pdf 15 | # for node IDs, and reduce index to 0-based 16 | directed_edges = [(i-1, j-1) for i, j in [ 17 | (1, 13), (1, 17), (2, 1), (3, 4), (5, 6), 18 | (6, 7), (7, 8), (8, 22), (8, 23), (9, 10), 19 | (10, 11), (11, 12), (12, 24), (12, 25), (13, 14), 20 | (14, 15), (15, 16), (17, 18), (18, 19), (19, 20), 21 | (21, 2), (21, 3), (21, 5), (21, 9), 22 | (21, 21) # Add self loop for Node 21 (the centre) to avoid singular matrices 23 | ]] 24 | 25 | # NOTE: for now, let's not add self loops since the paper didn't mention this 26 | # self_loops = [(i, i) for i in range(num_nodes)] 27 | 28 | 29 | def build_digraph_adj_list(edges: List[Tuple]) -> np.ndarray: 30 | graph = defaultdict(list) 31 | for source, target in edges: 32 | graph[source].append(target) 33 | return graph 34 | 35 | 36 | def normalize_incidence_matrix(im: np.ndarray, full_im: np.ndarray) -> np.ndarray: 37 | # NOTE: 38 | # 1. The paper assumes that the Incidence matrix is square, 39 | # so that the normalized form A @ (D ** -1) is viable. 40 | # However, if the incidence matrix is non-square, then 41 | # the above normalization won't work. 42 | # For now, move the term (D ** -1) to the front 43 | # 2. It's not too clear whether the degree matrix of the FULL incidence matrix 44 | # should be calculated, or just the target/source IMs. 45 | # However, target/source IMs are SINGULAR matrices since not all nodes 46 | # have incoming/outgoing edges, but the full IM as described by the paper 47 | # is also singular, since ±1 is used for target/source nodes. 48 | # For now, we'll stick with adding target/source IMs. 49 | degree_mat = full_im.sum(-1) * np.eye(len(full_im)) 50 | # Since all nodes should have at least some edge, degree matrix is invertible 51 | inv_degree_mat = np.linalg.inv(degree_mat) 52 | return (inv_degree_mat @ im) + epsilon 53 | 54 | 55 | def build_digraph_incidence_matrix(num_nodes: int, edges: List[Tuple]) -> np.ndarray: 56 | # NOTE: For now, we won't consider all possible edges 57 | # max_edges = int(special.comb(num_nodes, 2)) 58 | max_edges = len(edges) 59 | source_graph = np.zeros((num_nodes, max_edges), dtype='float32') 60 | target_graph = np.zeros((num_nodes, max_edges), dtype='float32') 61 | for edge_id, (source_node, target_node) in enumerate(edges): 62 | source_graph[source_node, edge_id] = 1. 63 | target_graph[target_node, edge_id] = 1. 64 | full_graph = source_graph + target_graph 65 | source_graph = normalize_incidence_matrix(source_graph, full_graph) 66 | target_graph = normalize_incidence_matrix(target_graph, full_graph) 67 | return source_graph, target_graph 68 | 69 | 70 | def build_digraph_adj_matrix(edges: List[Tuple]) -> np.ndarray: 71 | graph = np.zeros((num_nodes, num_nodes), dtype='float32') 72 | for edge in edges: 73 | graph[edge] = 1 74 | return graph 75 | 76 | 77 | class Graph: 78 | def __init__(self): 79 | super().__init__() 80 | self.num_nodes = num_nodes 81 | self.edges = directed_edges 82 | # Incidence matrices 83 | self.source_M, self.target_M = \ 84 | build_digraph_incidence_matrix(self.num_nodes, self.edges) 85 | 86 | 87 | # TODO: 88 | # Check whether self loop should be added inside the graph 89 | # Check incidence matrix size 90 | 91 | 92 | if __name__ == "__main__": 93 | import matplotlib.pyplot as plt 94 | graph = Graph() 95 | source_M = graph.source_M 96 | target_M = graph.target_M 97 | plt.imshow(source_M, cmap='gray') 98 | plt.show() 99 | plt.imshow(target_M, cmap='gray') 100 | plt.show() 101 | print(source_M) 102 | -------------------------------------------------------------------------------- /graph/hyper_graphs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from graph import tools 4 | import torch 5 | 6 | sys.path.insert(0, '') 7 | sys.path.extend(['../']) 8 | 9 | 10 | local_bone_hyper_edge_dict = { 11 | 'ntu': ( 12 | (1, 17, 13), (2, 21, 1), (3, 4, 21), (4, 4, 4), (5, 21, 6), (6, 5, 7), (7, 6, 8), (8, 23, 22), 13 | (9, 10, 21), (10, 11, 9), (11, 12, 10), (12, 24, 25), (13, 1, 14), (14, 13, 15), (15, 14, 16), 14 | (16, 16, 16), (17, 18, 1), (18, 19, 17), (19, 20, 18), (20, 20, 20), (21, 9, 5), (22, 8, 23), 15 | (23, 8, 22), (24, 25, 12), (25, 24, 12) 16 | ) 17 | } 18 | 19 | 20 | center_hyper_edge_dict = { 21 | 'ntu': ( 22 | (1, 2, 21), (2, 2, 21), (3, 2, 21), (4, 2, 21), (5, 2, 21), (6, 2, 21), (7, 2, 21), (8, 2, 21), 23 | (9, 2, 21), (10, 2, 21), (11, 2, 21), (12, 2, 21), (13, 2, 21), (14, 2, 21), (15, 2, 21), 24 | (16, 2, 21), (17, 2, 21), (18, 2, 21), (19, 2, 21), (20, 2, 21), (21, 2, 21), (22, 2, 21), 25 | (23, 2, 21), (24, 2, 21), (25, 2, 21), 26 | ) 27 | } 28 | 29 | 30 | figure_l_hyper_edge_dict = { 31 | 'ntu': ( 32 | (1, 24, 25), (2, 24, 25), (3, 24, 25), (4, 24, 25), (5, 24, 25), (6, 24, 25), (7, 24, 25), (8, 24, 25), 33 | (9, 24, 25), (10, 24, 25), (11, 24, 25), (12, 24, 25), (13, 24, 25), (14, 24, 25), (15, 24, 25), 34 | (16, 24, 25), (17, 24, 25), (18, 24, 25), (19, 24, 25), (20, 24, 25), (21, 24, 25), (22, 24, 25), 35 | (23, 24, 25), (24, 24, 25), (25, 24, 25), 36 | ) 37 | } 38 | 39 | 40 | figure_r_hyper_edge_dict = { 41 | 'ntu': ( 42 | (1, 22, 23), (2, 22, 23), (3, 22, 23), (4, 22, 23), (5, 22, 23), (6, 22, 23), (7, 22, 23), (8, 22, 23), 43 | (9, 22, 23), (10, 22, 23), (11, 22, 23), (12, 22, 23), (13, 22, 23), (14, 22, 23), (15, 22, 23), 44 | (16, 22, 23), (17, 22, 23), (18, 22, 23), (19, 22, 23), (20, 22, 23), (21, 22, 23), (22, 22, 23), 45 | (23, 22, 23), (24, 22, 23), (25, 22, 23), 46 | ) 47 | } 48 | 49 | 50 | hand_hyper_edge_dict = { 51 | 'ntu': ( 52 | (1, 24, 22), (2, 24, 22), (3, 24, 22), (4, 24, 22), (5, 24, 22), (6, 24, 22), (7, 24, 22), (8, 24, 22), 53 | (9, 24, 22), (10, 24, 22), (11, 24, 22), (12, 24, 22), (13, 24, 22), (14, 24, 22), (15, 24, 22), 54 | (16, 24, 22), (17, 24, 22), (18, 24, 22), (19, 24, 22), (20, 24, 22), (21, 24, 22), (22, 24, 22), 55 | (23, 24, 22), (24, 24, 22), (25, 24, 22), 56 | ) 57 | } 58 | 59 | 60 | elbow_hyper_edge = { 61 | 'ntu': ( 62 | (1, 10, 6), (2, 10, 6), (3, 10, 6), (4, 10, 6), (5, 10, 6), (6, 10, 6), (7, 10, 6), (8, 10, 6), 63 | (9, 10, 6), (10, 10, 6), (11, 10, 6), (12, 10, 6), (13, 10, 6), (14, 10, 6), (15, 10, 6), 64 | (16, 10, 6), (17, 10, 6), (18, 10, 6), (19, 10, 6), (20, 10, 6), (21, 10, 6), (22, 10, 6), 65 | (23, 10, 6), (24, 10, 6), (25, 10, 6), 66 | ) 67 | } 68 | 69 | 70 | foot_hyper_edge = { 71 | 'ntu': ( 72 | (1, 20, 16), (2, 20, 16), (3, 20, 16), (4, 20, 16), (5, 20, 16), (6, 20, 16), (7, 20, 16), (8, 20, 16), 73 | (9, 20, 16), (10, 20, 16), (11, 20, 16), (12, 20, 16), (13, 20, 16), (14, 20, 16), (15, 20, 16), 74 | (16, 20, 16), (17, 20, 16), (18, 20, 16), (19, 20, 16), (20, 20, 16), (21, 20, 16), (22, 20, 16), 75 | (23, 20, 16), (24, 20, 16), (25, 20, 16), 76 | ) 77 | } 78 | 79 | 80 | def get_hyper_edge(dataset, edge_type): 81 | if edge_type == 'local_bone': 82 | tgt_dict = local_bone_hyper_edge_dict 83 | elif edge_type == 'center': 84 | tgt_dict = center_hyper_edge_dict 85 | elif edge_type == 'figure_l': 86 | tgt_dict = figure_l_hyper_edge_dict 87 | elif edge_type == 'figure_r': 88 | tgt_dict = figure_r_hyper_edge_dict 89 | elif edge_type == 'hand': 90 | tgt_dict = hand_hyper_edge_dict 91 | elif edge_type == 'elbow': 92 | tgt_dict = elbow_hyper_edge 93 | elif edge_type == 'foot': 94 | tgt_dict = foot_hyper_edge 95 | else: 96 | raise NotImplementedError 97 | tgt_hyper_edge = tgt_dict[dataset] 98 | if 'ntu' in dataset: 99 | node_num = 25 100 | hyper_edge_adj = torch.zeros((node_num, node_num)) 101 | else: 102 | raise NotImplementedError 103 | for i in range(node_num): 104 | edge_idx = 0 105 | for a_hyper_edge in tgt_hyper_edge: 106 | if (i+1) in a_hyper_edge: 107 | hyper_edge_adj[i][edge_idx] = 1 108 | edge_idx += 1 109 | 110 | return hyper_edge_adj 111 | 112 | 113 | 114 | def get_local_bone_hyper_adj(dataset): 115 | if 'ntu' in dataset: 116 | the_hyper_edge = [ 117 | (17, 13), 118 | (21, 1), 119 | (4, 21), 120 | (21, 6), 121 | (5, 7), 122 | (6, 8), 123 | (23, 22), 124 | (10, 21), 125 | (11, 9), 126 | (12, 10), 127 | (24, 25), 128 | (1, 14), 129 | (13, 15), 130 | (14, 16), 131 | (18, 1), 132 | (19, 17), 133 | (20, 18), 134 | (20, 20), 135 | (9, 5), 136 | (8, 23), 137 | (8, 22), 138 | (25, 12), 139 | (24, 12) 140 | ] 141 | else: 142 | raise NotImplementedError 143 | return np.array(the_hyper_edge) 144 | 145 | 146 | def get_ntu_local_bone_neighbor(): 147 | local_bone_inward_ori_index = [(25, 24), (25, 12), (24, 25), (24, 12), (12, 24), (12, 25), 148 | (11, 12), (11, 10), (10, 11), (10, 9), (9, 10), (9, 21), 149 | (21, 9), (21, 5), (5, 21), (5, 6), (6, 5), (6, 7), (7, 6), (7, 8), 150 | (8, 23), (8, 22), (22, 8), (22, 23), (23, 8), (23, 22), (3, 4), 151 | (3, 21), (2, 21), (2, 1), (1, 17), (1, 13), (17, 18), (17, 1), 152 | (18, 19), (18, 17), (19, 20), (19, 18), (13, 1), (13, 14), 153 | (14, 13), (14, 15), (15, 14), (15, 16) 154 | ] 155 | inward = [(i - 1, j - 1) for (i, j) in local_bone_inward_ori_index] 156 | outward = [(j, i) for (i, j) in inward] 157 | neighbor = inward + outward 158 | return neighbor 159 | 160 | 161 | class LocalBoneAdj: 162 | def __init__(self, dataset): 163 | if 'ntu' in dataset: 164 | num_node = 25 165 | self.edges = get_ntu_local_bone_neighbor() 166 | elif 'kinetics' in dataset: 167 | num_node = 18 168 | raise NotImplementedError 169 | else: 170 | raise NotImplementedError 171 | self.num_nodes = num_node 172 | self.self_loops = [(i, i) for i in range(self.num_nodes)] 173 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes) 174 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes) 175 | 176 | 177 | if __name__ == '__main__': 178 | rtn = get_local_bone_hyper_edges('ntu') 179 | print('rtn: ', rtn) 180 | -------------------------------------------------------------------------------- /graph/kinetics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '') 3 | sys.path.extend(['../']) 4 | 5 | import numpy as np 6 | 7 | from graph import tools 8 | 9 | # Joint index: 10 | # {0, "Nose"} 11 | # {1, "Neck"}, 12 | # {2, "RShoulder"}, 13 | # {3, "RElbow"}, 14 | # {4, "RWrist"}, 15 | # {5, "LShoulder"}, 16 | # {6, "LElbow"}, 17 | # {7, "LWrist"}, 18 | # {8, "RHip"}, 19 | # {9, "RKnee"}, 20 | # {10, "RAnkle"}, 21 | # {11, "LHip"}, 22 | # {12, "LKnee"}, 23 | # {13, "LAnkle"}, 24 | # {14, "REye"}, 25 | # {15, "LEye"}, 26 | # {16, "REar"}, 27 | # {17, "LEar"}, 28 | 29 | num_node = 18 30 | self_link = [(i, i) for i in range(num_node)] 31 | inward = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), (10, 9), (9, 8), 32 | (11, 5), (8, 2), (5, 1), (2, 1), (0, 1), (15, 0), (14, 0), (17, 15), 33 | (16, 14)] 34 | outward = [(j, i) for (i, j) in inward] 35 | neighbor = inward + outward 36 | 37 | 38 | class AdjMatrixGraph: 39 | def __init__(self, *args, **kwargs): 40 | self.num_nodes = num_node 41 | self.edges = neighbor 42 | self.self_loops = [(i, i) for i in range(self.num_nodes)] 43 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes) 44 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes) 45 | 46 | 47 | if __name__ == '__main__': 48 | graph = AdjMatrixGraph() 49 | A_binary = graph.A_binary 50 | import matplotlib.pyplot as plt 51 | print(A_binary) 52 | plt.matshow(A_binary) 53 | plt.show() 54 | -------------------------------------------------------------------------------- /graph/ntu_rgb_d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '') 3 | sys.path.extend(['../']) 4 | 5 | import numpy as np 6 | 7 | from graph import tools 8 | 9 | num_node = 25 10 | self_link = [(i, i) for i in range(num_node)] 11 | inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), 12 | (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1), 13 | (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), 14 | (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)] 15 | inward = [(i - 1, j - 1) for (i, j) in inward_ori_index] 16 | outward = [(j, i) for (i, j) in inward] 17 | neighbor = inward + outward 18 | 19 | 20 | class AdjMatrixGraph: 21 | def __init__(self, *args, **kwargs): 22 | self.edges = neighbor 23 | self.num_nodes = num_node 24 | self.self_loops = [(i, i) for i in range(self.num_nodes)] 25 | self.A_binary = tools.get_adjacency_matrix(self.edges, self.num_nodes) 26 | self.A_binary_with_I = tools.get_adjacency_matrix(self.edges + self.self_loops, self.num_nodes) 27 | self.A = tools.normalize_adjacency_matrix(self.A_binary) 28 | 29 | 30 | if __name__ == '__main__': 31 | import matplotlib.pyplot as plt 32 | graph = AdjMatrixGraph() 33 | A, A_binary, A_binary_with_I = graph.A, graph.A_binary, graph.A_binary_with_I 34 | f, ax = plt.subplots(1, 3) 35 | ax[0].imshow(A_binary_with_I, cmap='gray') 36 | ax[1].imshow(A_binary, cmap='gray') 37 | ax[2].imshow(A, cmap='gray') 38 | plt.show() 39 | print(A_binary_with_I.shape, A_binary.shape, A.shape) 40 | 41 | 42 | class Graph: 43 | def __init__(self, labeling_mode='spatial'): 44 | self.A = self.get_adjacency_matrix(labeling_mode) 45 | self.num_node = num_node 46 | self.self_link = self_link 47 | self.inward = inward 48 | self.outward = outward 49 | self.neighbor = neighbor 50 | 51 | def get_adjacency_matrix(self, labeling_mode=None): 52 | if labeling_mode is None: 53 | return self.A 54 | if labeling_mode == 'spatial': 55 | A = tools.get_spatial_graph(num_node, self_link, inward, outward) 56 | else: 57 | raise ValueError() 58 | return A 59 | -------------------------------------------------------------------------------- /graph/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def edge2mat(link, num_node): 5 | A = np.zeros((num_node, num_node)) 6 | for i, j in link: 7 | A[j, i] = 1 8 | return A 9 | 10 | 11 | def normalize_digraph(A): 12 | Dl = np.sum(A, 0) 13 | h, w = A.shape 14 | Dn = np.zeros((w, w)) 15 | for i in range(w): 16 | if Dl[i] > 0: 17 | Dn[i, i] = Dl[i] ** (-1) 18 | AD = np.dot(A, Dn) 19 | return AD 20 | 21 | 22 | def get_spatial_graph(num_node, self_link, inward, outward): 23 | I = edge2mat(self_link, num_node) 24 | In = normalize_digraph(edge2mat(inward, num_node)) 25 | Out = normalize_digraph(edge2mat(outward, num_node)) 26 | A = np.stack((I, In, Out)) 27 | return A 28 | 29 | 30 | def k_adjacency(A, k, with_self=False, self_factor=1): 31 | assert isinstance(A, np.ndarray) 32 | I = np.eye(len(A), dtype=A.dtype) 33 | if k == 0: 34 | return I 35 | Ak = np.minimum(np.linalg.matrix_power(A + I, k), 1) \ 36 | - np.minimum(np.linalg.matrix_power(A + I, k - 1), 1) 37 | if with_self: 38 | Ak += (self_factor * I) 39 | return Ak 40 | 41 | 42 | def normalize_adjacency_matrix(A): 43 | node_degrees = A.sum(-1) 44 | degs_inv_sqrt = np.power(node_degrees, -0.5) 45 | norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt 46 | return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32) 47 | 48 | 49 | def get_adjacency_matrix(edges, num_nodes): 50 | A = np.zeros((num_nodes, num_nodes), dtype=np.float32) 51 | for edge in edges: 52 | A[edge] = 1. 53 | return A -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import network 2 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/activation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/activation.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/att_gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/att_gcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hyper_gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/hyper_gcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ms_gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_gcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ms_gtcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_gtcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ms_tcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/ms_tcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/model/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /model/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.modules import Sine 5 | 6 | 7 | def activation_factory(name, inplace=True): 8 | if name == 'relu': 9 | return nn.ReLU(inplace=inplace) 10 | elif name == 'leakyrelu': 11 | return nn.LeakyReLU(0.2, inplace=inplace) 12 | elif name == 'tanh': 13 | return nn.Tanh() 14 | elif name == 'linear' or name is None: 15 | return nn.Identity() 16 | elif name == 'sine': 17 | return Sine() 18 | else: 19 | raise ValueError('Not supported activation:', name) -------------------------------------------------------------------------------- /model/att_gcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 5 | 6 | from graph.ang_adjs import get_ang_adjs 7 | from graph.hyper_graphs import get_hyper_edge 8 | 9 | sys.path.insert(0, '') 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | from graph.tools import k_adjacency, normalize_adjacency_matrix 17 | from model.mlp import MLP 18 | from model.activation import activation_factory 19 | 20 | 21 | class Att_GraphConv(nn.Module): 22 | def __init__(self, 23 | in_channels, 24 | out_channels, 25 | dropout=0, 26 | activation='relu', 27 | **kwargs): 28 | super().__init__() 29 | 30 | self.local_bone_hyper_edges = get_hyper_edge('ntu', 'local_bone') 31 | self.center_hyper_edges = get_hyper_edge('ntu', 'center') 32 | self.figure_l_hyper_edges = get_hyper_edge('ntu', 'figure_l') 33 | self.figure_r_hyper_edges = get_hyper_edge('ntu', 'figure_r') 34 | self.hand_hyper_edges = get_hyper_edge('ntu', 'hand') 35 | # self.foot_hyper_edges = get_hyper_edge('ntu', 'foot') 36 | 37 | self.hyper_edge_num = 8 38 | self.in_fea_mlp = MLP(self.hyper_edge_num, [50, out_channels], dropout=dropout, activation=activation) 39 | self.in_fea_mlp_last = nn.Conv2d(out_channels, out_channels, kernel_size=1) 40 | 41 | def process_hyper_edge_w(self, he_w, device): 42 | he_w = he_w.repeat(1, 1, 1, he_w.shape[-2]) 43 | for i in range(he_w.shape[0]): 44 | for j in range(he_w.shape[1]): 45 | he_w[i][j] *= torch.eye(he_w.shape[-2]).to(device) 46 | return he_w 47 | 48 | def normalized_aggregate(self, w, h): 49 | degree_v = torch.einsum('ve,bte->btv', h, w) 50 | degree_e = torch.sum(h, dim=0) 51 | degree_v = torch.pow(degree_v, -0.5) 52 | # degree_v = torch.pow(degree_v, -1) 53 | degree_e = torch.pow(degree_e, -1) 54 | degree_v[degree_v == float("Inf")] = 0 55 | degree_v[degree_v != degree_v] = 0 56 | degree_e[degree_e == float("Inf")] = 0 57 | degree_e[degree_e != degree_e] = 0 58 | dh = torch.einsum('btv,ve->btve', degree_v, h) 59 | dhw = torch.einsum('btve,bte->btve', dh, w) 60 | dhwb = torch.einsum('btve,e->btve', dhw, degree_e) 61 | dhwbht = torch.einsum('btve,eu->btvu', dhwb, torch.transpose(h, 0, 1)) 62 | dhwbhtd = torch.einsum('btvu,btu->btvu', dhwbht, degree_v) 63 | if torch.max(dhwbhtd).item() != torch.max(dhwbhtd).item(): 64 | print('max h: ', torch.max(h).item(), 'min h: ', torch.min(h).item()) 65 | print('max w: ', torch.max(w).item(), 'min w: ', torch.min(w).item()) 66 | print('max degree v: ', torch.max(degree_v).item(), 'min degree v: ', torch.min(degree_v).item()) 67 | print('max degree e: ', torch.max(degree_e).item(), 'min degree e: ', torch.min(degree_e).item()) 68 | print('max dh: ', torch.max(dh).item(), 'min dh: ', torch.min(dh).item()) 69 | print('max dhw: ', torch.max(dhw).item(), 'min dhw: ', torch.min(dhw).item()) 70 | print('max dhwb: ', torch.max(dhwb).item(), 'min dhwb: ', torch.min(dhwb).item()) 71 | print('max dhwbht: ', torch.max(dhwbht).item(), 'min dhwbht: ', torch.min(dhwbht).item()) 72 | print('max dhwbhtd: ', torch.max(dhwbhtd).item(), 'min dhwbhtd: ', torch.min(dhwbhtd).item()) 73 | assert 0 74 | 75 | # dhwbhtd[dhwbhtd != dhwbhtd] = 0 76 | return dhwbhtd 77 | 78 | def att_convolve(self, x): 79 | cor_w = x[:, :3, :, :] 80 | local_bone_w = x[:, 6, :, :].unsqueeze(1) # 6 81 | center_w = x[:, 7, :, :].unsqueeze(1) # 7 82 | figure_l_w = x[:, 9, :, :].unsqueeze(1) # 9 83 | figure_r_w = x[:, 10, :, :].unsqueeze(1) # 10 84 | hand_w = x[:, 11, :, :].unsqueeze(1) # 11 85 | 86 | # Make channels more 87 | in_fea = torch.cat((cor_w, local_bone_w, center_w, figure_l_w, figure_r_w, hand_w), dim=1) 88 | in_fea = self.in_fea_mlp(in_fea) 89 | in_fea = self.in_fea_mlp_last(in_fea) 90 | in_fea = in_fea.permute(0, 2, 3, 1) 91 | 92 | in_fea = torch.einsum('btvm,btmu->btvu', in_fea, in_fea.permute(0, 1, 3, 2)) 93 | in_fea = torch.softmax(in_fea, dim=-1) 94 | return in_fea 95 | 96 | def forward(self, x): 97 | return self.att_convolve(x) 98 | 99 | 100 | if __name__ == "__main__": 101 | from graph.ntu_rgb_d import AdjMatrixGraph 102 | 103 | graph = AdjMatrixGraph() 104 | A_binary = graph.A_binary 105 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary) 106 | msgcn.forward(torch.randn(16, 3, 30, 25)) 107 | -------------------------------------------------------------------------------- /model/hyper_gcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 5 | 6 | from graph.ang_adjs import get_ang_adjs 7 | from graph.hyper_graphs import get_hyper_edge 8 | 9 | sys.path.insert(0, '') 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | from graph.tools import k_adjacency, normalize_adjacency_matrix 17 | from model.mlp import MLP 18 | from model.activation import activation_factory 19 | 20 | 21 | class Hyper_GraphConv(nn.Module): 22 | def __init__(self, 23 | in_channels, 24 | out_channels, 25 | dropout=0, 26 | activation='relu', 27 | **kwargs): 28 | super().__init__() 29 | 30 | self.local_bone_hyper_edges = get_hyper_edge('ntu', 'local_bone') 31 | self.center_hyper_edges = get_hyper_edge('ntu', 'center') 32 | self.figure_l_hyper_edges = get_hyper_edge('ntu', 'figure_l') 33 | self.figure_r_hyper_edges = get_hyper_edge('ntu', 'figure_r') 34 | self.hand_hyper_edges = get_hyper_edge('ntu', 'hand') 35 | # self.foot_hyper_edges = get_hyper_edge('ntu', 'foot') 36 | 37 | self.hyper_edge_num = 5 38 | 39 | self.mlp = MLP(in_channels * self.hyper_edge_num, [out_channels], dropout=dropout, activation=activation) 40 | 41 | self.fea_mlp = MLP(self.hyper_edge_num, [50, 50, self.hyper_edge_num], dropout=dropout, activation=activation) 42 | 43 | def process_hyper_edge_w(self, he_w, device): 44 | he_w = he_w.repeat(1, 1, 1, he_w.shape[-2]) 45 | for i in range(he_w.shape[0]): 46 | for j in range(he_w.shape[1]): 47 | he_w[i][j] *= torch.eye(he_w.shape[-2]).to(device) 48 | return he_w 49 | 50 | def normalized_aggregate(self, w, h): 51 | degree_v = torch.einsum('ve,bte->btv', h, w) 52 | degree_e = torch.sum(h, dim=0) 53 | degree_v = torch.pow(degree_v, -0.5) 54 | # degree_v = torch.pow(degree_v, -1) 55 | degree_e = torch.pow(degree_e, -1) 56 | degree_v[degree_v == float("Inf")] = 0 57 | degree_v[degree_v != degree_v] = 0 58 | degree_e[degree_e == float("Inf")] = 0 59 | degree_e[degree_e != degree_e] = 0 60 | dh = torch.einsum('btv,ve->btve', degree_v, h) 61 | dhw = torch.einsum('btve,bte->btve', dh, w) 62 | dhwb = torch.einsum('btve,e->btve', dhw, degree_e) 63 | dhwbht = torch.einsum('btve,eu->btvu', dhwb, torch.transpose(h, 0, 1)) 64 | dhwbhtd = torch.einsum('btvu,btu->btvu', dhwbht, degree_v) 65 | if torch.max(dhwbhtd).item() != torch.max(dhwbhtd).item(): 66 | print('max h: ', torch.max(h).item(), 'min h: ', torch.min(h).item()) 67 | print('max w: ', torch.max(w).item(), 'min w: ', torch.min(w).item()) 68 | print('max degree v: ', torch.max(degree_v).item(), 'min degree v: ', torch.min(degree_v).item()) 69 | print('max degree e: ', torch.max(degree_e).item(), 'min degree e: ', torch.min(degree_e).item()) 70 | print('max dh: ', torch.max(dh).item(), 'min dh: ', torch.min(dh).item()) 71 | print('max dhw: ', torch.max(dhw).item(), 'min dhw: ', torch.min(dhw).item()) 72 | print('max dhwb: ', torch.max(dhwb).item(), 'min dhwb: ', torch.min(dhwb).item()) 73 | print('max dhwbht: ', torch.max(dhwbht).item(), 'min dhwbht: ', torch.min(dhwbht).item()) 74 | print('max dhwbhtd: ', torch.max(dhwbhtd).item(), 'min dhwbhtd: ', torch.min(dhwbhtd).item()) 75 | assert 0 76 | 77 | # dhwbhtd[dhwbhtd != dhwbhtd] = 0 78 | return dhwbhtd 79 | 80 | def hyper_edge_convolve(self, x): 81 | self.local_bone_hyper_edges = self.local_bone_hyper_edges.to(x.device) 82 | self.center_hyper_edges = self.center_hyper_edges.to(x.device) 83 | self.figure_l_hyper_edges = self.figure_l_hyper_edges.to(x.device) 84 | self.figure_r_hyper_edges = self.figure_r_hyper_edges.to(x.device) 85 | self.hand_hyper_edges = self.hand_hyper_edges.to(x.device) 86 | # self.foot_hyper_edges = self.foot_hyper_edges.to(x.device) 87 | 88 | # Not make the angular feature learnable 89 | # print('max x: ', torch.max(x).item(), 'min x: ', torch.min(x).item()) 90 | local_bone_w = x[:, 6, :, :] # 6 91 | center_w = x[:, 7, :, :] # 7 92 | figure_l_w = x[:, 9, :, :] # 9 93 | figure_r_w = x[:, 10, :, :] # 10 94 | hand_w = x[:, 11, :, :] # 11 95 | # foot_w = x[:, 14, :, :].unsqueeze(1) # 14 96 | 97 | # Makes the angular feature learnble 98 | # local_bone_w = x[:, 0, :, :].unsqueeze(1) # 6 99 | # center_w = x[:, 0, :, :].unsqueeze(1) # 7 100 | # figure_l_w = x[:, 0, :, :].unsqueeze(1) # 9 101 | # figure_r_w = x[:, 0, :, :].unsqueeze(1) # 10 102 | # hand_w = x[:, 0, :, :].unsqueeze(1) # 11 103 | # # foot_w = x[:, 14, :, :].unsqueeze(1) # 14 104 | # 105 | # fea_w_cat = torch.cat((local_bone_w, center_w, figure_l_w, figure_r_w, 106 | # hand_w), dim=1) 107 | # fea_w_cat = self.fea_mlp(fea_w_cat) 108 | # local_bone_w = fea_w_cat[:, 0, :, :] 109 | # center_w = fea_w_cat[:, 1, :, :] 110 | # figure_l_w = fea_w_cat[:, 2, :, :] 111 | # figure_r_w = fea_w_cat[:, 3, :, :] 112 | # hand_w = fea_w_cat[:, 4, :, :] 113 | # # foot_w = fea_w_cat[:, 5, :, :] 114 | 115 | # local bone angle 116 | # local_bone_hwh = torch.einsum('ve,bte->btve', self.local_bone_hyper_edges, local_bone_w) 117 | # local_bone_hwh = torch.einsum('btvu,un->btvn', local_bone_hwh, 118 | # torch.transpose(self.local_bone_hyper_edges, 0, 1)) 119 | local_bone_hwh = self.normalized_aggregate(local_bone_w, self.local_bone_hyper_edges) 120 | 121 | # center angle 122 | # center_hwh = torch.einsum('ve,bte->btve', self.center_hyper_edges, center_w) 123 | # center_hwh = torch.einsum('btvu,un->btvn', center_hwh, 124 | # torch.transpose(self.center_hyper_edges, 0, 1)) 125 | center_hwh = self.normalized_aggregate(center_w, self.center_hyper_edges) 126 | 127 | # figure left angle 128 | # figure_l_hwh = torch.einsum('ve,bte->btve', self.figure_l_hyper_edges, figure_l_w) 129 | # figure_l_hwh = torch.einsum('btvu,un->btvn', figure_l_hwh, 130 | # torch.transpose(self.figure_l_hyper_edges, 0, 1)) 131 | figure_l_hwh = self.normalized_aggregate(figure_l_w, self.figure_l_hyper_edges) 132 | 133 | # figure right angle 134 | # figure_r_hwh = torch.einsum('ve,bte->btve', self.figure_r_hyper_edges, figure_r_w) 135 | # figure_r_hwh = torch.einsum('btvu,un->btvn', figure_r_hwh, 136 | # torch.transpose(self.figure_r_hyper_edges, 0, 1)) 137 | figure_r_hwh = self.normalized_aggregate(figure_r_w, self.figure_r_hyper_edges) 138 | 139 | # hand angle 140 | # hand_hwh = torch.einsum('ve,bte->btve', self.hand_hyper_edges, hand_w) 141 | # hand_hwh = torch.einsum('btvu,un->btvn', hand_hwh, 142 | # torch.transpose(self.hand_hyper_edges, 0, 1)) 143 | hand_hwh = self.normalized_aggregate(hand_w, self.hand_hyper_edges) 144 | 145 | # foot angle 146 | # foot_hwh = torch.einsum('ve,bte->btve', self.foot_hyper_edges, foot_w) 147 | # foot_hwh = torch.einsum('btvu,un->btvn', foot_hwh, 148 | # torch.transpose(self.foot_hyper_edges, 0, 1)) 149 | 150 | hwh_cat = torch.cat((local_bone_hwh, center_hwh, figure_l_hwh, figure_r_hwh, 151 | hand_hwh), dim=-2) 152 | 153 | # Softmax normalization 154 | # hwh_cat = torch.softmax(hwh_cat, dim=-1) 155 | 156 | # Use partial features 157 | x = x[:, :3, :, :] 158 | N, C, T, V = x.shape 159 | 160 | support = torch.einsum('btvu,bctu->bctv', hwh_cat, x) 161 | support = support.view(N, C, T, self.hyper_edge_num, V) 162 | support = support.permute(0, 3, 1, 2, 4).contiguous().view(N, self.hyper_edge_num * C, T, V) 163 | out = self.mlp(support) 164 | 165 | return out 166 | 167 | def forward(self, x): 168 | return self.hyper_edge_convolve(x) 169 | 170 | 171 | if __name__ == "__main__": 172 | from graph.ntu_rgb_d import AdjMatrixGraph 173 | 174 | graph = AdjMatrixGraph() 175 | A_binary = graph.A_binary 176 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary) 177 | msgcn.forward(torch.randn(16, 3, 30, 25)) 178 | -------------------------------------------------------------------------------- /model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.activation import activation_factory 6 | from utils_dir.utils_visual import plot_multiple_lines 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, in_channels, out_channels, activation='relu', dropout=0): 11 | super().__init__() 12 | channels = [in_channels] + out_channels 13 | self.layers = nn.ModuleList() 14 | for i in range(1, len(channels)): 15 | if dropout > 0.001: 16 | self.layers.append(nn.Dropout(p=dropout)) 17 | self.layers.append(nn.Conv2d(channels[i-1], channels[i], kernel_size=1)) 18 | self.layers.append(nn.BatchNorm2d(channels[i])) 19 | self.layers.append(activation_factory(activation, inplace=False)) 20 | 21 | def forward(self, x): 22 | # Input shape: (N,C,T,V) 23 | # 这里是学习同一个joint的不同尺度的信息. 24 | # the_mlp_w = torch.sum(self.layers[0].weight, dim=0).squeeze().view(13, -1).sum(0).cpu().numpy() 25 | # print('the_mlp_w: ', the_mlp_w) 26 | # plot_multiple_lines([the_mlp_w]) 27 | # assert 0 28 | for layer in self.layers: 29 | x = layer(x) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from collections import OrderedDict 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Sine(nn.Module): 10 | def __init(self): 11 | super().__init__() 12 | 13 | def forward(self, input): 14 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 15 | # return torch.sin(30 * input) 16 | return torch.sin(input) 17 | 18 | 19 | def sine_init(m): 20 | with torch.no_grad(): 21 | if hasattr(m, 'weight'): 22 | num_input = m.weight.size(-1) 23 | # See supplement Sec. 1.5 for discussion of factor 30 24 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) 25 | 26 | 27 | def first_layer_sine_init(m): 28 | with torch.no_grad(): 29 | if hasattr(m, 'weight'): 30 | num_input = m.weight.size(-1) 31 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 32 | m.weight.uniform_(-1 / num_input, 1 / num_input) 33 | 34 | 35 | ################### 36 | # Complex operators 37 | def compl_conj(x): 38 | y = x.clone() 39 | y[..., 1::2] = -1 * y[..., 1::2] 40 | return y 41 | 42 | 43 | def compl_div(x, y): 44 | ''' x / y ''' 45 | a = x[..., ::2] 46 | b = x[..., 1::2] 47 | c = y[..., ::2] 48 | d = y[..., 1::2] 49 | 50 | outr = (a * c + b * d) / (c ** 2 + d ** 2) 51 | outi = (b * c - a * d) / (c ** 2 + d ** 2) 52 | out = torch.zeros_like(x) 53 | out[..., ::2] = outr 54 | out[..., 1::2] = outi 55 | return out 56 | 57 | 58 | def compl_mul(x, y): 59 | ''' x * y ''' 60 | a = x[..., ::2] 61 | b = x[..., 1::2] 62 | c = y[..., ::2] 63 | d = y[..., 1::2] 64 | 65 | outr = a * c - b * d 66 | outi = (a + b) * (c + d) - a * c - b * d 67 | out = torch.zeros_like(x) 68 | out[..., ::2] = outr 69 | out[..., 1::2] = outi 70 | return out 71 | -------------------------------------------------------------------------------- /model/ms_gcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | from torch.nn import TransformerEncoderLayer, TransformerEncoder 5 | 6 | from graph.ang_adjs import get_ang_adjs 7 | from model.hyper_gcn import Hyper_GraphConv 8 | 9 | sys.path.insert(0, '') 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | from graph.tools import k_adjacency, normalize_adjacency_matrix 17 | from model.mlp import MLP 18 | from model.activation import activation_factory 19 | 20 | 21 | class MultiScale_GraphConv(nn.Module): 22 | def __init__(self, 23 | num_scales, 24 | in_channels, 25 | out_channels, 26 | A_binary, 27 | disentangled_agg=True, 28 | use_mask=True, 29 | dropout=0, 30 | activation='relu', 31 | to_use_hyper_conv=False, 32 | **kwargs): 33 | super().__init__() 34 | self.num_scales = num_scales 35 | 36 | if disentangled_agg: 37 | A_powers = [k_adjacency(A_binary, k, with_self=True) for k in range(num_scales)] 38 | A_powers = np.concatenate([normalize_adjacency_matrix(g) for g in A_powers]) 39 | else: 40 | A_powers = [A_binary + np.eye(len(A_binary)) for k in range(num_scales)] 41 | A_powers = [normalize_adjacency_matrix(g) for g in A_powers] 42 | A_powers = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_powers)] 43 | A_powers = np.concatenate(A_powers) 44 | 45 | self.A_powers = torch.Tensor(A_powers) 46 | 47 | if 'hyper_conv' in kwargs and kwargs['hyper_conv'] == 'ntu': 48 | hyper_adjs = get_ang_adjs('ntu') 49 | self.A_powers = torch.cat((self.A_powers, hyper_adjs), dim=0) 50 | if kwargs['hyper_conv'] == 'ntu': 51 | self.num_scales += 6 52 | elif kwargs['hyper_conv'] == 'kinetics': 53 | self.num_scales += 4 54 | 55 | # self.A_powers_param = torch.nn.Parameter(self.A_powers) 56 | 57 | self.use_mask = use_mask 58 | if use_mask: 59 | # NOTE: the inclusion of residual mask appears to slow down training noticeably 60 | self.A_res = nn.init.uniform_(nn.Parameter(torch.Tensor(self.A_powers.shape)), -1e-6, 1e-6) 61 | 62 | # 这个MLP根本就不是MLP, 这是卷积, 只是类似于MLP的功能. 63 | self.mlp = MLP(in_channels * self.num_scales, [out_channels], dropout=dropout, activation=activation) 64 | 65 | # Spatial Transformer Attention 66 | if 'to_use_spatial_transformer' in kwargs and kwargs['to_use_spatial_transformer']: 67 | self.to_use_spatial_trans = True 68 | self.trans_conv = nn.Conv2d(out_channels, 1, (1, 1), (1, 1)) 69 | self.temporal_len = kwargs['temporal_len'] 70 | nhead = 5 71 | nlayers = 2 72 | trans_dropout = 0.5 73 | encoder_layers = nn.TransformerEncoderLayer(self.temporal_len, 74 | nhead=nhead, dropout=trans_dropout) 75 | self.trans_enc = nn.TransformerEncoder(encoder_layers, nlayers) 76 | 77 | # spatial point normalization 78 | self.point_norm_layer = nn.Sigmoid() 79 | 80 | else: 81 | self.to_use_spatial_trans = False 82 | 83 | if 'to_use_sptl_trans_feature' in kwargs and kwargs['to_use_sptl_trans_feature']: 84 | self.to_use_sptl_trans_feature = True 85 | self.fea_dim = kwargs['fea_dim'] 86 | encoder_layers = nn.TransformerEncoderLayer(self.fea_dim, 87 | nhead=kwargs['sptl_trans_feature_n_head'], 88 | dropout=0.5) 89 | self.trans_enc_fea = nn.TransformerEncoder(encoder_layers, 90 | kwargs['sptl_trans_feature_n_layer']) 91 | else: 92 | self.to_use_sptl_trans_feature = False 93 | 94 | def forward(self, x): 95 | N, C, T, V = x.shape 96 | self.A_powers = self.A_powers.to(x.device) 97 | 98 | A = self.A_powers.to(x.dtype) 99 | if self.use_mask: 100 | A = A + self.A_res.to(x.dtype) 101 | 102 | support = torch.einsum('vu,nctu->nctv', A, x) 103 | 104 | support = support.view(N, C, T, self.num_scales, V) 105 | support = support.permute(0, 3, 1, 2, 4).contiguous().view(N, self.num_scales * C, T, V) 106 | 107 | out = self.mlp(support) 108 | 109 | # 实现kernel中, 只实现了一半. 110 | # out = torch.einsum('nijtv,njktv->niktv', out.unsqueeze(2), out.unsqueeze(1)).view( 111 | # N, self.out_channels * self.out_channels, T, V 112 | # ) 113 | 114 | if self.to_use_spatial_trans: 115 | out_mean = self.trans_conv(out).squeeze() 116 | out_mean = out_mean.permute(0, 2, 1) 117 | out_mean = self.trans_enc(out_mean) 118 | out_mean = self.point_norm_layer(out_mean) 119 | out_mean = out_mean.permute(0, 2, 1) 120 | out_mean = torch.unsqueeze(out_mean, dim=1).repeat(1, out.shape[1], 1, 1) 121 | out = out_mean * out 122 | 123 | if self.to_use_sptl_trans_feature: 124 | out = out.permute(2, 3, 0, 1) 125 | for a_out_idx in range(len(out)): 126 | a_out = out[a_out_idx] 127 | a_out = self.trans_enc_fea(a_out) 128 | out[a_out_idx] = a_out 129 | out = out.permute(2, 3, 0, 1) 130 | 131 | return out 132 | 133 | 134 | if __name__ == "__main__": 135 | from graph.ntu_rgb_d import AdjMatrixGraph 136 | 137 | graph = AdjMatrixGraph() 138 | A_binary = graph.A_binary 139 | msgcn = MultiScale_GraphConv(num_scales=15, in_channels=3, out_channels=64, A_binary=A_binary) 140 | msgcn.forward(torch.randn(16, 3, 30, 25)) 141 | -------------------------------------------------------------------------------- /model/ms_gtcn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '') 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from model.ms_tcn import MultiScale_TemporalConv as MS_TCN 10 | from model.mlp import MLP 11 | from model.activation import activation_factory 12 | from graph.tools import k_adjacency, normalize_adjacency_matrix 13 | 14 | 15 | class UnfoldTemporalWindows(nn.Module): 16 | def __init__(self, window_size, window_stride, window_dilation=1): 17 | super().__init__() 18 | self.window_size = window_size 19 | self.window_stride = window_stride 20 | self.window_dilation = window_dilation 21 | 22 | self.padding = (window_size + (window_size-1) * (window_dilation-1) - 1) // 2 23 | # unfold之后的结果就是: batch size x receptive filed size x 多少个receptive field 24 | self.unfold = nn.Unfold(kernel_size=(self.window_size, 1), 25 | dilation=(self.window_dilation, 1), 26 | stride=(self.window_stride, 1), 27 | padding=(self.padding, 0)) 28 | 29 | def forward(self, x): 30 | # Input shape: (N,C,T,V), out: (N,C,T,V*window_size) 31 | N, C, T, V = x.shape 32 | x = self.unfold(x) 33 | # Permute extra channels from window size to the graph dimension; -1 for number of windows 34 | x = x.view(N, C, self.window_size, -1, V).permute(0,1,3,2,4).contiguous() 35 | x = x.view(N, C, -1, self.window_size * V) 36 | return x 37 | 38 | 39 | class SpatialTemporal_MS_GCN(nn.Module): 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | A_binary, 44 | num_scales, 45 | window_size, 46 | disentangled_agg=True, 47 | use_Ares=True, 48 | residual=False, 49 | dropout=0, 50 | activation='relu'): 51 | 52 | super().__init__() 53 | self.num_scales = num_scales 54 | self.window_size = window_size 55 | self.use_Ares = use_Ares 56 | A = self.build_spatial_temporal_graph(A_binary, window_size) 57 | 58 | if disentangled_agg: 59 | A_scales = [k_adjacency(A, k, with_self=True) for k in range(num_scales)] 60 | A_scales = np.concatenate([normalize_adjacency_matrix(g) for g in A_scales]) 61 | else: 62 | # Self-loops have already been included in A 63 | A_scales = [normalize_adjacency_matrix(A) for k in range(num_scales)] 64 | A_scales = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_scales)] 65 | A_scales = np.concatenate(A_scales) 66 | 67 | self.A_scales = torch.Tensor(A_scales) 68 | self.V = len(A_binary) 69 | 70 | if use_Ares: 71 | self.A_res = nn.init.uniform_(nn.Parameter(torch.randn(self.A_scales.shape)), -1e-6, 1e-6) 72 | else: 73 | self.A_res = torch.tensor(0) 74 | 75 | self.mlp = MLP(in_channels * num_scales, [out_channels], dropout=dropout, activation='linear') 76 | 77 | # Residual connection 78 | if not residual: 79 | self.residual = lambda x: 0 80 | elif (in_channels == out_channels): 81 | self.residual = lambda x: x 82 | else: 83 | self.residual = MLP(in_channels, [out_channels], activation='linear') 84 | 85 | self.act = activation_factory(activation) 86 | 87 | def build_spatial_temporal_graph(self, A_binary, window_size): 88 | assert isinstance(A_binary, np.ndarray), 'A_binary should be of type `np.ndarray`' 89 | V = len(A_binary) 90 | V_large = V * window_size 91 | A_binary_with_I = A_binary + np.eye(len(A_binary), dtype=A_binary.dtype) 92 | # Build spatial-temporal graph 93 | A_large = np.tile(A_binary_with_I, (window_size, window_size)).copy() 94 | return A_large 95 | 96 | def forward(self, x): 97 | N, C, T, V = x.shape # T = number of windows 98 | 99 | # Build graphs 100 | A = self.A_scales.to(x.dtype).to(x.device) + self.A_res.to(x.dtype).to(x.device) 101 | 102 | # Perform Graph Convolution 103 | res = self.residual(x) 104 | agg = torch.einsum('vu,nctu->nctv', A, x) 105 | agg = agg.view(N, C, T, self.num_scales, V) 106 | agg = agg.permute(0,3,1,2,4).contiguous().view(N, self.num_scales*C, T, V) 107 | out = self.mlp(agg) 108 | out += res 109 | return self.act(out) 110 | 111 | -------------------------------------------------------------------------------- /model/ms_tcn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '') 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from model.activation import activation_factory 8 | 9 | 10 | class TemporalConv(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): 12 | super(TemporalConv, self).__init__() 13 | pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2 14 | self.conv = nn.Conv2d( 15 | in_channels, 16 | out_channels, 17 | kernel_size=(kernel_size, 1), 18 | padding=(pad, 0), 19 | stride=(stride, 1), 20 | dilation=(dilation, 1)) 21 | 22 | self.bn = nn.BatchNorm2d(out_channels) 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | x = self.bn(x) 27 | return x 28 | 29 | 30 | class MultiScale_TemporalConv(nn.Module): 31 | def __init__(self, 32 | in_channels, 33 | out_channels, 34 | kernel_size=3, 35 | stride=1, 36 | dilations=[1,2,3,4], 37 | residual=True, 38 | residual_kernel_size=1, 39 | activation='relu', 40 | **kwargs): 41 | 42 | super().__init__() 43 | assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches' 44 | 45 | # Multiple branches of temporal convolution 46 | self.num_branches = len(dilations) + 2 47 | branch_channels = out_channels // self.num_branches 48 | 49 | # Temporal Convolution branches 50 | self.branches = nn.ModuleList([ 51 | nn.Sequential( 52 | nn.Conv2d( 53 | in_channels, 54 | branch_channels, 55 | kernel_size=1, 56 | padding=0), 57 | nn.BatchNorm2d(branch_channels), 58 | activation_factory(activation), 59 | # 在时间轴上对每一个joint做卷积 60 | TemporalConv( 61 | branch_channels, 62 | branch_channels, 63 | kernel_size=kernel_size, 64 | stride=stride, 65 | dilation=dilation), 66 | ) 67 | for dilation in dilations 68 | ]) 69 | 70 | # Additional Max & 1x1 branch 71 | self.branches.append(nn.Sequential( 72 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0), 73 | nn.BatchNorm2d(branch_channels), 74 | activation_factory(activation), 75 | nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)), 76 | nn.BatchNorm2d(branch_channels) 77 | )) 78 | 79 | self.branches.append(nn.Sequential( 80 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)), 81 | nn.BatchNorm2d(branch_channels) 82 | )) 83 | 84 | # Residual connection 85 | if not residual: 86 | self.residual = lambda x: 0 87 | elif (in_channels == out_channels) and (stride == 1): 88 | self.residual = lambda x: x 89 | else: 90 | self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride) 91 | 92 | self.act = activation_factory(activation) 93 | 94 | # Transformer attention 95 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']: 96 | self.to_use_temporal_trans = True 97 | self.section_size = kwargs['section_size'] 98 | self.num_point = kwargs['num_point'] 99 | self.trans_conv = nn.Conv2d(1, 1, (self.section_size, 1), (self.section_size, 1)) 100 | nhead = 5 101 | nlayers = 2 102 | trans_dropout = 0.5 103 | encoder_layers = nn.TransformerEncoderLayer(self.num_point, 104 | nhead=nhead, dropout=trans_dropout) 105 | self.trans_enc = nn.TransformerEncoder(encoder_layers, nlayers) 106 | 107 | # frame normalization 108 | self.frame_norm_layer = nn.Softmax(dim=1) 109 | if 'frame_norm' in kwargs: 110 | if kwargs['frame_norm'] == 'sigmoid': 111 | self.frame_norm_layer = nn.Sigmoid() 112 | 113 | else: 114 | self.to_use_temporal_trans = False 115 | 116 | # Transformer feature 117 | if 'to_use_temp_trans_feature' in kwargs and kwargs['to_use_temp_trans_feature']: 118 | self.to_use_temp_trans_feature = True 119 | self.fea_dim = kwargs['fea_dim'] 120 | nhead = kwargs['temp_trans_feature_n_head'] 121 | nlayers = kwargs['temp_trans_feature_n_layer'] 122 | trans_dropout = 0.5 123 | encoder_layers = nn.TransformerEncoderLayer(self.fea_dim, 124 | nhead=nhead, dropout=trans_dropout) 125 | self.trans_enc_fea = nn.TransformerEncoder(encoder_layers, nlayers) 126 | else: 127 | self.to_use_temp_trans_feature = False 128 | 129 | def forward(self, x): 130 | # Input dim: (N,C,T,V) 131 | res = self.residual(x) 132 | branch_outs = [] 133 | tempconv_idx = 0 134 | for tempconv in self.branches: 135 | x_in = x 136 | if self.to_use_temporal_trans: 137 | x_mean = torch.mean(x, dim=1) 138 | x_mean = x_mean.unsqueeze(1) 139 | x_mean = self.trans_conv(x_mean).squeeze(1) 140 | x_mean = self.trans_enc(x_mean) 141 | x_mean = self.frame_norm_layer(x_mean) 142 | x_mean = torch.repeat_interleave(x_mean, self.section_size, dim=1) 143 | x_mean = torch.unsqueeze(x_mean, dim=1).repeat(1, x.shape[1], 1, 1) 144 | x_in = x * x_mean 145 | out = tempconv(x_in) 146 | branch_outs.append(out) 147 | 148 | out = torch.cat(branch_outs, dim=1) 149 | 150 | if self.to_use_temp_trans_feature: 151 | out = out.permute(3, 2, 0, 1) 152 | for a_out_idx in range(len(out)): 153 | a_out = out[a_out_idx] 154 | a_out = self.trans_enc_fea(a_out) 155 | out[a_out_idx] = a_out 156 | out = out.permute(2, 3, 1, 0) 157 | 158 | out += res 159 | out = self.act(out) 160 | 161 | return out 162 | 163 | 164 | if __name__ == "__main__": 165 | mstcn = MultiScale_TemporalConv(288, 288) 166 | x = torch.randn(32, 288, 100, 20) 167 | mstcn.forward(x) 168 | for name, param in mstcn.named_parameters(): 169 | print(f'{name}: {param.numel()}') 170 | print(sum(p.numel() for p in mstcn.parameters() if p.requires_grad)) -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | 4 | from model.att_gcn import Att_GraphConv 5 | from model.hyper_gcn import Hyper_GraphConv 6 | # from model.transformers import get_pretrained_transformer 7 | 8 | sys.path.insert(0, '') 9 | 10 | import math 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from utils import import_class, count_params 17 | from model.ms_gcn import MultiScale_GraphConv as MS_GCN 18 | from model.ms_tcn import MultiScale_TemporalConv as MS_TCN 19 | from model.ms_gtcn import SpatialTemporal_MS_GCN, UnfoldTemporalWindows 20 | from model.mlp import MLP 21 | from model.activation import activation_factory 22 | 23 | 24 | class MS_G3D(nn.Module): 25 | def __init__(self, 26 | in_channels, 27 | out_channels, 28 | A_binary, 29 | num_scales, 30 | window_size, 31 | window_stride, 32 | window_dilation, 33 | embed_factor=1, 34 | nonlinear='relu'): 35 | super().__init__() 36 | 37 | self.window_size = window_size 38 | self.out_channels = out_channels 39 | self.embed_channels_in = self.embed_channels_out = out_channels // embed_factor 40 | if embed_factor == 1: 41 | self.in1x1 = nn.Identity() 42 | self.embed_channels_in = self.embed_channels_out = in_channels 43 | # The first STGC block changes channels right away; others change at collapse 44 | if in_channels == 3: 45 | self.embed_channels_out = out_channels 46 | else: 47 | self.in1x1 = MLP(in_channels, [self.embed_channels_in]) 48 | 49 | self.gcn3d = nn.Sequential( 50 | UnfoldTemporalWindows(window_size, window_stride, window_dilation), 51 | SpatialTemporal_MS_GCN( 52 | in_channels=self.embed_channels_in, 53 | out_channels=self.embed_channels_out, 54 | A_binary=A_binary, 55 | num_scales=num_scales, 56 | window_size=window_size, 57 | use_Ares=True, 58 | activation=nonlinear 59 | ) 60 | ) 61 | 62 | self.out_conv = nn.Conv3d(self.embed_channels_out, out_channels, kernel_size=(1, self.window_size, 1)) 63 | self.out_bn = nn.BatchNorm2d(out_channels) 64 | 65 | def forward(self, x): 66 | N, _, T, V = x.shape 67 | x = self.in1x1(x) 68 | # Construct temporal windows and apply MS-GCN 69 | x = self.gcn3d(x) 70 | 71 | # Collapse the window dimension 72 | x = x.view(N, self.embed_channels_out, -1, self.window_size, V) 73 | x = self.out_conv(x).squeeze(dim=3) 74 | x = self.out_bn(x) 75 | 76 | # no activation 77 | return x 78 | 79 | 80 | class MultiWindow_MS_G3D(nn.Module): 81 | def __init__(self, 82 | in_channels, 83 | out_channels, 84 | A_binary, 85 | num_scales, 86 | window_sizes=[3, 5], 87 | window_stride=1, 88 | window_dilations=[1, 1]): 89 | super().__init__() 90 | self.gcn3d = nn.ModuleList([ 91 | MS_G3D( 92 | in_channels, 93 | out_channels, 94 | A_binary, 95 | num_scales, 96 | window_size, 97 | window_stride, 98 | window_dilation 99 | ) 100 | for window_size, window_dilation in zip(window_sizes, window_dilations) 101 | ]) 102 | 103 | def forward(self, x): 104 | # Input shape: (N, C, T, V) 105 | out_sum = 0 106 | for gcn3d in self.gcn3d: 107 | out_sum += gcn3d(x) 108 | # no activation 109 | return out_sum 110 | 111 | ntu_bone_angle_pairs = { 112 | 25: (24, 12), 113 | 24: (25, 12), 114 | 12: (24, 25), 115 | 11: (12, 10), 116 | 10: (11, 9), 117 | 9: (10, 21), 118 | 21: (9, 5), 119 | 5: (21, 6), 120 | 6: (5, 7), 121 | 7: (6, 8), 122 | 8: (23, 22), 123 | 22: (8, 23), 124 | 23: (8, 22), 125 | 3: (4, 21), 126 | 4: (4, 4), 127 | 2: (21, 1), 128 | 1: (17, 13), 129 | 17: (18, 1), 130 | 18: (19, 17), 131 | 19: (20, 18), 132 | 20: (20, 20), 133 | 13: (1, 14), 134 | 14: (13, 15), 135 | 15: (14, 16), 136 | 16: (16, 16) 137 | } 138 | 139 | ntu_bone_adj = { 140 | 25: 12, 141 | 24: 12, 142 | 12: 11, 143 | 11: 10, 144 | 10: 9, 145 | 9: 21, 146 | 21: 21, 147 | 5: 21, 148 | 6: 5, 149 | 7: 6, 150 | 8: 7, 151 | 22: 8, 152 | 23: 8, 153 | 3: 21, 154 | 4: 3, 155 | 2: 21, 156 | 1: 2, 157 | 17: 1, 158 | 18: 17, 159 | 19: 18, 160 | 20: 19, 161 | 13: 1, 162 | 14: 13, 163 | 15: 14, 164 | 16: 15 165 | } 166 | 167 | 168 | class Model(nn.Module): 169 | def __init__(self, 170 | num_class, 171 | num_point, 172 | num_person, 173 | num_gcn_scales, 174 | num_g3d_scales, 175 | graph, 176 | in_channels=3, 177 | ablation='original', 178 | to_use_final_fc=True, 179 | to_fc_last=True, 180 | frame_len=300, 181 | nonlinear='relu', 182 | **kwargs): 183 | super(Model, self).__init__() 184 | 185 | # cosine 186 | self.cos = nn.CosineSimilarity(dim=1, eps=0) 187 | 188 | # Activation function 189 | self.nonlinear_f = activation_factory(nonlinear) 190 | 191 | # ZQ ablation studies 192 | self.ablation = ablation 193 | 194 | Graph = import_class(graph) 195 | A_binary = Graph().A_binary 196 | 197 | self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) 198 | 199 | # channels 200 | # c1 = 96 201 | # c2 = c1 * 2 # 192 202 | # c3 = c2 * 2 # 384 203 | 204 | c1 = 96 205 | self.c1 = c1 206 | c2 = c1 * 2 # 192 # Original implementation 207 | self.c2 = c2 208 | c3 = c2 * 2 # 384 # Original implementation 209 | self.c3 = c3 210 | 211 | # r=3 STGC blocks 212 | 213 | # MSG3D 214 | self.gcn3d1 = MultiWindow_MS_G3D(in_channels, c1, A_binary, num_g3d_scales, window_stride=1) 215 | self.sgcn1_msgcn = MS_GCN(num_gcn_scales, in_channels, c1, A_binary, disentangled_agg=True, 216 | **kwargs, temporal_len=frame_len, fea_dim=c1, to_use_hyper_conv=True, 217 | activation=nonlinear) 218 | self.sgcn1_ms_tcn_1 = MS_TCN(c1, c1, activation=nonlinear) 219 | self.sgcn1_ms_tcn_2 = MS_TCN(c1, c1, activation=nonlinear) 220 | self.sgcn1_ms_tcn_2.act = nn.Identity() 221 | 222 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']: 223 | self.tcn1 = MS_TCN(c1, c1, **kwargs, 224 | section_size=kwargs['section_sizes'][0], num_point=num_point, 225 | fea_dim=c1, activation=nonlinear) 226 | else: 227 | self.tcn1 = MS_TCN(c1, c1, **kwargs, fea_dim=c1, activation=nonlinear) 228 | 229 | # MSG3D 230 | self.gcn3d2 = MultiWindow_MS_G3D(c1, c2, A_binary, num_g3d_scales, window_stride=2) 231 | self.sgcn2_msgcn = MS_GCN(num_gcn_scales, c1, c1, A_binary, disentangled_agg=True, 232 | **kwargs, temporal_len=frame_len, fea_dim=c1, activation=nonlinear) 233 | self.sgcn2_ms_tcn_1 = MS_TCN(c1, c2, stride=2, activation=nonlinear) 234 | # self.sgcn2_ms_tcn_1 = MS_TCN(c1, c2, activation=nonlinear) 235 | self.sgcn2_ms_tcn_2 = MS_TCN(c2, c2, activation=nonlinear) 236 | self.sgcn2_ms_tcn_2.act = nn.Identity() 237 | 238 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']: 239 | self.tcn2 = MS_TCN(c2, c2, **kwargs, 240 | section_size=kwargs['section_sizes'][1], num_point=num_point, 241 | fea_dim=c2, activation=nonlinear) 242 | else: 243 | self.tcn2 = MS_TCN(c2, c2, **kwargs, fea_dim=c2, activation=nonlinear) 244 | 245 | # MSG3D 246 | self.gcn3d3 = MultiWindow_MS_G3D(c2, c3, A_binary, num_g3d_scales, window_stride=2) 247 | self.sgcn3_msgcn = MS_GCN(num_gcn_scales, c2, c2, A_binary, disentangled_agg=True, 248 | **kwargs, temporal_len=frame_len // 2, fea_dim=c2, 249 | activation=nonlinear) 250 | self.sgcn3_ms_tcn_1 = MS_TCN(c2, c3, stride=2, activation=nonlinear) 251 | # self.sgcn3_ms_tcn_1 = MS_TCN(c2, c3, activation=nonlinear) 252 | self.sgcn3_ms_tcn_2 = MS_TCN(c3, c3, activation=nonlinear) 253 | self.sgcn3_ms_tcn_2.act = nn.Identity() 254 | 255 | if 'to_use_temporal_transformer' in kwargs and kwargs['to_use_temporal_transformer']: 256 | self.tcn3 = MS_TCN(c3, c3, **kwargs, 257 | section_size=kwargs['section_sizes'][2], num_point=num_point, 258 | fea_dim=c3, activation=nonlinear) 259 | else: 260 | self.tcn3 = MS_TCN(c3, c3, **kwargs, fea_dim=c3, activation=nonlinear) 261 | 262 | self.use_temporal_transformer = False 263 | 264 | self.to_use_final_fc = to_use_final_fc 265 | if self.to_use_final_fc: 266 | self.fc = nn.Linear(c3, num_class) 267 | 268 | def forward(self, x, set_to_fc_last=True): 269 | # Select channels 270 | x = x[:, :3, :, :] 271 | x = self.preprocessing(x) 272 | # assert 0 273 | N, C, T, V, M = x.size() 274 | 275 | x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) 276 | x = self.data_bn(x) 277 | x = x.view(N * M, V, C, T).permute(0, 2, 3, 1).contiguous() 278 | 279 | ###### First Component ###### 280 | x = self.sgcn1_msgcn(x) 281 | x = self.sgcn1_ms_tcn_1(x) 282 | x = self.sgcn1_ms_tcn_2(x) 283 | x = self.nonlinear_f(x) 284 | x = self.tcn1(x) 285 | ###### End First Component ###### 286 | 287 | ###### Second Component ###### 288 | x = self.sgcn2_msgcn(x) 289 | x = self.sgcn2_ms_tcn_1(x) 290 | x = self.sgcn2_ms_tcn_2(x) 291 | x = self.nonlinear_f(x) 292 | x = self.tcn2(x) 293 | ###### End Second Component ###### 294 | 295 | ###### Third Component ###### 296 | x = self.sgcn3_msgcn(x) 297 | x = self.sgcn3_ms_tcn_1(x) 298 | x = self.sgcn3_ms_tcn_2(x) 299 | x = self.nonlinear_f(x) 300 | x = self.tcn3(x) 301 | ###### End Third Component ###### 302 | 303 | out = x 304 | 305 | out_channels = out.size(1) 306 | 307 | t_dim = out.shape[2] 308 | out = out.view(N, M, out_channels, t_dim, -1) 309 | out = out.permute(0, 1, 3, 4, 2) # N, M, T, V, C 310 | out = out.mean(3) # Global Average Pooling (Spatial) 311 | 312 | out = out.mean(2) # Global Average Pooling (Temporal) 313 | out = out.mean(1) # Average pool number of bodies in the sequence 314 | 315 | if set_to_fc_last: 316 | if self.to_use_final_fc: 317 | out = self.fc(out) 318 | 319 | other_outs = {} 320 | return out, other_outs 321 | 322 | def preprocessing(self, x): 323 | # Extract Bone and Angular Features 324 | fp_sp_joint_list_bone = [] 325 | fp_sp_joint_list_bone_angle = [] 326 | fp_sp_joint_list_body_center_angle_1 = [] 327 | fp_sp_joint_list_body_center_angle_2 = [] 328 | fp_sp_left_hand_angle = [] 329 | fp_sp_right_hand_angle = [] 330 | fp_sp_two_hand_angle = [] 331 | fp_sp_two_elbow_angle = [] 332 | fp_sp_two_knee_angle = [] 333 | fp_sp_two_feet_angle = [] 334 | 335 | all_list = [ 336 | fp_sp_joint_list_bone, fp_sp_joint_list_bone_angle, fp_sp_joint_list_body_center_angle_1, 337 | fp_sp_joint_list_body_center_angle_2, fp_sp_left_hand_angle, fp_sp_right_hand_angle, 338 | fp_sp_two_hand_angle, fp_sp_two_elbow_angle, fp_sp_two_knee_angle, 339 | fp_sp_two_feet_angle 340 | ] 341 | 342 | for a_key in ntu_bone_angle_pairs: 343 | a_angle_value = ntu_bone_angle_pairs[a_key] 344 | a_bone_value = ntu_bone_adj[a_key] 345 | the_joint = a_key - 1 346 | a_adj = a_bone_value - 1 347 | bone_diff = (x[:, :3, :, the_joint, :] - 348 | x[:, :3, :, a_adj, :]).unsqueeze(3).cpu() 349 | fp_sp_joint_list_bone.append(bone_diff) 350 | 351 | # bone angles 352 | v1 = a_angle_value[0] - 1 353 | v2 = a_angle_value[1] - 1 354 | vec1 = x[:, :3, :, v1, :] - x[:, :3, :, the_joint, :] 355 | vec2 = x[:, :3, :, v2, :] - x[:, :3, :, the_joint, :] 356 | angular_feature = (1.0 - self.cos(vec1, vec2)) 357 | angular_feature[angular_feature != angular_feature] = 0 358 | fp_sp_joint_list_bone_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 359 | 360 | # body angles 1 361 | vec1 = x[:, :3, :, 2 - 1, :] - x[:, :3, :, the_joint, :] 362 | vec2 = x[:, :3, :, 21 - 1, :] - x[:, :3, :, the_joint, :] 363 | angular_feature = (1.0 - self.cos(vec1, vec2)) 364 | angular_feature[angular_feature != angular_feature] = 0 365 | fp_sp_joint_list_body_center_angle_1.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 366 | 367 | # body angles 2 368 | vec1 = x[:, :3, :, the_joint, :] - x[:, :3, :, 21 - 1, :] 369 | vec2 = x[:, :3, :, 2 - 1, :] - x[:, :3, :, 21 - 1, :] 370 | angular_feature = (1.0 - self.cos(vec1, vec2)) 371 | angular_feature[angular_feature != angular_feature] = 0 372 | fp_sp_joint_list_body_center_angle_2.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 373 | 374 | # left hand angle 375 | vec1 = x[:, :3, :, 24 - 1, :] - x[:, :3, :, the_joint, :] 376 | vec2 = x[:, :3, :, 25 - 1, :] - x[:, :3, :, the_joint, :] 377 | angular_feature = (1.0 - self.cos(vec1, vec2)) 378 | angular_feature[angular_feature != angular_feature] = 0 379 | fp_sp_left_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 380 | 381 | # right hand angle 382 | vec1 = x[:, :3, :, 22 - 1, :] - x[:, :3, :, the_joint, :] 383 | vec2 = x[:, :3, :, 23 - 1, :] - x[:, :3, :, the_joint, :] 384 | angular_feature = (1.0 - self.cos(vec1, vec2)) 385 | angular_feature[angular_feature != angular_feature] = 0 386 | fp_sp_right_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 387 | 388 | # two hand angle 389 | vec1 = x[:, :3, :, 24 - 1, :] - x[:, :3, :, the_joint, :] 390 | vec2 = x[:, :3, :, 22 - 1, :] - x[:, :3, :, the_joint, :] 391 | angular_feature = (1.0 - self.cos(vec1, vec2)) 392 | angular_feature[angular_feature != angular_feature] = 0 393 | fp_sp_two_hand_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 394 | 395 | # two elbow angle 396 | vec1 = x[:, :3, :, 10 - 1, :] - x[:, :3, :, the_joint, :] 397 | vec2 = x[:, :3, :, 6 - 1, :] - x[:, :3, :, the_joint, :] 398 | angular_feature = (1.0 - self.cos(vec1, vec2)) 399 | angular_feature[angular_feature != angular_feature] = 0 400 | fp_sp_two_elbow_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 401 | 402 | # two knee angle 403 | vec1 = x[:, :3, :, 18 - 1, :] - x[:, :3, :, the_joint, :] 404 | vec2 = x[:, :3, :, 14 - 1, :] - x[:, :3, :, the_joint, :] 405 | angular_feature = (1.0 - self.cos(vec1, vec2)) 406 | angular_feature[angular_feature != angular_feature] = 0 407 | fp_sp_two_knee_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 408 | 409 | # two feet angle 410 | vec1 = x[:, :3, :, 20 - 1, :] - x[:, :3, :, the_joint, :] 411 | vec2 = x[:, :3, :, 16 - 1, :] - x[:, :3, :, the_joint, :] 412 | angular_feature = (1.0 - self.cos(vec1, vec2)) 413 | angular_feature[angular_feature != angular_feature] = 0 414 | fp_sp_two_feet_angle.append(angular_feature.unsqueeze(2).unsqueeze(1).cpu()) 415 | 416 | for a_list_id in range(len(all_list)): 417 | all_list[a_list_id] = torch.cat(all_list[a_list_id], dim=3) 418 | 419 | all_list = torch.cat(all_list, dim=1) 420 | # print('All_list:', all_list.shape) 421 | 422 | features = torch.cat((x, all_list.cuda()), dim=1) 423 | # print('features:', features.shape) 424 | return features 425 | 426 | 427 | if __name__ == "__main__": 428 | # For debugging purposes 429 | import sys 430 | 431 | sys.path.append('..') 432 | 433 | model = Model( 434 | num_class=60, 435 | num_point=25, 436 | num_person=2, 437 | num_gcn_scales=13, 438 | num_g3d_scales=6, 439 | graph='graph.ntu_rgb_d.AdjMatrixGraph' 440 | ) 441 | 442 | N, C, T, V, M = 6, 3, 50, 25, 2 443 | x = torch.randn(N, C, T, V, M) 444 | model.forward(x) 445 | 446 | print('Model total # params:', count_params(model)) 447 | -------------------------------------------------------------------------------- /notification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__init__.py -------------------------------------------------------------------------------- /notification/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /notification/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_config.cpython-36.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_config.cpython-37.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_sender.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_sender.cpython-36.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_sender.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_sender.cpython-37.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_templates.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_templates.cpython-36.pyc -------------------------------------------------------------------------------- /notification/__pycache__/email_templates.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/email_templates.cpython-37.pyc -------------------------------------------------------------------------------- /notification/__pycache__/html_templates.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/html_templates.cpython-36.pyc -------------------------------------------------------------------------------- /notification/__pycache__/html_templates.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/notification/__pycache__/html_templates.cpython-37.pyc -------------------------------------------------------------------------------- /notification/email_config.py: -------------------------------------------------------------------------------- 1 | EMAIL_S_ADDRESS = 'YOUR SENDER EMAIL ADDRESS' 2 | PASSWORD = 'YOUR PASSWORD' 3 | EMAIL_R_ADDRESS = 'YOUR RECEIVER EMAIL ADDRESS' 4 | -------------------------------------------------------------------------------- /notification/email_sender.py: -------------------------------------------------------------------------------- 1 | import smtplib 2 | 3 | from notification.email_templates import * 4 | from .email_config import * 5 | import traceback 6 | from email.mime.text import MIMEText 7 | from email.mime.multipart import MIMEMultipart 8 | 9 | from .html_templates import * 10 | 11 | 12 | def send_email(receivers, email_type, msg_content): 13 | return 14 | template_1 = None 15 | sub_email = None 16 | if email_type == 'exp_end': 17 | template_1 = exp_complete_1 18 | template_2 = exp_complete_2 19 | sub_email = 'Watchtower at Warrumbul: Experiment Complete' 20 | elif email_type == 'test_end': 21 | template_1 = exp_progress_1 22 | template_2 = exp_progress_2 23 | sub_email = 'Watchtower at Warrumbul: Experiment Progress' 24 | elif email_type == 'error': 25 | template_1 = emergency_1 26 | template_2 = emergency_2 27 | sub_email = 'Watchtower at Warrumbul: Emergency' 28 | else: 29 | raise NotImplementedError 30 | 31 | for a_person in receivers: 32 | 33 | def send_email(msg): 34 | try: 35 | server = smtplib.SMTP('smtp-relay.sendinblue.com', port=587) 36 | server.ehlo() 37 | server.starttls() 38 | server.login(EMAIL_S_ADDRESS, PASSWORD) 39 | message = msg 40 | server.sendmail(EMAIL_S_ADDRESS, a_person, message) 41 | server.quit() 42 | except Exception: 43 | traceback.print_exc() 44 | 45 | message = MIMEMultipart("alternative") 46 | if sub_email is not None: 47 | message["Subject"] = sub_email 48 | else: 49 | message["Subject"] = "实验进展: 加急公文 御赐金牌 马上飞递" 50 | message["From"] = EMAIL_S_ADDRESS 51 | message["To"] = a_person 52 | 53 | html = template_1 + msg_content + template_2 54 | 55 | part2 = MIMEText(html, 'html', 'utf-8') 56 | 57 | # message.attach(part1) 58 | message.attach(part2) 59 | 60 | send_email(message.as_string()) 61 | 62 | print(f'Success: Email sent to {a_person}') 63 | 64 | 65 | if __name__ == '__main__': 66 | test_msg = '这是一个测试. ' 67 | send_email(receivers=['zhenyue.qin@anu.edu.au'], 68 | email_type='error', 69 | msg_content=test_msg) -------------------------------------------------------------------------------- /notification/email_test.py: -------------------------------------------------------------------------------- 1 | from notification.email_sender import send_email 2 | 3 | 4 | test_msg = '这是一个测试. ' 5 | send_email(receivers=['zhenyue.qin@anu.edu.au'], 6 | email_type='exp_end', 7 | msg_content=test_msg) 8 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__init__.py -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /processor/__pycache__/args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/processor/__pycache__/args.cpython-36.pyc -------------------------------------------------------------------------------- /processor/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | 5 | 6 | def str2bool(v): 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | 14 | 15 | def get_parser(): 16 | # parameter priority: command line > config file > default 17 | parser = argparse.ArgumentParser(description='MS-G3D') 18 | 19 | parser.add_argument( 20 | '--work-dir', 21 | type=str, 22 | required=False, 23 | default='.work_dir/unknown_path', 24 | help='the work folder for storing results') 25 | parser.add_argument('--model_saved_name', default='') 26 | parser.add_argument( 27 | '--config', 28 | default='./config/nturgbd-cross1-view/test_bone.yaml', 29 | help='path to the configuration file') 30 | parser.add_argument( 31 | '--assume-yes', 32 | action='store_true', 33 | help='Say yes to every prompt') 34 | 35 | parser.add_argument( 36 | '--phase', 37 | default='train', 38 | help='must be train or test') 39 | parser.add_argument( 40 | '--save-score', 41 | type=str2bool, 42 | default=False, 43 | help='if ture, the classification score will be stored') 44 | 45 | parser.add_argument( 46 | '--seed', 47 | type=int, 48 | default=0, 49 | help='random seed') 50 | parser.add_argument( 51 | '--log-interval', 52 | type=int, 53 | default=1, 54 | help='the interval for printing messages (#iteration)') 55 | parser.add_argument( 56 | '--save-interval', 57 | type=int, 58 | default=1, 59 | help='the interval for storing models (#iteration)') 60 | parser.add_argument( 61 | '--eval-interval', 62 | type=int, 63 | default=5, 64 | help='the interval for evaluating models (#iteration)') 65 | parser.add_argument( 66 | '--eval-start', 67 | type=int, 68 | default=1, 69 | help='The epoch number to start evaluating models') 70 | parser.add_argument( 71 | '--print-log', 72 | type=str2bool, 73 | default=True, 74 | help='print logging or not') 75 | parser.add_argument( 76 | '--show-topk', 77 | type=int, 78 | default=[1, 5], 79 | nargs='+', 80 | help='which Top K accuracy will be shown') 81 | 82 | parser.add_argument( 83 | '--feeder', 84 | default='feeder.feeder', 85 | help='data loader will be used') 86 | parser.add_argument( 87 | '--num-worker', 88 | type=int, 89 | default=os.cpu_count(), 90 | help='the number of worker for data loader') 91 | parser.add_argument( 92 | '--train-feeder-args', 93 | '--train-feeder-args', 94 | default=dict(), 95 | help='the arguments of data loader for training') 96 | parser.add_argument( 97 | '--test-feeder-args', 98 | default=dict(), 99 | help='the arguments of data loader for test') 100 | 101 | parser.add_argument( 102 | '--model', 103 | default=None, 104 | help='the model will be used') 105 | parser.add_argument( 106 | '--model-args', 107 | type=dict, 108 | default=dict(), 109 | help='the arguments of model') 110 | parser.add_argument( 111 | '--weights', 112 | default=None, 113 | help='the weights for network initialization') 114 | parser.add_argument( 115 | '--ignore-weights', 116 | type=str, 117 | default=[], 118 | nargs='+', 119 | help='the name of weights which will be ignored in the initialization') 120 | parser.add_argument( 121 | '--half', 122 | action='store_true', 123 | help='Use half-precision (FP16) training') 124 | parser.add_argument( 125 | '--amp-opt-level', 126 | type=int, 127 | default=1, 128 | help='NVIDIA Apex AMP optimization level') 129 | 130 | parser.add_argument( 131 | '--base-lr', 132 | type=float, 133 | default=0.01, 134 | help='initial learning rate') 135 | parser.add_argument( 136 | '--step', 137 | type=int, 138 | default=[20, 40, 60], 139 | nargs='+', 140 | help='the epoch where optimizer reduce the learning rate') 141 | parser.add_argument( 142 | '--lr_decay', 143 | type=float, 144 | default=0.1, 145 | help='learning rate decay degree' 146 | ) 147 | parser.add_argument( 148 | '--device', 149 | type=int, 150 | default=0, 151 | nargs='+', 152 | help='the indexes of GPUs for training or testing') 153 | parser.add_argument( 154 | '--optimizer', 155 | default='SGD', 156 | help='type of optimizer') 157 | parser.add_argument( 158 | '--nesterov', 159 | type=str2bool, 160 | default=False, 161 | help='use nesterov or not') 162 | parser.add_argument( 163 | '--batch-size', 164 | type=int, 165 | default=32, 166 | help='training batch size') 167 | parser.add_argument( 168 | '--test-batch-size', 169 | type=int, 170 | default=256, 171 | help='test batch size') 172 | parser.add_argument( 173 | '--forward-batch-size', 174 | type=int, 175 | default=16, 176 | help='Batch size during forward pass, must be factor of --batch-size') 177 | parser.add_argument( 178 | '--start-epoch', 179 | type=int, 180 | default=0, 181 | help='start training from which epoch') 182 | parser.add_argument( 183 | '--num-epoch', 184 | type=int, 185 | default=80, 186 | help='stop training in which epoch') 187 | parser.add_argument( 188 | '--weight-decay', 189 | type=float, 190 | default=0.0005, 191 | help='weight decay for optimizer') 192 | parser.add_argument( 193 | '--optimizer-states', 194 | type=str, 195 | help='path of previously saved optimizer states') 196 | parser.add_argument( 197 | '--checkpoint', 198 | type=str, 199 | help='path of previously saved training checkpoint') 200 | parser.add_argument( 201 | '--debug', 202 | type=str2bool, 203 | default=False, 204 | help='Debug mode; default false') 205 | 206 | # ZQ resume training 207 | parser.add_argument( 208 | '--resume', 209 | type=str2bool, 210 | default=False, 211 | help='resume previous training' 212 | ) 213 | 214 | parser.add_argument( 215 | '--tbatch', 216 | type=int, 217 | default=128, 218 | help='batch size for transformers' 219 | ) 220 | 221 | parser.add_argument( 222 | '--train_print_freq', 223 | type=int, 224 | default=100, 225 | help='training printing frequency' 226 | ) 227 | 228 | # one hot 229 | parser.add_argument( 230 | '--to_add_onehot', 231 | type=bool, 232 | default=False, 233 | help='to add one hot in the input data' 234 | ) 235 | 236 | # feature selection 237 | parser.add_argument( 238 | '--feature_combo', 239 | type=str, 240 | default='', 241 | help='what features to use' 242 | ) 243 | 244 | parser.add_argument( 245 | '--additional_loss', 246 | type=dict, 247 | default=dict(), 248 | ) 249 | 250 | parser.add_argument( # encode data 251 | '--encoding_args', 252 | default=dict(), 253 | ) 254 | 255 | return parser 256 | -------------------------------------------------------------------------------- /processor/io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import yaml 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import torchlight 11 | from torchlight.torchlight.gpu import visible_gpu, occupy_gpu 12 | from torchlight.torchlight.io import str2bool 13 | from torchlight.torchlight.io import DictAction 14 | from torchlight.torchlight.io import import_class 15 | 16 | 17 | class IO(): 18 | 19 | def __init__(self, argv=None): 20 | 21 | self.load_arg(argv) 22 | self.init_environment() 23 | self.load_model() 24 | self.load_weights() 25 | self.gpu() 26 | 27 | def load_arg(self, argv=None): 28 | parser = self.get_parser() 29 | 30 | # load arg form config file 31 | p = parser.parse_args(argv) 32 | if p.config is not None: 33 | # load config file 34 | with open(p.config, 'r') as f: 35 | default_arg = yaml.load(f) 36 | 37 | # update parser from config file 38 | key = vars(p).keys() 39 | for k in default_arg.keys(): 40 | if k not in key: 41 | print('Unknown Arguments: {}'.format(k)) 42 | assert k in key 43 | 44 | parser.set_defaults(**default_arg) 45 | 46 | self.arg = parser.parse_args(argv) 47 | 48 | def init_environment(self): 49 | self.save_dir = os.path.join(self.arg.work_dir, 50 | self.arg.max_hop_dir, 51 | self.arg.lamda_act_dir) 52 | self.io = torchlight.torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) 53 | self.io.save_arg(self.arg) 54 | 55 | # gpu 56 | if self.arg.use_gpu: 57 | gpus = visible_gpu(self.arg.device) 58 | occupy_gpu(gpus) 59 | self.gpus = gpus 60 | self.dev = "cuda:0" 61 | else: 62 | self.dev = "cpu" 63 | 64 | def load_model(self): 65 | self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) 66 | self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) 67 | 68 | def load_weights(self): 69 | if self.arg.weights1: 70 | self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) 71 | self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) 72 | 73 | def gpu(self): 74 | # move modules to gpu 75 | self.model1 = self.model1.to(self.dev) 76 | self.model2 = self.model2.to(self.dev) 77 | for name, value in vars(self).items(): 78 | cls_name = str(value.__class__) 79 | if cls_name.find('torch.nn.modules') != -1: 80 | setattr(self, name, value.to(self.dev)) 81 | 82 | # model parallel 83 | if self.arg.use_gpu and len(self.gpus) > 1: 84 | self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) 85 | self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) 86 | 87 | def start(self): 88 | self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) 89 | 90 | @staticmethod 91 | def get_parser(add_help=False): 92 | 93 | #region arguments yapf: disable 94 | # parameter priority: command line > config > default 95 | parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') 96 | 97 | parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') 98 | parser.add_argument('-c', '--config', default=None, help='path to the configuration file') 99 | 100 | # processor 101 | parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') 102 | parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') 103 | 104 | # visulize and debug 105 | parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') 106 | parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') 107 | 108 | # model 109 | parser.add_argument('--model1', default=None, help='the model will be used') 110 | parser.add_argument('--model2', default=None, help='the model will be used') 111 | parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') 112 | parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') 113 | parser.add_argument('--weights', default=None, help='the weights for network initialization') 114 | parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') 115 | #endregion yapf: enable 116 | 117 | return parser 118 | -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import yaml 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | import torchlight 11 | from torchlight.torchlight.io import str2bool 12 | from torchlight.torchlight.io import DictAction 13 | from torchlight.torchlight.io import import_class 14 | 15 | from .io import IO 16 | 17 | 18 | class Processor(IO): 19 | 20 | def __init__(self, argv=None): 21 | 22 | self.load_arg(argv) 23 | self.init_environment() 24 | self.load_model() 25 | self.load_weights() 26 | self.gpu() 27 | self.load_data() 28 | 29 | def init_environment(self): 30 | 31 | super().init_environment() 32 | self.result = dict() 33 | self.iter_info = dict() 34 | self.epoch_info = dict() 35 | self.meta_info = dict(epoch=0, iter=0) 36 | 37 | 38 | def load_data(self): 39 | Feeder = import_class(self.arg.feeder) 40 | if 'debug' not in self.arg.train_feeder_args: 41 | self.arg.train_feeder_args['debug'] = self.arg.debug 42 | self.data_loader = dict() 43 | if self.arg.phase == 'train': 44 | self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args), 45 | batch_size=self.arg.batch_size, 46 | shuffle=True, 47 | num_workers=self.arg.num_worker, 48 | drop_last=True) 49 | if self.arg.test_feeder_args: 50 | self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args), 51 | batch_size=self.arg.test_batch_size, 52 | shuffle=False, 53 | num_workers=self.arg.num_worker) 54 | 55 | def show_epoch_info(self): 56 | for k, v in self.epoch_info.items(): 57 | self.io.print_log('\t{}: {}'.format(k, v)) 58 | if self.arg.pavi_log: 59 | self.io.log('train', self.meta_info['iter'], self.epoch_info) 60 | 61 | def show_iter_info(self): 62 | if self.meta_info['iter'] % self.arg.log_interval == 0: 63 | info ='\tIter {} Done.'.format(self.meta_info['iter']) 64 | for k, v in self.iter_info.items(): 65 | if isinstance(v, float): 66 | info = info + ' | {}: {:.4f}'.format(k, v) 67 | else: 68 | info = info + ' | {}: {}'.format(k, v) 69 | 70 | self.io.print_log(info) 71 | 72 | if self.arg.pavi_log: 73 | self.io.log('train', self.meta_info['iter'], self.iter_info) 74 | 75 | def train(self): 76 | for _ in range(100): 77 | self.iter_info['loss'] = 0 78 | self.iter_info['loss_class'] = 0 79 | self.iter_info['loss_recon'] = 0 80 | self.show_iter_info() 81 | self.meta_info['iter'] += 1 82 | self.epoch_info['mean_loss'] = 0 83 | self.epoch_info['mean_loss_class'] = 0 84 | self.epoch_info['mean_loss_recon'] = 0 85 | self.show_epoch_info() 86 | 87 | def test(self): 88 | for _ in range(100): 89 | self.iter_info['loss'] = 1 90 | self.iter_info['loss_class'] = 1 91 | self.iter_info['loss_recon'] = 1 92 | self.show_iter_info() 93 | self.epoch_info['mean_loss'] = 1 94 | self.epoch_info['mean_loss_class'] = 1 95 | self.epoch_info['mean_loss_recon'] = 1 96 | self.show_epoch_info() 97 | 98 | def start(self): 99 | self.arg.eval_interval = 5 100 | 101 | self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) 102 | 103 | if self.arg.phase == 'train': 104 | for epoch in range(self.arg.start_epoch, self.arg.num_epoch): 105 | self.meta_info['epoch'] = epoch 106 | 107 | if epoch < 10: 108 | self.io.print_log('Training epoch: {}'.format(epoch)) 109 | self.train(training_A=True) 110 | self.io.print_log('Done.') 111 | else: 112 | self.io.print_log('Training epoch: {}'.format(epoch)) 113 | self.train(training_A=False) 114 | self.io.print_log('Done.') 115 | 116 | # save model 117 | if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch): 118 | filename1 = 'epoch{}_model1.pt'.format(epoch) 119 | self.io.save_model(self.model1, filename1) 120 | filename2 = 'epoch{}_model2.pt'.format(epoch) 121 | self.io.save_model(self.model2, filename2) 122 | 123 | # evaluation 124 | if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch): 125 | self.io.print_log('Eval epoch: {}'.format(epoch)) 126 | if epoch <= 10: 127 | self.test(testing_A=True) 128 | else: 129 | self.test(testing_A=False) 130 | self.io.print_log('Done.') 131 | 132 | 133 | elif self.arg.phase == 'test': 134 | if self.arg.weights2 is None: 135 | raise ValueError('Please appoint --weights.') 136 | self.io.print_log('Model: {}.'.format(self.arg.model2)) 137 | self.io.print_log('Weights: {}.'.format(self.arg.weights2)) 138 | 139 | self.io.print_log('Evaluation Start:') 140 | self.test(testing_A=False, save_feature=True) 141 | self.io.print_log('Done.\n') 142 | 143 | if self.arg.save_result: 144 | result_dict = dict( 145 | zip(self.data_loader['test'].dataset.sample_name, 146 | self.result)) 147 | self.io.save_pkl(result_dict, 'test_result.pkl') 148 | 149 | 150 | @staticmethod 151 | def get_parser(add_help=False): 152 | 153 | parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor') 154 | 155 | parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') 156 | parser.add_argument('-c', '--config', default=None, help='path to the configuration file') 157 | 158 | parser.add_argument('--phase', default='train', help='must be train or test') 159 | parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored') 160 | parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch') 161 | parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch') 162 | parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') 163 | parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') 164 | 165 | parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)') 166 | parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)') 167 | parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)') 168 | parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') 169 | parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') 170 | parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not') 171 | 172 | parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used') 173 | parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader') 174 | parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training') 175 | parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test') 176 | parser.add_argument('--batch_size', type=int, default=256, help='training batch size') 177 | parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size') 178 | parser.add_argument('--debug', action="store_true", help='less data, faster loading') 179 | 180 | parser.add_argument('--model1', default=None, help='the model will be used') 181 | parser.add_argument('--model2', default=None, help='the model will be used') 182 | parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') 183 | parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') 184 | parser.add_argument('--weights1', default=None, help='the weights for network initialization') 185 | parser.add_argument('--weights2', default=None, help='the weights for network initialization') 186 | parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') 187 | 188 | return parser 189 | -------------------------------------------------------------------------------- /processor/recognition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | from processor.torchlight_io import str2bool 13 | 14 | from .processor import Processor 15 | 16 | 17 | def weights_init(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Conv1d') != -1: 20 | m.weight.data.normal_(0.0, 0.02) 21 | if m.bias is not None: 22 | m.bias.data.fill_(0) 23 | elif classname.find('Conv2d') != -1: 24 | m.weight.data.normal_(0.0, 0.02) 25 | if m.bias is not None: 26 | m.bias.data.fill_(0) 27 | elif classname.find('BatchNorm') != -1: 28 | m.weight.data.normal_(1.0, 0.02) 29 | m.bias.data.fill_(0) 30 | 31 | 32 | class REC_Processor(Processor): 33 | 34 | def load_model(self): 35 | self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) 36 | self.model1.apply(weights_init) 37 | self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) 38 | 39 | self.loss_class = nn.CrossEntropyLoss().cuda() 40 | self.loss_pred = nn.MSELoss() 41 | self.w_pred = 0.01 42 | 43 | prior = np.array([0.95, 0.05/2, 0.05/2]) 44 | self.log_prior = torch.FloatTensor(np.log(prior)) 45 | self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0) 46 | 47 | self.load_optimizer() 48 | 49 | def load_optimizer(self): 50 | if self.arg.optimizer == 'SGD': 51 | self.optimizer1 = optim.SGD(params=self.model1.parameters(), 52 | lr=self.arg.base_lr1, 53 | momentum=0.9, 54 | nesterov=self.arg.nesterov, 55 | weight_decay=self.arg.weight_decay) 56 | elif self.arg.optimizer == 'Adam': 57 | self.optimizer1 = optim.Adam(params=self.model1.parameters(), 58 | lr=self.arg.base_lr1, 59 | weight_decay=self.arg.weight_decay) 60 | else: 61 | raise ValueError() 62 | self.optimizer2 = optim.Adam(params=self.model2.parameters(), 63 | lr=self.arg.base_lr2) 64 | 65 | def adjust_lr(self): 66 | if self.arg.optimizer == 'SGD' and self.arg.step: 67 | lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step))) 68 | for param_group in self.optimizer1.param_groups: 69 | param_group['lr'] = lr 70 | self.lr = lr 71 | else: 72 | self.lr = self.arg.base_lr1 73 | self.lr2 = self.arg.base_lr2 74 | 75 | def nll_gaussian(self, preds, target, variance, add_const=False): 76 | neg_log_p = ((preds-target)**2/(2*variance)) 77 | if add_const: 78 | const = 0.5*np.log(2*np.pi*variance) 79 | neg_log_p += const 80 | return neg_log_p.sum() / (target.size(0) * target.size(1)) 81 | 82 | def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): 83 | kl_div = preds*(torch.log(preds+eps)-log_prior) 84 | return kl_div.sum()/(num_node*preds.size(0)) 85 | 86 | 87 | def train(self, training_A=False): 88 | self.model1.train() 89 | self.model2.train() 90 | self.adjust_lr() 91 | loader = self.data_loader['train'] 92 | loss1_value = [] 93 | loss_class_value = [] 94 | loss_recon_value = [] 95 | loss2_value = [] 96 | loss_nll_value = [] 97 | loss_kl_value = [] 98 | 99 | if training_A: 100 | for param1 in self.model1.parameters(): 101 | param1.requires_grad = False 102 | for param2 in self.model2.parameters(): 103 | param2.requires_grad = True 104 | self.iter_info.clear() 105 | self.epoch_info.clear() 106 | 107 | for data, data_downsample, target_data, data_last, label in loader: 108 | data = data.float().to(self.dev) 109 | data_downsample = data_downsample.float().to(self.dev) 110 | label = label.long().to(self.dev) 111 | 112 | gpu_id = data.get_device() 113 | self.log_prior = self.log_prior.cuda(gpu_id) 114 | A_batch, prob, outputs, data_target = self.model2(data_downsample) 115 | loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4) 116 | loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) 117 | loss2 = loss_nll + loss_kl 118 | 119 | self.optimizer2.zero_grad() 120 | loss2.backward() 121 | self.optimizer2.step() 122 | 123 | self.iter_info['loss2'] = loss2.data.item() 124 | self.iter_info['loss_nll'] = loss_nll.data.item() 125 | self.iter_info['loss_kl'] = loss_kl.data.item() 126 | self.iter_info['lr'] = '{:.6f}'.format(self.lr2) 127 | 128 | loss2_value.append(self.iter_info['loss2']) 129 | loss_nll_value.append(self.iter_info['loss_nll']) 130 | loss_kl_value.append(self.iter_info['loss_kl']) 131 | self.show_iter_info() 132 | self.meta_info['iter'] += 1 133 | self.epoch_info['mean_loss2'] = np.mean(loss2_value) 134 | self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) 135 | self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) 136 | self.show_epoch_info() 137 | self.io.print_timer() 138 | 139 | else: 140 | for param1 in self.model1.parameters(): 141 | param1.requires_grad = True 142 | for param2 in self.model2.parameters(): 143 | param2.requires_grad = True 144 | self.iter_info.clear() 145 | self.epoch_info.clear() 146 | for data, data_downsample, target_data, data_last, label in loader: 147 | data = data.float().to(self.dev) 148 | data_downsample = data_downsample.float().to(self.dev) 149 | target_data = target_data.float().to(self.dev) 150 | data_last = data_last.float().to(self.dev) 151 | label = label.long().to(self.dev) 152 | 153 | A_batch, prob, outputs, _ = self.model2(data_downsample) 154 | x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) 155 | loss_class = self.loss_class(x_class, label) 156 | loss_recon = self.loss_pred(pred, target) 157 | loss1 = loss_class + self.w_pred*loss_recon 158 | 159 | self.optimizer1.zero_grad() 160 | loss1.backward() 161 | self.optimizer1.step() 162 | 163 | self.iter_info['loss1'] = loss1.data.item() 164 | self.iter_info['loss_class'] = loss_class.data.item() 165 | self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred 166 | self.iter_info['lr'] = '{:.6f}'.format(self.lr) 167 | 168 | loss1_value.append(self.iter_info['loss1']) 169 | loss_class_value.append(self.iter_info['loss_class']) 170 | loss_recon_value.append(self.iter_info['loss_recon']) 171 | self.show_iter_info() 172 | self.meta_info['iter'] += 1 173 | 174 | self.epoch_info['mean_loss1']= np.mean(loss1_value) 175 | self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) 176 | self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) 177 | self.show_epoch_info() 178 | self.io.print_timer() 179 | 180 | 181 | def test(self, evaluation=True, testing_A=False, save=False, save_feature=False): 182 | 183 | self.model1.eval() 184 | self.model2.eval() 185 | loader = self.data_loader['test'] 186 | loss1_value = [] 187 | loss_class_value = [] 188 | loss_recon_value = [] 189 | loss2_value = [] 190 | loss_nll_value = [] 191 | loss_kl_value = [] 192 | result_frag = [] 193 | label_frag = [] 194 | 195 | if testing_A: 196 | A_all = [] 197 | self.epoch_info.clear() 198 | for data, data_downsample, target_data, data_last, label in loader: 199 | data = data.float().to(self.dev) 200 | data_downsample = data_downsample.float().to(self.dev) 201 | label = label.long().to(self.dev) 202 | 203 | with torch.no_grad(): 204 | A_batch, prob, outputs, data_bn = self.model2(data_downsample) 205 | 206 | if save: 207 | n = A_batch.size(0) 208 | a = A_batch[:int(n/2),:,:,:].cpu().numpy() 209 | A_all.extend(a) 210 | 211 | if evaluation: 212 | gpu_id = data.get_device() 213 | self.log_prior = self.log_prior.cuda(gpu_id) 214 | loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4) 215 | loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) 216 | loss2 = loss_nll + loss_kl 217 | 218 | loss2_value.append(loss2.item()) 219 | loss_nll_value.append(loss_nll.item()) 220 | loss_kl_value.append(loss_kl.item()) 221 | 222 | if save: 223 | A_all = np.array(A_all) 224 | np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all) 225 | 226 | if evaluation: 227 | self.epoch_info['mean_loss2'] = np.mean(loss2_value) 228 | self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) 229 | self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) 230 | self.show_epoch_info() 231 | 232 | else: 233 | recon_data = [] 234 | feature_map = [] 235 | self.epoch_info.clear() 236 | for data, data_downsample, target_data, data_last, label in loader: 237 | data = data.float().to(self.dev) 238 | data_downsample = data_downsample.float().to(self.dev) 239 | target_data = target_data.float().to(self.dev) 240 | data_last = data_last.float().to(self.dev) 241 | label = label.long().to(self.dev) 242 | 243 | with torch.no_grad(): 244 | A_batch, prob, outputs, _ = self.model2(data_downsample) 245 | x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) 246 | result_frag.append(x_class.data.cpu().numpy()) 247 | 248 | if save: 249 | n = pred.size(0) 250 | p = pred[::2,:,:,:].cpu().numpy() 251 | recon_data.extend(p) 252 | 253 | if evaluation: 254 | loss_class = self.loss_class(x_class, label) 255 | loss_recon = self.loss_pred(pred, target) 256 | loss1 = loss_class + self.w_pred*loss_recon 257 | 258 | loss1_value.append(loss1.item()) 259 | loss_class_value.append(loss_class.item()) 260 | loss_recon_value.append(loss_recon.item()) 261 | label_frag.append(label.data.cpu().numpy()) 262 | 263 | if save: 264 | recon_data = np.array(recon_data) 265 | np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data) 266 | 267 | 268 | self.result = np.concatenate(result_frag) 269 | if evaluation: 270 | self.label = np.concatenate(label_frag) 271 | self.epoch_info['mean_loss1'] = np.mean(loss1_value) 272 | self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) 273 | self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) 274 | self.show_epoch_info() 275 | 276 | for k in self.arg.show_topk: 277 | hit_top_k = [] 278 | rank = self.result.argsort() 279 | for i,l in enumerate(self.label): 280 | hit_top_k.append(l in rank[i, -k:]) 281 | self.io.print_log('\n') 282 | accuracy = sum(hit_top_k)*1.0/len(hit_top_k) 283 | self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy)) 284 | 285 | 286 | 287 | @staticmethod 288 | def get_parser(add_help=False): 289 | 290 | parent_parser = Processor.get_parser(add_help=False) 291 | parser = argparse.ArgumentParser( 292 | add_help=add_help, 293 | parents=[parent_parser], 294 | description='Spatial Temporal Graph Convolution Network') 295 | 296 | parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown') 297 | parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate') 298 | parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate') 299 | parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate') 300 | parser.add_argument('--optimizer', default='SGD', help='type of optimizer') 301 | parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not') 302 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer') 303 | 304 | parser.add_argument('--max_hop_dir', type=str, default='max_hop_4') 305 | parser.add_argument('--lamda_act', type=float, default=0.5) 306 | parser.add_argument('--lamda_act_dir', type=str, default='lamda_05') 307 | 308 | return parser 309 | -------------------------------------------------------------------------------- /processor/torchlight_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import sys 5 | import traceback 6 | import time 7 | import warnings 8 | import pickle 9 | from collections import OrderedDict 10 | import yaml 11 | import numpy as np 12 | # torch 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.autograd import Variable 17 | 18 | with warnings.catch_warnings(): 19 | warnings.filterwarnings("ignore", category=FutureWarning) 20 | import h5py 21 | 22 | 23 | class IO(): 24 | def __init__(self, work_dir, save_log=True, print_log=True): 25 | self.work_dir = work_dir 26 | self.save_log = save_log 27 | self.print_to_screen = print_log 28 | self.cur_time = time.time() 29 | self.split_timer = {} 30 | self.pavi_logger = None 31 | self.session_file = None 32 | self.model_text = '' 33 | 34 | # PaviLogger is removed in this version 35 | def log(self, *args, **kwargs): 36 | pass 37 | 38 | # try: 39 | # if self.pavi_logger is None: 40 | # from torchpack.runner.hooks import PaviLogger 41 | # url = 'http://pavi.parrotsdnn.org/log' 42 | # with open(self.session_file, 'r') as f: 43 | # info = dict( 44 | # session_file=self.session_file, 45 | # session_text=f.read(), 46 | # model_text=self.model_text) 47 | # self.pavi_logger = PaviLogger(url) 48 | # self.pavi_logger.connect(self.work_dir, info=info) 49 | # self.pavi_logger.log(*args, **kwargs) 50 | # except: #pylint: disable=W0702 51 | # pass 52 | 53 | def load_model(self, model, **model_args): 54 | Model = import_class(model) 55 | model = Model(**model_args) 56 | self.model_text += '\n\n' + str(model) 57 | return model 58 | 59 | def load_weights(self, model, weights_path, ignore_weights=None): 60 | if ignore_weights is None: 61 | ignore_weights = [] 62 | if isinstance(ignore_weights, str): 63 | ignore_weights = [ignore_weights] 64 | 65 | self.print_log('Load weights from {}.'.format(weights_path)) 66 | weights = torch.load(weights_path) 67 | weights = OrderedDict([[k.split('module.')[-1], 68 | v.cpu()] for k, v in weights.items()]) 69 | 70 | # filter weights 71 | for i in ignore_weights: 72 | ignore_name = list() 73 | for w in weights: 74 | if w.find(i) == 0: 75 | ignore_name.append(w) 76 | for n in ignore_name: 77 | weights.pop(n) 78 | self.print_log('Filter [{}] remove weights [{}].'.format(i, n)) 79 | 80 | for w in weights: 81 | self.print_log('Load weights [{}].'.format(w)) 82 | 83 | try: 84 | model.load_state_dict(weights) 85 | except (KeyError, RuntimeError): 86 | state = model.state_dict() 87 | diff = list(set(state.keys()).difference(set(weights.keys()))) 88 | for d in diff: 89 | self.print_log('Can not find weights [{}].'.format(d)) 90 | state.update(weights) 91 | model.load_state_dict(state) 92 | return model 93 | 94 | def save_pkl(self, result, filename): 95 | with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: 96 | pickle.dump(result, f) 97 | 98 | def save_h5(self, result, filename): 99 | with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: 100 | for k in result.keys(): 101 | f[k] = result[k] 102 | 103 | def save_model(self, model, name): 104 | model_path = '{}/{}'.format(self.work_dir, name) 105 | state_dict = model.state_dict() 106 | weights = OrderedDict([[''.join(k.split('module.')), 107 | v.cpu()] for k, v in state_dict.items()]) 108 | torch.save(weights, model_path) 109 | self.print_log('The model has been saved as {}.'.format(model_path)) 110 | 111 | def save_arg(self, arg): 112 | 113 | self.session_file = '{}/config.yaml'.format(self.work_dir) 114 | 115 | # save arg 116 | arg_dict = vars(arg) 117 | if not os.path.exists(self.work_dir): 118 | os.makedirs(self.work_dir) 119 | with open(self.session_file, 'w') as f: 120 | f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) 121 | yaml.dump(arg_dict, f, default_flow_style=False, indent=4) 122 | 123 | def print_log(self, str, print_time=True): 124 | if print_time: 125 | # localtime = time.asctime(time.localtime(time.time())) 126 | str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str 127 | 128 | if self.print_to_screen: 129 | print(str) 130 | if self.save_log: 131 | with open('{}/log.txt'.format(self.work_dir), 'a') as f: 132 | print(str, file=f) 133 | 134 | def init_timer(self, *name): 135 | self.record_time() 136 | self.split_timer = {k: 0.0000001 for k in name} 137 | 138 | def check_time(self, name): 139 | self.split_timer[name] += self.split_time() 140 | 141 | def record_time(self): 142 | self.cur_time = time.time() 143 | return self.cur_time 144 | 145 | def split_time(self): 146 | split_time = time.time() - self.cur_time 147 | self.record_time() 148 | return split_time 149 | 150 | def print_timer(self): 151 | proportion = { 152 | k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) 153 | for k, v in self.split_timer.items() 154 | } 155 | self.print_log('Time consumption:') 156 | for k in proportion: 157 | self.print_log( 158 | '\t[{}][{}]: {:.4f}'.format(k, proportion[k], self.split_timer[k]) 159 | ) 160 | 161 | 162 | def str2bool(v): 163 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 164 | return True 165 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 166 | return False 167 | else: 168 | raise argparse.ArgumentTypeError('Boolean value expected.') 169 | 170 | 171 | def str2dict(v): 172 | return eval('dict({})'.format(v)) # pylint: disable=W0123 173 | 174 | 175 | def _import_class_0(name): 176 | components = name.split('.') 177 | mod = __import__(components[0]) 178 | for comp in components[1:]: 179 | mod = getattr(mod, comp) 180 | return mod 181 | 182 | 183 | def import_class(import_str): 184 | mod_str, _sep, class_str = import_str.rpartition('.') 185 | __import__(mod_str) 186 | try: 187 | return getattr(sys.modules[mod_str], class_str) 188 | except AttributeError: 189 | raise ImportError('Class %s cannot be found (%s)' % 190 | (class_str, 191 | traceback.format_exception(*sys.exc_info()))) 192 | 193 | 194 | class DictAction(argparse.Action): 195 | def __init__(self, option_strings, dest, nargs=None, **kwargs): 196 | if nargs is not None: 197 | raise ValueError("nargs not allowed") 198 | super(DictAction, self).__init__(option_strings, dest, **kwargs) 199 | 200 | def __call__(self, parser, namespace, values, option_string=None): 201 | input_dict = eval('dict({})'.format(values)) # pylint: disable=W0123 202 | output_dict = getattr(namespace, self.dest) 203 | for k in input_dict: 204 | output_dict[k] = input_dict[k] 205 | setattr(namespace, self.dest, output_dict) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1 python main.py \ 4 | --config config/train.yaml \ 5 | # >> outs_files/22_03_11-ntu120xsub-angular.log 2>&1 & -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import traceback 4 | 5 | 6 | def import_class(name): 7 | components = name.split('.') 8 | mod = __import__(components[0]) 9 | for comp in components[1:]: 10 | mod = getattr(mod, comp) 11 | return mod 12 | 13 | 14 | def import_class_2(import_str): 15 | mod_str, _sep, class_str = import_str.rpartition('.') 16 | __import__(mod_str) 17 | try: 18 | return getattr(sys.modules[mod_str], class_str) 19 | except AttributeError: 20 | raise ImportError('Class %s cannot be found (%s)' % 21 | (class_str, 22 | traceback.format_exception(*sys.exc_info()))) 23 | 24 | 25 | 26 | def count_params(model): 27 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 28 | 29 | 30 | def get_current_time(): 31 | currentDT = datetime.datetime.now() 32 | return str(currentDT.strftime("%Y-%m-%dT%H-%M-%S")) 33 | -------------------------------------------------------------------------------- /utils_dir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__init__.py -------------------------------------------------------------------------------- /utils_dir/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/__pycache__/utils_cam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_cam.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/__pycache__/utils_io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_io.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/__pycache__/utils_math.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_math.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/__pycache__/utils_result.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_result.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/__pycache__/utils_visual.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kfzyqin/Angular-Skeleton-Encoding/346b87376e5d915471bbfe9ce13f8a8c4e053951/utils_dir/__pycache__/utils_visual.cpython-36.pyc -------------------------------------------------------------------------------- /utils_dir/utils_cam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | 5 | 6 | def plot_freq(cam_dict): 7 | total_num = 0 8 | for a_key in cam_dict: 9 | a_action_cam = cam_dict[a_key].cpu().numpy() 10 | print('action: ', a_key, 'cam shape: ', a_action_cam.shape) 11 | total_num += a_action_cam.shape[0] 12 | ax = sns.heatmap(a_action_cam) 13 | plt 14 | plt.savefig(f'test_fields/action_heatmaps/action_{a_key}_heatmap.png', dpi=200, 15 | bbox_inches='tight') 16 | # plt.show() 17 | plt.close() 18 | print('total num: ', total_num) 19 | -------------------------------------------------------------------------------- /utils_dir/utils_io.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | 5 | def mv_py_files_to_dir(a_dir, tgt_dir=None): 6 | if tgt_dir is None: 7 | py_dir = os.path.join(a_dir, 'py_dir') 8 | else: 9 | py_dir = tgt_dir 10 | if not os.path.exists(py_dir): 11 | os.makedirs(py_dir) 12 | 13 | # for root, dirs, files in os.walk(a_dir): # copy files 14 | # for file in files: 15 | # if file.endswith(".py"): 16 | # cp_tgt = os.path.join(root, file) 17 | # shutil.copy2(cp_tgt, py_dir) 18 | 19 | for item in os.listdir(a_dir): 20 | s = os.path.join(a_dir, item) 21 | d = os.path.join(py_dir, item) 22 | if os.path.isdir(s): 23 | shutil.copytree(s, d) 24 | else: 25 | shutil.copy2(s, d) 26 | 27 | -------------------------------------------------------------------------------- /utils_dir/utils_math.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class Embedder_DCT: 6 | def __init__(self, frm_len, multires, inc_input=True, inc_func='linear'): 7 | self.frm_len = frm_len 8 | self.multires = multires 9 | self.inc_input = inc_input 10 | self.inc_func = inc_func 11 | self.periodic_fns = [torch.cos] 12 | 13 | self.create_embedding_fn() 14 | 15 | def create_embedding_fn(self): 16 | embed_fns = [] 17 | if self.inc_input: 18 | embed_fns.append(lambda x, y: x) # with x 19 | 20 | N_freqs = self.multires 21 | 22 | freq_bands = [] 23 | for k in range(1, N_freqs+1): 24 | if self.inc_func == 'linear': 25 | a_freq = k 26 | elif self.inc_func == 'exp': 27 | a_freq = 2 ** (k-1) 28 | elif self.inc_func == 'pow': 29 | a_freq = k ** 2 30 | else: 31 | raise NotImplementedError('Unsupported inc_func.') 32 | 33 | freq_bands.append(math.pi / self.frm_len * a_freq) # This is DCT 34 | 35 | freq_bands = torch.tensor(freq_bands) 36 | 37 | for freq in freq_bands: 38 | for p_fn in self.periodic_fns: 39 | embed_fns.append(lambda x, frm_idx, p_fn=p_fn, freq=freq: (x * p_fn(freq * (frm_idx + 1/2)))) # this is DCT 40 | 41 | self.embed_fns = embed_fns 42 | 43 | def embed(self, inputs, dim): 44 | t_len_all = inputs.shape[2] 45 | time_list = [] 46 | for t_idx in range(t_len_all): 47 | a_series = inputs[:, :, t_idx, :, :].unsqueeze(2) 48 | # new_time_list = torch.cat([fn(a_series, t_idx) for fn in self.embed_fns], dim) # DCT 49 | 50 | # To try positional encoding 51 | new_time_list = [] 52 | for fn in self.embed_fns: 53 | a_new_one = fn(a_series, t_idx) 54 | new_time_list.append(a_new_one) 55 | new_time_list = torch.cat(new_time_list, dim) 56 | 57 | # To sum encodes 58 | # new_time_list = None 59 | # for fn in self.embed_fns: 60 | # if new_time_list is None: 61 | # new_time_list = fn(a_series, t_idx) 62 | # else: 63 | # new_time_list += fn(a_series, t_idx) 64 | 65 | # print('new_time_list: ', new_time_list.squeeze()) 66 | time_list.append(new_time_list) 67 | rtn = torch.cat(time_list, 2) 68 | return rtn 69 | 70 | 71 | an_embed = Embedder_DCT(300, 8) 72 | 73 | 74 | def gen_dct_on_the_fly(the_data, K=None): 75 | N, C, T, V, M = the_data.shape 76 | if K is None: 77 | K = T 78 | 79 | the_data = an_embed.embed(the_data, dim=1) 80 | return the_data 81 | 82 | 83 | def dct_2_no_sum_parallel(bch_seq, K0=0, K1=None): 84 | bch_seq_rsp = bch_seq.view(-1, bch_seq.shape[1]) 85 | N = bch_seq_rsp.shape[1] 86 | if K1 is None: 87 | K1 = N 88 | basis_list = [] 89 | for k in range(K0, K1): 90 | a_basis_list = [] 91 | for i in range(N): 92 | a_basis_list.append(math.cos(math.pi / N * (i + 0.5) * k)) 93 | basis_list.append(a_basis_list) 94 | basis_list = torch.tensor(basis_list).to(bch_seq_rsp.device) 95 | bch_seq_rsp = bch_seq_rsp.unsqueeze(1).repeat(1, K1 - K0, 1) 96 | dot_prod = torch.einsum('abc,bc->abc', bch_seq_rsp, basis_list) 97 | return dot_prod.view(-1, K1 - K0) 98 | -------------------------------------------------------------------------------- /utils_dir/utils_result.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | 4 | from sklearn.metrics import confusion_matrix 5 | 6 | from metadata.class_labels import ntu120_code_labels, anu_bullying_pair_labels, bly_labels, anubis_ind_actions 7 | # from test_fields.kinetics_analysis import get_kinetics_dict 8 | import numpy as np 9 | 10 | 11 | def get_result_confusion_jsons(gt, pred, data_type, acc_f_name_prefix=None): 12 | if 'ntu' in data_type: 13 | code_labels = ntu120_code_labels 14 | elif 'bly' in data_type: 15 | code_labels = bly_labels 16 | gt = np.array(gt)[:, 0] 17 | elif 'front' in data_type: 18 | code_labels = {} 19 | tmp = {} 20 | for key in anubis_ind_actions: 21 | tmp[anubis_ind_actions[key]] = key 22 | for key in tmp: 23 | code_labels[tmp[key] + 1] = key 24 | # elif 'kinetics' in data_type: 25 | # code_labels = get_kinetics_dict() 26 | else: 27 | raise NotImplementedError 28 | 29 | correct_dict = defaultdict(list) 30 | for idx in range(len(gt)): 31 | correct_dict[gt[idx]].append(int(pred[idx] == gt[idx])) 32 | correct_dict_ = correct_dict.copy() 33 | 34 | for a_key in correct_dict: 35 | correct_dict[a_key] = '{:.6f}'.format(sum(correct_dict[a_key]) / len(correct_dict[a_key])) 36 | 37 | label_acc = {} 38 | for a_key in correct_dict: 39 | label_acc[code_labels[int(a_key) + 1]] = float(correct_dict[a_key]) 40 | 41 | label_acc = dict(sorted(label_acc.items(), key=lambda item: item[1])) 42 | label_acc_keys = list(label_acc.keys()) 43 | 44 | conf_mat = confusion_matrix(gt, pred) 45 | 46 | most_confused = {} 47 | for i in correct_dict.keys(): 48 | confusion_0 = np.argsort(conf_mat[int(i)])[::-1][0] 49 | confusion_1 = np.argsort(conf_mat[int(i)])[::-1][1] 50 | 51 | most_confused[code_labels[int(i) + 1]] = [ 52 | "{} {}".format(code_labels[confusion_0 + 1], conf_mat[int(i)][confusion_0]), 53 | "{} {}".format(code_labels[confusion_1 + 1], conf_mat[int(i)][confusion_1]), 54 | "{}".format(len(correct_dict_[i])) 55 | ] 56 | most_confused_ = {} 57 | for i in label_acc_keys: 58 | most_confused_[i] = most_confused[i] 59 | 60 | if acc_f_name_prefix is not None: 61 | with open('{}_confusion_matrix.json'.format(acc_f_name_prefix), 'w') as f: 62 | json.dump(most_confused_, f, indent=4) 63 | with open('{}_accuracy_per_class.json'.format(acc_f_name_prefix), 'w') as f: 64 | json.dump(label_acc, f, indent=4) 65 | 66 | return label_acc, most_confused_ 67 | 68 | -------------------------------------------------------------------------------- /utils_dir/utils_visual.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | from matplotlib.animation import FuncAnimation 4 | import numpy as np 5 | import torch 6 | 7 | azure_kinect_bone_pairs = ( 8 | (1, 0), (2, 1), (3, 2), (4, 2), (5, 4), (6, 5), (7, 6), (8, 7), (9, 8), (10, 7), (11, 2), (12, 11), 9 | (13, 12), (14, 13), (15, 14), (16, 15), (17, 14), (18, 0), (19, 18), (20, 19), (21, 20), (22, 0), 10 | (23, 22), (24, 23), (25, 24), (26, 3), (27, 26), (28, 27), (29, 28), (30, 27), (31, 30) 11 | ) 12 | 13 | kinect_v2_bone_pairs = tuple((i - 1, j - 1) for (i, j) in ( 14 | (1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), 15 | (7, 6), (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), 16 | (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), 17 | (19, 18), (20, 19), (22, 23), (21, 21), (23, 8), (24, 25), (25, 12) 18 | )) 19 | 20 | bone_pair_dict = { 21 | 'azure_kinect': azure_kinect_bone_pairs, 22 | 'kinect_v2': kinect_v2_bone_pairs 23 | } 24 | 25 | 26 | def azure_kinect_post_visualize(frames, save_name=None, sklt_type='azure_kinect'): 27 | min_x, max_x = torch.min(frames[0][0]).item(), torch.max(frames[0][0]).item() 28 | min_y, max_y = torch.min(frames[0][1]).item(), torch.max(frames[0][1]).item() 29 | min_z, max_z = torch.min(frames[0][2]).item(), torch.max(frames[0][2]).item() 30 | 31 | bones = bone_pair_dict[sklt_type] 32 | def animate(skeletons): 33 | # Skeleton shape is 3*25. 3 corresponds to the 3D coordinates. 25 is the number of joints. 34 | ax.clear() 35 | 36 | # ax.set_xlim([-3000, 1000]) 37 | # ax.set_ylim([-3000, 1000]) 38 | # ax.set_zlim([-3000, 1000]) 39 | 40 | ax.set_xlim([min_x, max_x]) 41 | ax.set_ylim([min_y, max_y]) 42 | ax.set_zlim([min_z, max_z]) 43 | 44 | # ax.set_xlim([-2, -1]) 45 | # ax.set_ylim([2, 3]) 46 | # ax.set_zlim([-2, 0]) 47 | 48 | # ax.set_xticklabels([]) 49 | # ax.set_yticklabels([]) 50 | # ax.set_zticklabels([]) 51 | 52 | # person 1 53 | k = 0 54 | color_list = ('blue', 'orange', 'cyan', 'purple') 55 | # color_list = ('blue', 'blue', 'cyan', 'purple') 56 | color_idx = 0 57 | 58 | while k < skeletons.shape[0]: 59 | for i, j in bones: 60 | joint_locs = skeletons[:, [i, j]] 61 | # plot them 62 | ax.plot(joint_locs[k+0], joint_locs[k+1], joint_locs[k+2], color=color_list[color_idx]) 63 | # ax.plot(-joint_locs[k+0], -joint_locs[k+2], -joint_locs[k+1], color=color_list[color_idx]) 64 | 65 | k += 3 66 | color_idx = (color_idx + 1) % len(color_list) 67 | 68 | # Rotate 69 | # X, Y, Z = axes3d.get_test_data(0.1) 70 | # ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5) 71 | # 72 | # # rotate the axes and update 73 | # for angle in range(0, 360): 74 | # ax.view_init(30, angle) 75 | 76 | if save_name is None: 77 | title = 'Action Visualization' 78 | else: 79 | title = os.path.split(save_name)[-1] 80 | plt.title(title) 81 | skeleton_index[0] += 1 82 | return ax 83 | 84 | for an_entry in range(1): 85 | if isinstance(an_entry, tuple) and len(an_entry) == 2: 86 | index = int(an_entry[0]) 87 | pred_idx = int(an_entry[1]) 88 | else: 89 | index = an_entry 90 | # get data 91 | skeletons = np.copy(frames[index]) 92 | 93 | fig = plt.figure() 94 | ax = fig.gca(projection='3d') 95 | # ax.set_xlim([-1, 1]) 96 | # ax.set_ylim([-1, 1]) 97 | # ax.set_zlim([-1, 1]) 98 | 99 | # print(f'Sample index: {index}\nAction: {action_class}-{action_name}\n') # (C,T,V,M) 100 | 101 | # Pick the first body to visualize 102 | skeleton1 = skeletons[..., 0] # out (C,T,V) 103 | # make it shorter 104 | shorter_frame_start = 0 105 | shorter_frame_end = 300 106 | if skeletons.shape[-1] > 1: 107 | skeleton2 = np.copy(skeletons[..., 1]) # out (C,T,V) 108 | # make it shorter 109 | skeleton2 = skeleton2[:, shorter_frame_start:shorter_frame_end, :] 110 | # print('max of skeleton 2: ', np.max(skeleton2)) 111 | skeleton_frames_2 = skeleton2.transpose(1, 0, 2) 112 | else: 113 | skeleton_frames_2 = None 114 | 115 | skeleton_index = [0] 116 | skeleton_frames_1 = skeleton1[:, shorter_frame_start:shorter_frame_end, :].transpose(1, 0, 2) 117 | 118 | if skeleton_frames_2 is None: 119 | # skeleton_frames_1 = center_normalize_skeleton(skeleton_frames_1) 120 | ani = FuncAnimation(fig, animate, 121 | skeleton_frames_1, 122 | interval=150) 123 | else: 124 | # skeleton_frames_1 = center_normalize_skeleton(skeleton_frames_1) 125 | # skeleton_frames_2 = center_normalize_skeleton(skeleton_frames_2) 126 | ani = FuncAnimation(fig, animate, 127 | np.concatenate((skeleton_frames_1, skeleton_frames_2), axis=1), 128 | interval=150) 129 | 130 | if save_name is None: 131 | save_name = 'tmp_skeleton_video_2.mp4' 132 | print('save name: ', save_name) 133 | ani.save(save_name, dpi=200, writer='ffmpeg') 134 | plt.close('all') 135 | 136 | 137 | def plot_multiple_lines(lines, save_name=None, labels=None, every_n=1): 138 | font = {'size': 30} 139 | import matplotlib 140 | matplotlib.rc('font', **font) 141 | 142 | if labels is not None: 143 | assert len(lines) == len(labels) 144 | markers = ['^', '.', '*'] 145 | # colors = ['#dd0100', '#225095', '#fac901'] 146 | colors = ['#dd0100'] 147 | 148 | # plt.xlim(0, len(lines[0])) # Chronological loss value 149 | # plt.ylim(-0.25, 1.25) # Chronological loss value 150 | 151 | plt.xlim(0, len(lines[0])) 152 | plt.ylim(0, 9) 153 | 154 | x_axis_list = list([x for x in range(0, len(lines[0]), every_n)]) 155 | for line_idx, a_line in enumerate(lines): 156 | a_line_plot = list([a_line[i] for i in range(0, len(lines[0]), every_n)]) 157 | plt.plot(x_axis_list, a_line_plot, 158 | color=colors[line_idx % len(colors)], 159 | marker=markers[line_idx % len(markers)], 160 | markersize=30, 161 | label=labels[line_idx] if labels is not None else None) 162 | 163 | # plt.xticks(x_axis_list, a_line_plot) 164 | plt.grid() 165 | # plt.legend(loc=4) # Chronological loss value 166 | 167 | fig = matplotlib.pyplot.gcf() 168 | 169 | # fig.set_size_inches(7, 10) # Chronological loss value 170 | fig.set_size_inches(10, 8) 171 | 172 | if save_name is None: 173 | plt.show() 174 | else: 175 | plt.savefig(save_name, bbox_inches='tight') 176 | plt.close() 177 | plt.show() 178 | --------------------------------------------------------------------------------