├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── classification ├── get_dataloader.py ├── get_model.py ├── model.py └── run.py ├── data_pipeline ├── __init__.py ├── ntu_rgbd.py ├── pku_mmd.py └── ucf_101.py ├── detection ├── evaluation │ └── map.py ├── get_dataloader.py ├── get_model.py ├── model.py └── run.py ├── images └── pull.png ├── nets ├── get_distillation_kernel.py ├── get_gru.py └── get_tad.py ├── scripts ├── download_models.sh ├── test_ntu_rgbd.sh ├── test_pku_mmd.sh ├── train_ntu_rgbd.sh ├── train_ntu_rgbd_distillation.sh ├── train_pku_mmd.sh └── train_pku_mmd_distillation.sh ├── third_party ├── pku_mmd │ ├── LICENSE │ └── evaluate.py ├── pytorch │ ├── LICENSE │ └── get_cnn.py └── two_stream_pytorch │ ├── LICENSE │ └── video_transforms.py └── utils ├── __init__.py ├── imgproc.py ├── logging.py ├── misc.py ├── skelproc.py └── visualize.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2018 Google LLC 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | https://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Distillation 2 | 3 | 4 | 5 | This is the code for the paper 6 | **[Graph Distillation for Action Detection with Privileged Modalities](https://arxiv.org/abs/1712.00108)** 7 | presented at [ECCV 2018](https://eccv2018.org/) 8 | 9 | *Please note that this is not an officially supported Google product.* 10 | 11 | In this work, we propose a method termed **graph distillation** that incorporates rich privileged information from a large-scale multi- modal dataset in the source domain, and improves the learning in the target domain where training data and modalities are scarce. 12 | 13 | If you find this code useful in your research then please cite 14 | 15 | ``` 16 | @inproceedings{luo2018graph, 17 | title={Graph Distillation for Action Detection with Privileged Modalities}, 18 | author={Luo, Zelun and Hsieh, Jun-Ting and Jiang, Lu and Niebles, Juan Carlos and Fei-Fei, Li}, 19 | booktitle={ECCV}, 20 | year={2018} 21 | } 22 | ``` 23 | 24 | ## Setup 25 | All code was developed and tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.3.1. 26 | 27 | 28 | 29 | 30 | ## Pretrained Models 31 | We can download pretrained models used in our paper running the script: 32 | 33 | ``` 34 | sh scripts/download_models.sh 35 | ``` 36 | 37 | Or alternatively you can download Cloud SDK 38 | 39 | 1. Install Google Cloud SDK (https://cloud.google.com/sdk/install) 40 | 2. Copy the pretrained model using the following commands: 41 | 42 | ``` 43 | gsutil -m cp -r gs://graph_distillation/ckpt . 44 | ``` 45 | 46 | 47 | ## Running Models 48 | We can use the scripts in `scripts/` to train models on different modalities. 49 | 50 | 51 | 52 | ### Classification 53 | See `classification/run.py` for descriptions of the arguments. 54 | 55 | `scripts/train_ntu_rgbd.sh` trains a model for a single modality. 56 | 57 | `scripts/train_ntu_rgbd_distillation.sh` trains model with graph distillation. The modality being trained is specified by the `xfer_to` argument, and the modalities to distill from is specified in the `modalities` argument. 58 | 59 | ### Detection 60 | See `detection/run.py` for descriptions of the arguments. Note that the `visual_encoder_ckpt_path` argument is the pretrained visual encoder checkpoint, which should be from training classification models. 61 | 62 | `scripts/train_pku_mmd.sh` trains a model for a single modality. 63 | 64 | `scripts/train_pku_mmd_distillation.sh` trains model with graph distillation. The modality being trained is specified by the `xfer_to` argument, and the modalities to distill from is specified in the `modalities` argument. 65 | -------------------------------------------------------------------------------- /classification/get_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Get data loader for classification.""" 17 | 18 | import os 19 | from data_pipeline.ntu_rgbd import NTU_RGBD 20 | import third_party.two_stream_pytorch.video_transforms as vtransforms 21 | import torch.utils.data as data 22 | import torchvision.transforms as transforms 23 | 24 | 25 | CNN_MODALITIES = ['rgb', 'oflow', 'depth'] 26 | GRU_MODALITIES = ['jjd', 'jjv', 'jld'] 27 | 28 | 29 | def get_dataloader(opt): 30 | """Constructs dataset with transforms, and wrap it in a data loader.""" 31 | idx_t = 0 if opt.split == 'train' else 1 32 | 33 | xforms = [] 34 | for modality in opt.modalities: 35 | if opt.dset == 'ntu-rgbd': 36 | mean, std = NTU_RGBD.MEAN_STD[modality] 37 | else: 38 | raise NotImplementedError 39 | 40 | if opt.split == 'train' and (modality == 'rgb' or modality == 'depth'): 41 | xform = transforms.Compose([ 42 | vtransforms.RandomSizedCrop(224), 43 | vtransforms.RandomHorizontalFlip(), 44 | vtransforms.ToTensor(), 45 | vtransforms.Normalize(mean, std) 46 | ]) 47 | elif opt.split == 'train' and modality == 'oflow': 48 | # Special handling when flipping optical flow 49 | xform = transforms.Compose([ 50 | vtransforms.RandomSizedCrop(224, True), 51 | vtransforms.RandomHorizontalFlip(True), 52 | vtransforms.ToTensor(), 53 | vtransforms.Normalize(mean, std) 54 | ]) 55 | elif opt.split != 'train' and modality in CNN_MODALITIES: 56 | xform = transforms.Compose([ 57 | vtransforms.Scale(256), 58 | vtransforms.CenterCrop(224), 59 | vtransforms.ToTensor(), 60 | vtransforms.Normalize(mean, std) 61 | ]) 62 | elif modality in GRU_MODALITIES: 63 | xform = transforms.Compose([vtransforms.SkelNormalize(mean, std)]) 64 | else: 65 | raise Exception 66 | 67 | xforms.append(xform) 68 | 69 | if opt.dset == 'ntu-rgbd': 70 | root = os.path.join(opt.dset_path, opt.dset) 71 | dset = NTU_RGBD(root, opt.split, 'cross-subject', opt.modalities, 72 | opt.n_samples[idx_t], opt.n_frames, opt.downsample, xforms, 73 | opt.subsample) 74 | else: 75 | raise NotImplementedError 76 | 77 | dataloader = data.DataLoader( 78 | dset, 79 | batch_size=opt.batch_sizes[idx_t], 80 | shuffle=(opt.split == 'train'), 81 | num_workers=opt.n_workers) 82 | 83 | return dataloader 84 | -------------------------------------------------------------------------------- /classification/get_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Get classification model.""" 17 | 18 | from .model import GraphDistillation 19 | from .model import SingleStream 20 | 21 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 22 | 23 | 24 | def get_model(opt): 25 | """Get model given the dataset and modalities.""" 26 | if opt.dset == 'ntu-rgbd': 27 | n_classes = 60 28 | all_input_sizes = [-1, -1, -1, 276, 828, 836] 29 | all_n_channels = [3, 2, 1, -1, -1, -1] 30 | else: 31 | raise NotImplementedError 32 | 33 | n_channels = [all_n_channels[ALL_MODALITIES.index(m)] for m in opt.modalities] 34 | input_sizes = [ 35 | all_input_sizes[ALL_MODALITIES.index(m)] for m in opt.modalities 36 | ] 37 | 38 | if len(opt.modalities) == 1: 39 | # Single stream 40 | model = SingleStream(opt.modalities, n_classes, opt.n_frames, n_channels, 41 | input_sizes, opt.hidden_size, opt.n_layers, 42 | opt.dropout, opt.lr, opt.lr_decay_rate, opt.ckpt_path) 43 | else: 44 | model = GraphDistillation( 45 | opt.modalities, n_classes, opt.n_frames, n_channels, input_sizes, 46 | opt.hidden_size, opt.n_layers, opt.dropout, opt.lr, opt.lr_decay_rate, 47 | opt.ckpt_path, opt.w_losses, opt.w_modalities, opt.metric, opt.xfer_to, 48 | opt.gd_size, opt.gd_reg) 49 | 50 | return model 51 | -------------------------------------------------------------------------------- /classification/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model to train classification.""" 17 | 18 | from collections import OrderedDict 19 | import os 20 | from third_party.pytorch.get_cnn import * 21 | from nets.get_distillation_kernel import * 22 | from nets.get_gru import * 23 | from torch.autograd import Variable 24 | import torch.backends.cudnn as cudnn 25 | import torch.nn.functional as F 26 | import torch.optim as optim 27 | import utils 28 | 29 | 30 | CNN_MODALITIES = ['rgb', 'oflow', 'depth'] 31 | GRU_MODALITIES = ['jjd', 'jjv', 'jld'] 32 | 33 | 34 | class BaseModel: 35 | """Base class for the model.""" 36 | 37 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes, 38 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path): 39 | super(BaseModel, self).__init__() 40 | cudnn.benchmark = True 41 | 42 | self.embeds = [] 43 | for _, (modality, n_channels_m, input_size) in enumerate( 44 | zip(modalities, n_channels, input_sizes)): 45 | if modality in CNN_MODALITIES: 46 | self.embeds.append( 47 | nn.DataParallel( 48 | get_resnet(n_frames * n_channels_m, n_classes).cuda())) 49 | elif modality in GRU_MODALITIES: 50 | self.embeds.append( 51 | nn.DataParallel( 52 | get_gru(input_size, hidden_size, n_layers, dropout, 53 | n_classes).cuda())) 54 | else: 55 | raise NotImplementedError 56 | 57 | self.optimizer = None 58 | self.criterion_cls = nn.CrossEntropyLoss().cuda() 59 | 60 | self.modalities = modalities 61 | self.lr = lr 62 | self.lr_decay_rate = lr_decay_rate 63 | self.ckpt_path = ckpt_path 64 | 65 | def _forward(self, inputs): 66 | """Forward pass for all modalities. Return the representation and logits.""" 67 | logits, reprs = [], [] 68 | 69 | # Forward pass for all modalities 70 | for i, (input, embed, modality) in enumerate( 71 | zip(list(inputs), self.embeds, self.modalities)): 72 | if modality in CNN_MODALITIES: 73 | batch_size, n_samples, n_frames, n_channels, h, w = input.size() 74 | input = input.view(batch_size * n_samples, n_frames * n_channels, h, w) 75 | elif modality in GRU_MODALITIES: 76 | batch_size, n_samples, n_frames, input_size = input.size() 77 | input = input.view(batch_size * n_samples, n_frames, input_size) 78 | else: 79 | raise NotImplementedError 80 | 81 | logit, representation = embed(input) 82 | logit = logit.view(batch_size, n_samples, -1) 83 | representation = representation.view(batch_size, n_samples, -1) 84 | 85 | logits.append(logit.mean(1)) 86 | reprs.append(representation.mean(1)) 87 | 88 | logits = torch.stack(logits) 89 | reprs = torch.stack(reprs) 90 | return [logits, reprs] 91 | 92 | def _backward(self, results, label): 93 | raise NotImplementedError 94 | 95 | def train(self, inputs, label): 96 | """Train the model. 97 | 98 | Args: 99 | inputs: a list, each is batch_size x n_sample x n_frames x 100 | (n_channels x h x w) or (input_size). 101 | label: batch_size x n_samples. 102 | Returns: 103 | info: dictionary of results 104 | """ 105 | for embed in self.embeds: 106 | embed.train() 107 | 108 | for i in range(len(inputs)): 109 | inputs[i] = Variable(inputs[i].cuda(), requires_grad=False) 110 | label = Variable(label.cuda(), requires_grad=False) 111 | 112 | results = self._forward(inputs) 113 | info_loss = self._backward(results, label) 114 | info_acc = self._get_acc(results[0], label) 115 | 116 | return OrderedDict(info_loss + info_acc) 117 | 118 | def test(self, inputs, label): 119 | """Test the model. 120 | 121 | Args: 122 | inputs: a list, each is batch_size x n_sample x n_frames x 123 | (n_channels x h x w) or (input_size). 124 | label: batch_size x n_samples. 125 | Returns: 126 | info_acc: dictionary of results 127 | """ 128 | for embed in self.embeds: 129 | embed.eval() 130 | 131 | inputs = [Variable(x.cuda(), volatile=True) for x in inputs] 132 | label = Variable(label.cuda(), volatile=True) 133 | 134 | logits, _ = self._forward(inputs) 135 | info_acc = self._get_acc(logits, label) 136 | 137 | return OrderedDict(info_acc), logits 138 | 139 | def _get_acc(self, logits, label): 140 | info_acc = [] 141 | for _, (logit, modality) in enumerate(zip(logits, self.modalities)): 142 | acc, _, label = utils.get_stats(logit, label) 143 | info_acc.append(('acc_{}'.format(modality), acc)) 144 | return info_acc 145 | 146 | def lr_decay(self): 147 | lrs = [] 148 | for param_group in self.optimizer.param_groups: 149 | param_group['lr'] *= self.lr_decay_rate 150 | lrs.append(param_group['lr']) 151 | return lrs 152 | 153 | def save(self, epoch): 154 | path = os.path.join(self.ckpt_path, 'embed_{}.pth'.format(epoch)) 155 | torch.save(self.embeds[self.to_idx].state_dict(), path) 156 | 157 | def load(self, load_ckpt_paths, epoch=200): 158 | """Load trained models.""" 159 | assert len(load_ckpt_paths) == len(self.embeds) 160 | for i, ckpt_path in enumerate(load_ckpt_paths): 161 | if len(ckpt_path) > 0: 162 | path = os.path.join(ckpt_path, 'embed_{}.pth'.format(epoch)) 163 | self.embeds[i].load_state_dict(torch.load(path)) 164 | utils.info('{}: ckpt {} loaded'.format(self.modalities[i], path)) 165 | else: 166 | utils.info('{}: training from scratch'.format(self.modalities[i])) 167 | 168 | 169 | class SingleStream(BaseModel): 170 | """Model to train a single modality.""" 171 | 172 | def __init__(self, *args, **kwargs): 173 | super(SingleStream, self).__init__(*args, **kwargs) 174 | assert len(self.embeds) == 1 175 | self.optimizer = optim.SGD( 176 | self.embeds[0].parameters(), 177 | lr=self.lr, 178 | momentum=0.9, 179 | weight_decay=5e-4) 180 | self.to_idx = 0 181 | 182 | def _backward(self, results, label): 183 | logits, _ = results 184 | logits = logits.view(*logits.size()[1:]) 185 | loss = self.criterion_cls(logits, label) 186 | loss.backward() 187 | self.optimizer.step() 188 | self.optimizer.zero_grad() 189 | 190 | info_loss = [('loss', loss.data[0])] 191 | return info_loss 192 | 193 | 194 | class GraphDistillation(BaseModel): 195 | """Model to train with graph distillation. 196 | 197 | xfer_to is the modality to train. 198 | """ 199 | 200 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes, 201 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path, 202 | w_losses, w_modalities, metric, xfer_to, gd_size, gd_reg): 203 | super(GraphDistillation, self).__init__( \ 204 | modalities, n_classes, n_frames, n_channels, input_sizes, 205 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path) 206 | 207 | # Index of the modality to distill 208 | to_idx = self.modalities.index(xfer_to) 209 | from_idx = [x for x in range(len(self.modalities)) if x != to_idx] 210 | assert len(from_idx) >= 1 211 | 212 | # Prior 213 | w_modalities = [w_modalities[i] for i in from_idx 214 | ] # remove modality being transferred to 215 | gd_prior = utils.softmax(w_modalities, 0.25) 216 | # Distillation model 217 | self.distillation_kernel = get_distillation_kernel( 218 | n_classes, hidden_size, gd_size, to_idx, from_idx, gd_prior, gd_reg, 219 | w_losses, metric).cuda() 220 | 221 | params = list(self.embeds[to_idx].parameters()) + \ 222 | list(self.distillation_kernel.parameters()) 223 | self.optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=5e-4) 224 | 225 | self.xfer_to = xfer_to 226 | self.to_idx = to_idx 227 | self.from_idx = from_idx 228 | 229 | def _forward(self, inputs): 230 | logits, reprs = super(GraphDistillation, self)._forward(inputs) 231 | # Get edge weights of the graph 232 | graph = self.distillation_kernel(logits, reprs) 233 | return logits, reprs, graph 234 | 235 | def _backward(self, results, label): 236 | logits, reprs, graph = results # graph: size = len(from_idx) x batch_size 237 | info_loss = [] 238 | 239 | # Classification loss 240 | loss_cls = self.criterion_cls(logits[self.to_idx], label) 241 | # Graph distillation loss 242 | loss_reg, loss_logit, loss_repr = \ 243 | self.distillation_kernel.distillation_loss(logits, reprs, graph) 244 | 245 | loss = loss_cls + loss_reg + loss_logit + loss_repr 246 | loss.backward() 247 | if self.xfer_to in GRU_MODALITIES: 248 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5) 249 | self.optimizer.step() 250 | self.optimizer.zero_grad() 251 | 252 | info_loss = [('loss_cls', loss_cls.data[0]), ('loss_reg', loss_reg.data[0]), 253 | ('loss_logit', loss_logit.data[0]), ('loss_repr', 254 | loss_repr.data[0])] 255 | return info_loss 256 | -------------------------------------------------------------------------------- /classification/run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Train and test classification.""" 17 | 18 | import argparse 19 | import os 20 | from .get_dataloader import * 21 | from .get_model import * 22 | import numpy as np 23 | from sklearn.metrics import average_precision_score 24 | import utils 25 | import utils.logging as logging 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | # experimental settings 30 | parser.add_argument('--n_workers', type=int, default=24) 31 | parser.add_argument('--gpus', type=str, default='0') 32 | parser.add_argument('--split', type=str, choices=['train', 'test']) 33 | 34 | # ckpt and logging 35 | parser.add_argument('--ckpt_path', type=str, default='./ckpt', 36 | help='directory path that stores all checkpoints') 37 | parser.add_argument('--ckpt_name', type=str, default='ckpt') 38 | parser.add_argument('--pretrained_ckpt_name', type=str, default='ckpt', 39 | help='prefix of checkpoints used for graph distillation') 40 | parser.add_argument('--load_ckpt_path', type=str, default='', 41 | help='checkpoint path to load for testing/initialization') 42 | parser.add_argument('--load_epoch', type=int, default=200, 43 | help='Checkpoint epoch to load for testing.') 44 | parser.add_argument('--print_every', type=int, default=50) 45 | parser.add_argument('--save_every', type=int, default=50) 46 | 47 | # hyperparameters 48 | parser.add_argument('--batch_sizes', type=int, nargs='+', default=[64, 8], 49 | help='batch sizes: [train, test]') 50 | parser.add_argument('--n_epochs', type=int, default=200) 51 | parser.add_argument('--lr', type=float, default=1e-2) 52 | parser.add_argument('--lr_decay_at', type=int, nargs='+', default=[125, 175]) 53 | parser.add_argument('--lr_decay_rate', type=float, default=0.1) 54 | 55 | # data pipeline 56 | parser.add_argument('--dset', type=str, default='ntu-rgbd') 57 | parser.add_argument('--dset_path', type=str, 58 | default=os.path.join(os.environ['HOME'], 'slowbro')) 59 | parser.add_argument('--modalities', type=str, nargs='+', 60 | choices=['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']) 61 | parser.add_argument('--n_samples', type=int, nargs='+', default=[1, 5], 62 | help='Number of samples clips per video: [train, test]') 63 | parser.add_argument('--step_size', type=int, default=10, 64 | help='step size between samples (after downsample)') 65 | parser.add_argument('--n_frames', type=int, default=10, 66 | help='num frames per sample') 67 | parser.add_argument('--downsample', type=int, default=3, 68 | help='fps /= downsample') 69 | parser.add_argument('--subsample', type=int, default=33, 70 | help='subsample the dataset. 0: False, >0:' 71 | 'number of examples per class') 72 | 73 | # GRU 74 | parser.add_argument('--dropout', type=float, default=0.5) 75 | parser.add_argument('--hidden_size', type=int, default=512) 76 | parser.add_argument('--n_layers', type=int, default=3) 77 | 78 | # Graph Distillation parameters 79 | parser.add_argument('--metric', type=str, default='cosine', 80 | choices=['cosine', 'kl', 'l2', 'l1'], 81 | help='distance metric for distillation loss') 82 | parser.add_argument('--w_losses', type=float, nargs='+', default=[10, 1], 83 | help='weights for losses: [logit, repr]') 84 | parser.add_argument('--w_modalities', type=float, nargs='+', 85 | default=[1, 1, 1, 1, 1, 1], 86 | help='modality prior') 87 | parser.add_argument('--xfer_to', type=str, default='', 88 | help='modality to train with graph distillation') 89 | parser.add_argument('--gd_size', type=int, default=32, 90 | help='hidden size of graph distillation') 91 | parser.add_argument('--gd_reg', type=float, default=10, 92 | help='regularization for graph distillation') 93 | 94 | 95 | def single_stream(opt): 96 | """Train a single modality from scratch.""" 97 | # Checkpoint path example: ckpt_path/ntu-rgbd/rgb/ckpt 98 | opt.ckpt_path = os.path.join(opt.ckpt_path, opt.dset, opt.modalities[0], 99 | opt.ckpt_name) 100 | opt.load_ckpt_paths = [opt.load_ckpt_path] 101 | os.makedirs(opt.ckpt_path, exist_ok=True) 102 | 103 | # Data loader and model 104 | dataloader = get_dataloader(opt) 105 | model = get_model(opt) 106 | if opt.split == 'train': 107 | train(opt, model, dataloader) 108 | else: 109 | test(opt, model, dataloader) 110 | 111 | 112 | def multi_stream(opt): 113 | """Train a modality with graph distillation from other modalities. 114 | 115 | The modality is specified by opt.xfer_to 116 | """ 117 | assert opt.xfer_to in opt.modalities, 'opt.xfer_to must be in opt.modalities' 118 | # Checkpoints to load 119 | opt.load_ckpt_paths = [] 120 | for m in opt.modalities: 121 | if m != opt.xfer_to: 122 | # Checkpoint from single_stream 123 | path = os.path.join(opt.ckpt_path, opt.dset, m, opt.pretrained_ckpt_name) 124 | assert os.path.exists(path), '{} checkpoint does not exist.'.format(path) 125 | opt.load_ckpt_paths.append(path) 126 | else: 127 | opt.load_ckpt_paths.append(opt.load_ckpt_path) 128 | 129 | # Checkpoint path example: ckpt_path/ntu-rgbd/xfer_rgb/ckpt_rgb_depth 130 | opt.ckpt_path = os.path.join( 131 | opt.ckpt_path, opt.dset, 'xfer_{}'.format(opt.xfer_to), '{}_{}'.format( 132 | opt.ckpt_name, '_'.join([m for m in opt.modalities]))) 133 | os.makedirs(opt.ckpt_path, exist_ok=True) 134 | 135 | # Data loader and model 136 | dataloader = get_dataloader(opt) 137 | model = get_model(opt) 138 | train(opt, model, dataloader) 139 | 140 | 141 | def train(opt, model, dataloader): 142 | """Train the model.""" 143 | # Logging 144 | logger = logging.Logger(opt.ckpt_path, opt.split) 145 | stats = logging.Statistics(opt.ckpt_path, opt.split) 146 | logger.log(opt) 147 | 148 | model.load(opt.load_ckpt_paths, opt.load_epoch) 149 | for epoch in range(1, opt.n_epochs + 1): 150 | for step, data in enumerate(dataloader, 1): 151 | ret = model.train(*data) 152 | update = stats.update(data[-1].size(0), ret) 153 | if utils.is_due(step, opt.print_every): 154 | utils.info('epoch {}/{}, step {}/{}: {}'.format( 155 | epoch, opt.n_epochs, step, len(dataloader), update)) 156 | 157 | logger.log('[Summary] epoch {}/{}: {}'.format(epoch, opt.n_epochs, 158 | stats.summarize())) 159 | 160 | if utils.is_due(epoch, opt.n_epochs, opt.save_every): 161 | model.save(epoch) 162 | logger.log('***** saved *****') 163 | 164 | if utils.is_due(epoch, opt.lr_decay_at): 165 | lrs = model.lr_decay() 166 | logger.log('***** lr decay *****: {}'.format(lrs)) 167 | 168 | 169 | def test(opt, model, dataloader): 170 | '''Test model.''' 171 | # Logging 172 | logger = logging.Logger(opt.load_ckpt_path, opt.split) 173 | stats = logging.Statistics(opt.ckpt_path, opt.split) 174 | logger.log(opt) 175 | 176 | logits, labels = [], [] 177 | model.load(opt.load_ckpt_paths, opt.load_epoch) 178 | for step, data in enumerate(dataloader, 1): 179 | inputs, label = data 180 | info_acc, logit = model.test(inputs, label) 181 | logits.append(utils.to_numpy(logit.squeeze(0))) 182 | labels.append(utils.to_numpy(label)) 183 | update = stats.update(label.size(0), info_acc) 184 | if utils.is_due(step, opt.print_every): 185 | utils.info('step {}/{}: {}'.format(step, len(dataloader), update)) 186 | 187 | logits = np.concatenate(logits, axis=0) 188 | length, n_classes = logits.shape 189 | labels = np.concatenate(labels) 190 | scores = utils.softmax(logits, axis=1) 191 | 192 | # Accuracy 193 | preds = np.argmax(scores, axis=1) 194 | acc = np.sum(preds == labels) / length 195 | # Average precision 196 | y_true = np.zeros((length, n_classes)) 197 | y_true[np.arange(length), labels] = 1 198 | aps = average_precision_score(y_true, scores, average=None) 199 | aps = list(filter(lambda x: not np.isnan(x), aps)) 200 | mAP = np.mean(aps) 201 | 202 | logger.log('[Summary]: {}'.format(stats.summarize())) 203 | logger.log('Acc: {}, mAP: {}'.format(acc, mAP)) 204 | 205 | 206 | if __name__ == '__main__': 207 | opt = parser.parse_args() 208 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus 209 | 210 | if opt.split == 'test': 211 | assert len(opt.modalities) == 1, 'specify only 1 modality for testing' 212 | assert len(opt.load_ckpt_path) > 0, 'specify load_ckpt_path for testing' 213 | 214 | if len(opt.modalities) == 1: 215 | single_stream(opt) 216 | else: 217 | multi_stream(opt) 218 | -------------------------------------------------------------------------------- /data_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/graph_distillation/1a7ce5125098e7df869f08e15e2d6d8bb3189382/data_pipeline/__init__.py -------------------------------------------------------------------------------- /data_pipeline/ntu_rgbd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """NTU RGB-D dataset.""" 17 | 18 | 19 | import glob 20 | import os 21 | import numpy as np 22 | import torch.utils.data as data 23 | 24 | import utils 25 | import utils.imgproc as imgproc 26 | 27 | 28 | rgb_folder_name = 'rgb' 29 | rgb_pattern = 'RGB-%08d.jpg' 30 | oflow_folder_name = 'oflow' 31 | oflow_pattern = 'OFlow-%08d.jpg' 32 | depth_folder_name = 'depth' 33 | depth_pattern = 'Depth-%08d.png' 34 | skel_folder_name = 'skeleton' 35 | feat_folder_name = 'feature' 36 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 37 | train_cross_subject = [ 38 | 1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38 39 | ] 40 | train_cross_view = [2, 3] 41 | n_classes = 60 42 | 43 | 44 | def make_dataset(root, has_skel, evaluation, split, subsample): 45 | """Returns a list of (video_name, class).""" 46 | if has_skel or split == 'test' or subsample: 47 | vid_names = [os.path.splitext(n)[0] \ 48 | for n in glob.glob1(os.path.join(root, skel_folder_name, 'jjd') 49 | , '*.npy')] 50 | else: 51 | vid_names = os.listdir(os.path.join(root, rgb_folder_name)) 52 | vid_names = sorted(vid_names) 53 | 54 | if evaluation == 'cross-subject' and split != 'test': 55 | vid_names = [n for n in vid_names if int(n[9:12]) in train_cross_subject] 56 | elif evaluation == 'cross-subject' and split == 'test': 57 | vid_names = [ 58 | n for n in vid_names if int(n[9:12]) not in train_cross_subject 59 | ] 60 | elif evaluation == 'cross-view' and split != 'test': 61 | vid_names = [n for n in vid_names if int(n[5:8]) in train_cross_view] 62 | elif evaluation == 'cross-view' and split == 'test': 63 | vid_names = [n for n in vid_names if int(n[5:8]) not in train_cross_view] 64 | else: 65 | raise NotImplementedError 66 | 67 | if subsample: 68 | labels = np.array([int(n[-3:])-1 for n in vid_names]) 69 | vid_names_subsample = [] 70 | for i in range(n_classes): 71 | keep = np.where(labels == i)[0] 72 | keep = keep[np.linspace(0, len(keep)-1, subsample).astype(int)] 73 | vid_names_subsample += [vid_names[j] for j in keep] 74 | vid_names = vid_names_subsample 75 | 76 | elif has_skel and split == 'train' and not subsample: 77 | labels = np.array([int(n[-3:])-1 for n in vid_names]) 78 | vid_names_add = [] 79 | class_54 = np.where(labels == 54)[0].tolist() 80 | class_58 = np.where(labels == 58)[0].tolist() 81 | class_59 = np.where(labels == 59)[0].tolist() 82 | 83 | # deterministic oversampling for consistency 84 | for i in class_54[::7]: 85 | vid_names_add.append(vid_names[i]) 86 | for i in class_58+class_58[::3]: 87 | vid_names_add.append(vid_names[i]) 88 | for i in class_59+class_59[::2]+class_59[1::3]: 89 | vid_names_add.append(vid_names[i]) 90 | vid_names += vid_names_add 91 | 92 | dataset = [(n, int(n[-3:]) - 1) for n in vid_names] 93 | 94 | utils.info('NTU-RGBD: {}, {}, {} videos'.format(split, evaluation, 95 | len(dataset))) 96 | return dataset 97 | 98 | 99 | def rgb_loader(root, vid_name, frame_ids): 100 | """Loads the RGB data.""" 101 | vid = [] 102 | for frame_ids_s in frame_ids: 103 | vid_s = [] 104 | for frame_id in frame_ids_s: 105 | path = os.path.join(root, rgb_folder_name, vid_name, 106 | rgb_pattern % (frame_id + 1)) 107 | img = imgproc.imread_rgb('ntu-rgbd', path) 108 | vid_s.append(img) 109 | vid.append(vid_s) 110 | return np.array(vid) 111 | 112 | 113 | def oflow_loader(root, vid_name, frame_ids): 114 | """Loads the flow data.""" 115 | vid = [] 116 | for frame_ids_s in frame_ids: 117 | vid_s = [] 118 | for frame_id in frame_ids_s: 119 | path = os.path.join(root, oflow_folder_name, vid_name, 120 | oflow_pattern % (frame_id + 1)) 121 | img = imgproc.imread_oflow('ntu-rgbd', path) 122 | vid_s.append(img) 123 | vid.append(vid_s) 124 | return np.array(vid) 125 | 126 | 127 | def depth_loader(root, vid_name, frame_ids): 128 | """Loads the depth data.""" 129 | vid = [] 130 | for frame_ids_s in frame_ids: 131 | vid_s = [] 132 | for frame_id in frame_ids_s: 133 | path = os.path.join(root, depth_folder_name, vid_name, 134 | depth_pattern % (frame_id + 1)) 135 | img = imgproc.imread_depth('ntu-rgbd', path) 136 | vid_s.append(img) 137 | vid.append(vid_s) 138 | return np.array(vid) 139 | 140 | 141 | def jjd_loader(root, vid_name, frame_ids): 142 | """Loads the skeleton data JJD.""" 143 | path = os.path.join(root, skel_folder_name, 'jjd', vid_name + '.npy') 144 | skel = np.load(path).astype(np.float32) 145 | skel = skel[frame_ids] 146 | return skel 147 | 148 | 149 | def jjv_loader(root, vid_name, frame_ids): 150 | """Loads the skeleton data JJV.""" 151 | path = os.path.join(root, skel_folder_name, 'jjv', vid_name + '.npy') 152 | skel = np.load(path).astype(np.float32) 153 | skel = skel[frame_ids] 154 | return skel 155 | 156 | 157 | def jld_loader(root, vid_name, frame_ids): 158 | """Loads the skeleton data JLD.""" 159 | path = os.path.join(root, skel_folder_name, 'jld', vid_name + '.npy') 160 | skel = np.load(path).astype(np.float32) 161 | skel = skel[frame_ids] 162 | return skel 163 | 164 | 165 | class NTU_RGBD(data.Dataset): 166 | """Class for NTU RGBD Dataset""" 167 | MEAN_STD = { 168 | 'rgb': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 169 | 'oflow': (0.5, 1. / 255), 170 | 'depth': (4084.1213735 / (255 * 256), 1008.31271366 / (255 * 256)), 171 | 'jjd': (0.53968146, 0.32319776), 172 | 'jjv': (0, 0.35953656), 173 | 'jld': (0.15982792, 0.12776225) 174 | } 175 | 176 | def __init__(self, 177 | root, 178 | split, 179 | evaluation, 180 | modalities, 181 | n_samples, 182 | n_frames, 183 | downsample, 184 | transforms=None, 185 | subsample=0): 186 | """NTU RGBD dataset. 187 | 188 | Args: 189 | root: dataset root 190 | split: train to randomly select n_samples samples; test to uniformly 191 | select n_samples spanning the whole video 192 | evaluation: one of ['cross-subject', 'cross-view'] 193 | modalities: subset of ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 194 | n_samples: number of samples from the video 195 | n_frames: number of frames per sample 196 | downsample: fps /= downsample 197 | transforms: transformations to apply to data 198 | subsample: number of samples per class. 0 if using full dataset. 199 | """ 200 | modalities = utils.unsqueeze(modalities) 201 | transforms = utils.unsqueeze(transforms) 202 | 203 | # Loader functions 204 | loaders = { 205 | 'rgb': rgb_loader, 206 | 'oflow': oflow_loader, 207 | 'depth': depth_loader, 208 | 'jjd': jjd_loader, 209 | 'jjv': jjv_loader, 210 | 'jld': jld_loader 211 | } 212 | has_skel = any([m in ALL_MODALITIES[3:] for m in modalities]) 213 | dataset = make_dataset(root, has_skel, evaluation, split, subsample) 214 | 215 | self.root = root 216 | self.split = split 217 | self.modalities = modalities 218 | self.n_samples = n_samples 219 | self.n_frames = n_frames 220 | self.downsample = downsample 221 | self.transforms = transforms 222 | self.loaders = loaders 223 | self.dataset = dataset 224 | 225 | def __getitem__(self, idx): 226 | vid_name, label = self.dataset[idx] 227 | # -1 because len(oflow) = len(rgb)-1 228 | length = len( 229 | glob.glob( 230 | os.path.join(self.root, rgb_folder_name, vid_name, 231 | '*.' + rgb_pattern.split('.')[1]))) - 1 232 | 233 | length_ds = length // self.downsample 234 | if length_ds < self.n_frames: 235 | frame_ids_s = np.arange(0, length_ds, 1) # arange: exclusive 236 | frame_ids_s = np.concatenate( 237 | (frame_ids_s, 238 | np.array([frame_ids_s[-1]] * (self.n_frames - length_ds)))) 239 | frame_ids = np.repeat( 240 | frame_ids_s[np.newaxis, :], self.n_samples, 241 | axis=0).astype(int) * self.downsample 242 | else: 243 | if self.split == 'train': # randomly select n_samples samples 244 | starts = np.random.randint(0, length_ds - self.n_frames + 1, 245 | self.n_samples) # randint: exclusive 246 | # uniformly select n_samples spanning the whole video 247 | elif self.split == 'val' or self.split == 'test': 248 | starts = np.linspace( 249 | 0, length_ds - self.n_frames, self.n_samples, 250 | dtype=int) # linspace: inclusive 251 | else: 252 | starts = np.arange(0, 253 | length_ds - self.n_frames + 1) # arange: exclusive 254 | 255 | frame_ids = [] 256 | for start in starts: 257 | frame_ids_s = np.arange(start, start + self.n_frames, 258 | 1) * self.downsample # arange: exclusive 259 | frame_ids.append(frame_ids_s) 260 | frame_ids = np.stack(frame_ids) 261 | 262 | # load raw data 263 | inputs = [] 264 | for modality in self.modalities: 265 | vid = self.loaders[modality](self.root, vid_name, frame_ids) 266 | inputs.append(vid) 267 | 268 | # transform 269 | if self.transforms is not None: 270 | for i in range(len(self.transforms)): 271 | if self.transforms[i] is not None: 272 | inputs[i] = self.transforms[i](inputs[i]) 273 | 274 | return inputs, label 275 | 276 | def __len__(self): 277 | return len(self.dataset) 278 | -------------------------------------------------------------------------------- /data_pipeline/pku_mmd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """PKU-MMD dataset.""" 17 | 18 | import glob 19 | import numpy as np 20 | import os 21 | import random 22 | import torch.utils.data as data 23 | 24 | import utils 25 | import utils.imgproc as imgproc 26 | 27 | rgb_folder_name = 'rgb' 28 | rgb_pattern = 'RGB-%08d.jpg' 29 | oflow_folder_name = 'oflow' 30 | oflow_pattern = 'OFlow-%08d.jpg' 31 | depth_folder_name = 'depth' 32 | depth_pattern = 'Depth-%08d.png' 33 | skel_folder_name = 'skeleton' 34 | feat_folder_name = 'feature' 35 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 36 | 37 | 38 | def make_dataset(root, evaluation, mode, folder_name=rgb_folder_name, subsample_rate=0): 39 | """ 40 | mode: train (0), test (1) 41 | subsample_rate: Rate of subsampling. 0 if using full dataset. 42 | """ 43 | vid_names = os.listdir(os.path.join(root, folder_name)) 44 | f = open(os.path.join(root, 'split', evaluation + '.txt')).read().split('\n') 45 | idx = 1 if mode == 0 else 3 46 | vid_names_split = f[idx].split(', ')[:-1] 47 | vid_names = sorted(list(set(vid_names).intersection(set(vid_names_split)))) 48 | 49 | if mode == 0 and subsample_rate: 50 | # Subsample for training 51 | vid_names = vid_names[::subsample_rate] 52 | 53 | dataset = [] 54 | for vid_name in vid_names: 55 | label = np.loadtxt( 56 | os.path.join(root, 'label', vid_name + '.txt'), 57 | delimiter=',', 58 | dtype=int) 59 | dataset.append((vid_name, label)) 60 | 61 | utils.info('PKU-MMD: {}, {} videos'.format(evaluation, len(dataset))) 62 | return dataset 63 | 64 | 65 | def rgb_loader(root, vid_name, frame_ids): 66 | vid = [] 67 | for frame_ids_s in frame_ids: 68 | vid_s = [] 69 | for frame_id in frame_ids_s: 70 | path = os.path.join(root, rgb_folder_name, vid_name, 71 | rgb_pattern % (frame_id + 1)) 72 | img = imgproc.imread_rgb('pku-mmd', path) 73 | vid_s.append(img) 74 | vid.append(vid_s) 75 | return np.array(vid) 76 | 77 | 78 | def oflow_loader(root, vid_name, frame_ids): 79 | vid = [] 80 | for frame_ids_s in frame_ids: 81 | vid_s = [] 82 | for frame_id in frame_ids_s: 83 | path = os.path.join(root, oflow_folder_name, vid_name, 84 | oflow_pattern % (frame_id + 1)) 85 | img = imgproc.imread_oflow('pku-mmd', path) 86 | vid_s.append(img) 87 | vid.append(vid_s) 88 | return np.array(vid) 89 | 90 | 91 | def depth_loader(root, vid_name, frame_ids): 92 | vid = [] 93 | for frame_ids_s in frame_ids: 94 | vid_s = [] 95 | for frame_id in frame_ids_s: 96 | path = os.path.join(root, depth_folder_name, vid_name, 97 | depth_pattern % (frame_id + 1)) 98 | img = imgproc.imread_depth('pku-mmd', path) 99 | vid_s.append(img) 100 | vid.append(vid_s) 101 | return np.array(vid) 102 | 103 | 104 | def jjd_loader(root, vid_name, frame_ids): 105 | path = os.path.join(root, skel_folder_name, 'jjd', vid_name + '.npy') 106 | skel = np.load(path).astype(np.float32) 107 | skel = skel[frame_ids] 108 | return skel 109 | 110 | 111 | def jjv_loader(root, vid_name, frame_ids): 112 | path = os.path.join(root, skel_folder_name, 'jjv', vid_name + '.npy') 113 | skel = np.load(path).astype(np.float32) 114 | skel = skel[frame_ids] 115 | return skel 116 | 117 | 118 | def jld_loader(root, vid_name, frame_ids): 119 | path = os.path.join(root, skel_folder_name, 'jld', vid_name + '.npy') 120 | skel = np.load(path).astype(np.float32) 121 | skel = skel[frame_ids] 122 | return skel 123 | 124 | 125 | def get_overlap(a, b): 126 | return max(0, min(a[1], b[1]) - max(a[0], b[0])) 127 | 128 | 129 | class PKU_MMD(data.Dataset): 130 | MEAN_STD = { 131 | 'rgb': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 132 | 'oflow': (0.5, 1. / 255), 133 | 'depth': (4084.1213735 / (255 * 256), 1008.31271366 / (255 * 256)), 134 | 'jjd': (0.53968146, 0.32319776), 135 | 'jjv': (0, 0.35953656), 136 | 'jld': (0.15982792, 0.12776225) 137 | } 138 | 139 | def __init__(self, 140 | root, 141 | mode, 142 | evaluation, 143 | modalities, 144 | step_size, 145 | n_frames, 146 | downsample, 147 | timestep, 148 | transforms=None, 149 | subsample_rate=0): 150 | """PKU_MMD Constructor. 151 | 152 | Contructs the PKU_MMD dataset. 153 | 154 | Args: 155 | root: dataset root 156 | mode: train (0), test (1) 157 | evaluation: one of ['cross-subject', 'cross-view'] 158 | modalities: one of ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 159 | step_size: step size between clips 160 | n_frames: number of frames per clip 161 | transforms: transform. 162 | subsample_rate: sampling rate 163 | downsample: fps /= downsample 164 | timestep: number of clips in a sequence. 165 | """ 166 | modalities = utils.unsqueeze(modalities) 167 | transforms = utils.unsqueeze(transforms) 168 | 169 | loaders = { 170 | 'rgb': rgb_loader, 171 | 'oflow': oflow_loader, 172 | 'depth': depth_loader, 173 | 'jjd': jjd_loader, 174 | 'jjv': jjv_loader, 175 | 'jld': jld_loader 176 | } 177 | self.loaders = loaders 178 | self.dataset = make_dataset( 179 | root, evaluation, mode, subsample_rate=subsample_rate) 180 | 181 | self.root = root 182 | self.modalities = modalities 183 | 184 | self.step_size = step_size 185 | self.n_frames = n_frames 186 | self.downsample = downsample 187 | self.timestep = timestep 188 | self.all = mode != 0 # True if test mode, return the entire video 189 | self.transforms = transforms 190 | 191 | def __getitem__(self, idx): 192 | vid_name, label = self.dataset[idx] 193 | # label: action_id, start_frame, end_frame, confidence 194 | # -1 because len(oflow) = len(rgb)-1 195 | length = len( 196 | glob.glob( 197 | os.path.join(self.root, rgb_folder_name, vid_name, 198 | '*.' + rgb_pattern.split('.')[1]))) - 1 199 | length_ds = length // self.downsample 200 | 201 | if self.all: 202 | # Return entire video 203 | starts = np.arange(0, length_ds - self.n_frames + 1, 204 | self.step_size) # arange: exclusive 205 | else: 206 | start = random.randint( 207 | 0, length_ds - ((self.timestep - 1) * self.step_size + self.n_frames)) 208 | starts = [start + i * self.step_size for i in range(self.timestep) 209 | ] # randint: inclusive 210 | 211 | frame_ids = [] 212 | for start in starts: 213 | frame_ids_s = np.arange(start, start + self.n_frames, 214 | 1) * self.downsample # arange: exclusive 215 | frame_ids.append(frame_ids_s) 216 | frame_ids = np.stack(frame_ids) 217 | 218 | targets = [] 219 | for frame_ids_s in frame_ids: 220 | target = 0 221 | max_ratio = 0.5 222 | for action_id, start_frame, end_frame, _ in label: 223 | overlap = get_overlap([frame_ids_s[0], frame_ids_s[-1] - 1], 224 | [start_frame, end_frame - 1]) 225 | ratio = overlap / (frame_ids_s[-1] - frame_ids_s[0]) 226 | if ratio > max_ratio: 227 | target = int(action_id) 228 | targets.append(target) 229 | targets = np.stack(targets) 230 | 231 | # load raw data 232 | inputs = [] 233 | for modality in self.modalities: 234 | vid = self.loaders[modality](self.root, vid_name, frame_ids) 235 | inputs.append(vid) 236 | 237 | # transform 238 | if self.transforms is not None: 239 | for i, transform in enumerate(self.transforms): 240 | if transform is not None: 241 | inputs[i] = transform(inputs[i]) 242 | 243 | return inputs, targets, vid_name 244 | 245 | def __len__(self): 246 | return len(self.dataset) 247 | -------------------------------------------------------------------------------- /data_pipeline/ucf_101.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """TODO: One-sentence doc string.""" 17 | 18 | import glob 19 | import os 20 | 21 | import numpy as np 22 | import torch.utils.data as data 23 | 24 | import utils 25 | import utils.imgproc as imgproc 26 | 27 | rgb_folder_name = 'jpegs_256' 28 | rgb_pattern = 'frame%06d.jpg' 29 | oflow_folder_name = 'tvl1_flow' 30 | oflow_pattern = 'frame%06d.jpg' 31 | train_split_subpath = 'ucfTrainTestlist/trainlist01.txt' 32 | test_split_subpath = 'ucfTrainTestlist/testlist01.txt' 33 | ALL_MODALITIES = ['rgb', 'oflow'] 34 | 35 | 36 | def find_classes(root): 37 | rgb_folder_path = os.path.join(root, rgb_folder_name) 38 | classes = [ 39 | n.split('_')[1] 40 | for n in os.listdir(rgb_folder_path) 41 | if os.path.isdir(os.path.join(rgb_folder_path, n)) 42 | ] 43 | classes = sorted(list(set(classes))) 44 | class_to_idx = {classes[i]: i for i in range(len(classes))} 45 | 46 | return classes, class_to_idx 47 | 48 | 49 | def make_dataset(root, class_to_idx, train): 50 | dataset = [] 51 | split_subpath = train_split_subpath if train else test_split_subpath 52 | split_path = os.path.join(root, split_subpath) 53 | 54 | with open(split_path) as split_file: 55 | split_lines = split_file.readlines() 56 | 57 | for line in split_lines: 58 | vid_name = line.split()[0].split('/')[1].replace('.avi', '') 59 | class_name = vid_name.split('_')[1] 60 | item = (vid_name, class_to_idx[class_name]) 61 | dataset.append(item) 62 | 63 | return dataset 64 | 65 | 66 | def rgb_loader(root, vid_name, frame_id): 67 | rgb_path = os.path.join(root, rgb_folder_name, vid_name, 68 | rgb_pattern % frame_id) 69 | return imgproc.imread_rgb('ucf-101', rgb_path) 70 | 71 | 72 | def oflow_loader(root, vid_name, frame_id): 73 | oflow_path_u = os.path.join(root, oflow_folder_name, 'u', vid_name, 74 | oflow_pattern % frame_id) 75 | oflow_path_v = os.path.join(root, oflow_folder_name, 'v', vid_name, 76 | oflow_pattern % frame_id) 77 | return imgproc.imread_oflow('ucf-101', oflow_path_u, oflow_path_v) 78 | 79 | 80 | class UCF_101(data.Dataset): 81 | 82 | def __init__(self, 83 | root, 84 | train, 85 | modalities, 86 | n_samples, 87 | n_frames, 88 | transforms=None, 89 | target_transform=None): 90 | classes, class_to_idx = find_classes(root) 91 | dataset = make_dataset(root, class_to_idx, train) 92 | 93 | modalities = utils.unsqueeze(modalities) 94 | transforms = utils.unsqueeze(transforms) 95 | 96 | all_loaders = [rgb_loader, oflow_loader] 97 | all_modalities = ['rgb', 'oflow'] 98 | loaders = [ 99 | all_loaders[i] 100 | for i in range(len(all_loaders)) 101 | if all_modalities[i] in modalities 102 | ] 103 | 104 | assert len(modalities) == len(loaders) 105 | 106 | self.root = root 107 | self.train = train 108 | self.modalities = modalities 109 | self.n_samples = n_samples 110 | self.n_frames = n_frames 111 | self.transforms = transforms 112 | self.target_transform = target_transform 113 | 114 | self.loaders = loaders 115 | 116 | self.classes = classes 117 | self.class_to_idx = class_to_idx 118 | self.dataset = dataset 119 | 120 | def __getitem__(self, idx): 121 | vid_name, target = self.dataset[idx] 122 | length = len( 123 | glob.glob( 124 | os.path.join(self.root, rgb_folder_name, vid_name, 125 | '*.' + rgb_pattern.split('.')[1]))) - 1 126 | 127 | if self.train: 128 | samples = np.random.randint(0, length - self.n_frames + 1, 129 | self.n_samples) # randint: exclusive 130 | else: 131 | if length > self.n_samples: 132 | samples = np.round( 133 | np.linspace(0, length - self.n_frames, 134 | self.n_samples)).astype(int) # linspace: inclusive 135 | else: 136 | samples = np.arange(0, length - self.n_frames + 1) # arange: exclusive 137 | 138 | # load raw data 139 | inputs = [] 140 | for loader in self.loaders: 141 | vid = [] 142 | for s in samples: 143 | clip = [] 144 | for t in range(self.n_frames): 145 | frame_id = s + t + 1 146 | image = loader(self.root, vid_name, frame_id) 147 | clip.append(image) 148 | vid.append(clip) 149 | inputs.append(np.array(vid)) 150 | 151 | # transform 152 | if self.transforms is not None: 153 | for i, transform in enumerate(self.transforms): 154 | if self.transforms[i] is not None: 155 | inputs[i] = transform(inputs[i]) 156 | 157 | return utils.squeeze(inputs), target 158 | 159 | def __len__(self): 160 | return len(self.dataset) 161 | -------------------------------------------------------------------------------- /detection/evaluation/map.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Calculate mAP @ IoU thresholds for detection.""" 17 | 18 | import numpy as np 19 | from ..third_party.pku-mmd.evaluate import process # TODO(alan) please see whether this work 20 | 21 | 22 | def get_segments(scores, activity_threshold): 23 | """Get prediction segments of a video.""" 24 | # Each segment contains start, end, class, confidence score. 25 | # Sum of all probabilities (1 - probability of no-activity) 26 | activity_prob = 1 - scores[:, 0] 27 | # Binary vector indicating whether a clip is an activity or no-activity 28 | activity_tag = np.zeros(activity_prob.shape, dtype=np.int32) 29 | activity_tag[activity_prob >= activity_threshold] = 1 30 | assert activity_tag.ndim == 1 31 | # For each index, subtract the previous index, getting -1, 0, or 1 32 | # 1 indicates the start of a segment, and -1 indicates the end. 33 | padded = np.pad(activity_tag, pad_width=1, mode='constant') 34 | diff = padded[1:] - padded[:-1] 35 | indexes = np.arange(diff.size) 36 | startings = indexes[diff == 1] 37 | endings = indexes[diff == -1] 38 | assert startings.size == endings.size 39 | 40 | segments = [] 41 | for start, end in zip(startings, endings): 42 | segment_scores = scores[start:end, :] 43 | class_prob = np.mean(segment_scores, axis=0) 44 | segment_class_index = np.argmax(class_prob[1:]) + 1 45 | confidence = np.mean(segment_scores[:, segment_class_index]) 46 | segments.append((start, end, segment_class_index, confidence)) 47 | return segments 48 | 49 | 50 | def calc_map(opt, video_scores, video_names, groundtruth_dir, iou_thresholds): 51 | """Get mAP (action) for IoU 0.1, 0.3 and 0.5.""" 52 | activity_threshold = 0.4 53 | num_videos = len(video_scores) 54 | video_files = [name + '.txt' for name in video_names] 55 | 56 | v_props = [] 57 | for i in range(num_videos): 58 | # video_name = video_names[i] 59 | scores = video_scores[i] 60 | segments = get_segments(scores, activity_threshold) 61 | 62 | prop = [] 63 | for segment in segments: 64 | start, end, cls, score = segment 65 | # start, end are indices of clips. Transform to frame index. 66 | start_index = start * opt.step_size * opt.downsample 67 | end_index = ( 68 | (end - 1) * opt.step_size + opt.n_frames) * opt.downsample - 1 69 | prop.append([cls, start_index, end_index, score, video_files[i]]) 70 | v_props.append(prop) 71 | 72 | # Run evaluation on different IoU thresholds. 73 | mean_aps = [] 74 | for iou in iou_thresholds: 75 | mean_ap = process(v_props, video_files, groundtruth_dir, iou) 76 | mean_aps.append(mean_ap) 77 | return mean_aps 78 | -------------------------------------------------------------------------------- /detection/get_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Get data loader for detection.""" 17 | 18 | import os 19 | import torch.utils.data as data 20 | import torchvision.transforms as transforms 21 | 22 | import third_party.two_stream_pytorch.video_transforms as vtransforms 23 | from data_pipeline.pku_mmd import PKU_MMD 24 | 25 | CNN_MODALITIES = ['rgb', 'oflow', 'depth'] 26 | GRU_MODALITIES = ['jjd', 'jjv', 'jld'] 27 | 28 | 29 | def get_dataloader(opt): 30 | idx_t = 0 if opt.split == 'train' else 1 31 | 32 | xforms = [] 33 | for modality in opt.modalities: 34 | if opt.dset == 'pku-mmd': 35 | mean, std = PKU_MMD.MEAN_STD[modality] 36 | else: 37 | raise NotImplementedError 38 | 39 | if opt.split == 'train' and (modality == 'rgb' or modality == 'depth'): 40 | xform = transforms.Compose([ 41 | vtransforms.RandomSizedCrop(224), 42 | vtransforms.RandomHorizontalFlip(), 43 | vtransforms.ToTensor(), 44 | vtransforms.Normalize(mean, std) 45 | ]) 46 | elif opt.split == 'train' and modality == 'oflow': 47 | # Special handling when flipping optical flow 48 | xform = transforms.Compose([ 49 | vtransforms.RandomSizedCrop(224, True), 50 | vtransforms.RandomHorizontalFlip(True), 51 | vtransforms.ToTensor(), 52 | vtransforms.Normalize(mean, std) 53 | ]) 54 | elif opt.split != 'train' and modality in CNN_MODALITIES: 55 | xform = transforms.Compose([ 56 | vtransforms.Scale(256), 57 | vtransforms.CenterCrop(224), 58 | vtransforms.ToTensor(), 59 | vtransforms.Normalize(mean, std) 60 | ]) 61 | elif modality in GRU_MODALITIES: 62 | xform = transforms.Compose([vtransforms.SkelNormalize(mean, std)]) 63 | else: 64 | raise Exception 65 | 66 | xforms.append(xform) 67 | 68 | if opt.dset == 'pku-mmd': 69 | root = os.path.join(opt.dset_path, opt.dset) 70 | dset = PKU_MMD(root, idx_t, 'cross-subject', opt.modalities, opt.step_size, 71 | opt.n_frames, opt.downsample, opt.timestep, xforms, 72 | opt.subsample_rate) 73 | else: 74 | raise NotImplementedError 75 | 76 | dataloader = data.DataLoader( 77 | dset, 78 | batch_size=opt.batch_sizes[idx_t], 79 | shuffle=(opt.split == 'train'), 80 | num_workers=opt.n_workers) 81 | 82 | return dataloader 83 | -------------------------------------------------------------------------------- /detection/get_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Get detection model.""" 17 | 18 | from .model import SingleStream 19 | from .model import GraphDistillation 20 | 21 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 22 | 23 | 24 | def get_model(opt): 25 | if opt.dset == 'pku-mmd': 26 | n_classes = 51 27 | all_input_sizes = [-1, -1, -1, 276, 828, 836] 28 | all_n_channels = [3, 2, 1, -1, -1, -1] 29 | else: 30 | raise NotImplementedError 31 | 32 | n_channels = [all_n_channels[ALL_MODALITIES.index(m)] for m in opt.modalities] 33 | input_sizes = [ 34 | all_input_sizes[ALL_MODALITIES.index(m)] for m in opt.modalities 35 | ] 36 | 37 | if len(opt.modalities) == 1: 38 | # Single stream 39 | index = 0 40 | model = SingleStream(opt.modalities, n_classes, opt.n_frames, n_channels, 41 | input_sizes, opt.hidden_size, opt.n_layers, 42 | opt.dropout, opt.hidden_size_seq, opt.n_layers_seq, 43 | opt.dropout_seq, opt.bg_w, opt.lr, opt.lr_decay_rate, 44 | index, opt.ckpt_path) 45 | else: 46 | index = opt.modalities.index(opt.xfer_to) 47 | model = GraphDistillation( 48 | opt.modalities, n_classes, opt.n_frames, n_channels, input_sizes, 49 | opt.hidden_size, opt.n_layers, opt.dropout, opt.hidden_size_seq, 50 | opt.n_layers_seq, opt.dropout_seq, opt.bg_w, opt.lr, opt.lr_decay_rate, 51 | index, opt.ckpt_path, opt.w_losses, opt.w_modalities, opt.metric, 52 | opt.xfer_to, opt.gd_size, opt.gd_reg) 53 | 54 | return model 55 | -------------------------------------------------------------------------------- /detection/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model to train detection.""" 17 | 18 | from collections import OrderedDict 19 | import numpy as np 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | import torch.optim as optim 23 | 24 | import utils 25 | from nets.get_tad import * 26 | from nets.get_distillation_kernel import * 27 | 28 | CNN_MODALITIES = ['rgb', 'oflow', 'depth'] 29 | GRU_MODALITIES = ['jjd', 'jjv', 'jld'] 30 | 31 | 32 | class BaseModel: 33 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes, 34 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq, 35 | dropout_seq, bg_w, lr, lr_decay_rate, to_idx, ckpt_path): 36 | super(BaseModel, self).__init__() 37 | cudnn.benchmark = True 38 | utils.info('{} modality'.format(modalities[to_idx])) 39 | 40 | self.embeds = [] 41 | for i, m in enumerate(modalities): 42 | encoder_type = 'cnn' if m in CNN_MODALITIES else 'rnn' 43 | embed = nn.DataParallel( 44 | get_tad(n_classes, n_frames, n_channels[i], input_sizes[i], 45 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq, 46 | dropout_seq, encoder_type).cuda()) 47 | self.embeds.append(embed) 48 | 49 | # Multiple optimizers 50 | self.optimizers = [] 51 | self.lr_decay_rates = [] 52 | # Visual encoder: SGD 53 | visual_params = list(self.embeds[to_idx].module.embed.parameters()) 54 | visual_optimizer = optim.SGD( 55 | visual_params, lr=lr, momentum=0.9, weight_decay=5e-4) 56 | self.optimizers.append(visual_optimizer) 57 | self.lr_decay_rates.append(lr_decay_rate) 58 | # Sequence encoder: Adam 59 | sequence_params = list(self.embeds[to_idx].module.rnn.parameters()) + \ 60 | list(self.embeds[to_idx].module.fc.parameters()) 61 | sequence_optimizer = optim.Adam(sequence_params, lr=1e-3) 62 | self.optimizers.append(sequence_optimizer) 63 | self.lr_decay_rates.append(1) # No learning rate decay for Adam 64 | 65 | # Weighted cross-entropy loss 66 | self.criterion_cls = nn.CrossEntropyLoss( 67 | torch.FloatTensor([bg_w] + [1] * n_classes)).cuda() 68 | 69 | self.n_classes = n_classes 70 | self.modalities = modalities 71 | self.to_idx = to_idx 72 | self.ckpt_path = ckpt_path 73 | 74 | def _forward(self, inputs): 75 | """Forward pass for all modalities. 76 | """ 77 | logits, reprs = [], [] 78 | for i in range(len(inputs)): 79 | logit, repr = self.embeds[i](inputs[i]) 80 | logits.append(logit) 81 | reprs.append(repr) 82 | 83 | logits = torch.stack(logits) 84 | reprs = torch.stack(reprs) 85 | return [logits, reprs] 86 | 87 | def _backward(self, results, label): 88 | raise NotImplementedError 89 | 90 | def train(self, inputs, label): 91 | """Train model. 92 | :param inputs: a list, each is batch_size x timestep x n_frames x (n_channels x h x w) or (input_size) 93 | :param label: batch_size x timestep 94 | """ 95 | for embed in self.embeds: 96 | embed.train() 97 | 98 | for i in range(len(inputs)): 99 | inputs[i] = Variable(inputs[i].cuda(), requires_grad=False) 100 | label = Variable(label.cuda(), requires_grad=False) 101 | 102 | results = self._forward(inputs) 103 | info_loss = self._backward(results, label) 104 | info_acc = self._get_acc(results[0], label) 105 | return OrderedDict(info_loss + info_acc) 106 | 107 | def test(self, inputs, label, timestep): 108 | '''Test model. 109 | param timestep: split into segments of length timestep. 110 | ''' 111 | for embed in self.embeds: 112 | embed.eval() 113 | 114 | input = Variable(inputs[0].cuda(), requires_grad=False) 115 | label = Variable(label.cuda(), requires_grad=False) 116 | length = input.size(1) 117 | 118 | # Split video into segments 119 | input, start_indices = utils.get_segments(input, timestep) 120 | inputs = [input] 121 | 122 | logits, _ = self._forward(inputs) 123 | logits = utils.to_numpy(logits).squeeze(0) 124 | all_logits = [[] for i in range(length)] 125 | for i in range(len(start_indices)): 126 | s = start_indices[i] 127 | for j in range(timestep): 128 | all_logits[s + j].append(logits[i][j]) 129 | # Average logits for each time step. 130 | final_logits = np.zeros((length, self.n_classes + 1)) 131 | for i in range(length): 132 | final_logits[i] = np.mean(all_logits[i], axis=0) 133 | logits = final_logits 134 | 135 | info_acc = self._get_acc([torch.Tensor(logits)], label) 136 | scores = utils.softmax(logits, axis=1) 137 | return OrderedDict(info_acc), logits, scores 138 | 139 | def _get_acc(self, logits, label): 140 | """Get detection statistics for modality. 141 | """ 142 | info_acc = [] 143 | for i, m in enumerate(self.modalities): 144 | logit = logits[i].view(-1, self.n_classes + 1) 145 | label = label.view(-1) 146 | stats = utils.get_stats_detection(logit, label, self.n_classes + 1) 147 | info_acc.append(('ap_{}'.format(m), stats[0])) 148 | info_acc.append(('acc_{}'.format(m), stats[1])) 149 | info_acc.append(('acc_bg_{}'.format(m), stats[2])) 150 | info_acc.append(('acc_action_{}'.format(m), stats[3])) 151 | return info_acc 152 | 153 | def save(self, epoch): 154 | path = os.path.join(self.ckpt_path, 'embed_{}.pth'.format(epoch)) 155 | torch.save(self.embeds[self.to_idx].state_dict(), path) 156 | 157 | def load(self, load_ckpt_paths, options, epoch=200): 158 | """Load checkpoints. 159 | """ 160 | assert len(load_ckpt_paths) == len(self.embeds) 161 | for i in range(len(self.embeds)): 162 | ckpt_path = load_ckpt_paths[i] 163 | load_opt = options[i] 164 | if len(ckpt_path) == 0: 165 | utils.info('{}: training from scratch'.format(self.modalities[i])) 166 | continue 167 | 168 | if load_opt == 0: # load teacher model (visual + sequence) 169 | path = os.path.join(ckpt_path, 'embed_{}.pth'.format(epoch)) 170 | ckpt = torch.load(path) 171 | try: 172 | self.embeds[i].load_state_dict(ckpt) 173 | except: 174 | utils.warn('Check that the "modalities" argument is correct.') 175 | exit(0) 176 | utils.info('{}: ckpt {} loaded'.format(self.modalities[i], path)) 177 | elif load_opt == 1: # load pretrained visual encoder 178 | ckpt = torch.load(ckpt_path) 179 | # Change keys in the ckpt 180 | new_state_dict = OrderedDict() 181 | for key in list(ckpt.keys())[:-2]: # exclude fc weights 182 | new_key = key[7:] # Remove 'module.' 183 | new_state_dict[new_key] = ckpt[key] 184 | # update state_dict 185 | state_dict = self.embeds[i].module.embed.state_dict() 186 | state_dict.update(new_state_dict) 187 | self.embeds[i].module.embed.load_state_dict(state_dict) 188 | utils.info('{}: visual encoder from {} loaded'.format( 189 | self.modalities[i], ckpt_path)) 190 | else: 191 | raise NotImplementedError 192 | 193 | def lr_decay(self): 194 | lrs = [] 195 | for optimizer, decay_rate in zip(self.optimizers, self.lr_decay_rates): 196 | for param_group in optimizer.param_groups: 197 | param_group['lr'] *= decay_rate 198 | lrs.append(param_group['lr']) 199 | return lrs 200 | 201 | 202 | class SingleStream(BaseModel): 203 | """Model to train a single modality. 204 | """ 205 | 206 | def __init__(self, *args, **kwargs): 207 | super(SingleStream, self).__init__(*args, **kwargs) 208 | assert len(self.embeds) == 1 209 | 210 | def _backward(self, results, label): 211 | logits, _ = results 212 | logits = logits.view(-1, logits.size(-1)) 213 | 214 | loss = self.criterion_cls(logits, label.view(-1)) 215 | loss.backward() 216 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5) 217 | for optimizer in self.optimizers: 218 | optimizer.step() 219 | optimizer.zero_grad() 220 | 221 | info_loss = [('loss', loss.data[0])] 222 | return info_loss 223 | 224 | 225 | class GraphDistillation(BaseModel): 226 | """Model to train with graph distillation. 227 | 228 | xfer_to is the modality to train. 229 | """ 230 | 231 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes, 232 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq, 233 | dropout_seq, bg_w, lr, lr_decay_rate, to_idx, ckpt_path, 234 | w_losses, w_modalities, metric, xfer_to, gd_size, gd_reg): 235 | super(GraphDistillation, self).__init__(\ 236 | modalities, n_classes, n_frames, n_channels, input_sizes, 237 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq, dropout_seq, 238 | bg_w, lr, lr_decay_rate, to_idx, ckpt_path) 239 | 240 | # Index of the modality to distill 241 | to_idx = self.modalities.index(xfer_to) 242 | from_idx = [x for x in range(len(self.modalities)) if x != to_idx] 243 | assert len(from_idx) >= 1 244 | 245 | # Prior 246 | w_modalities = [w_modalities[i] for i in from_idx 247 | ] # remove modality being transferred to 248 | gd_prior = utils.softmax(w_modalities, 0.25) 249 | # Distillation model 250 | self.distillation_kernel = \ 251 | get_distillation_kernel(n_classes + 1, hidden_size, gd_size, to_idx, from_idx, 252 | gd_prior, gd_reg, w_losses, metric).cuda() 253 | 254 | # Add optimizer to self.optimizers 255 | gd_optimizer = optim.SGD( 256 | self.distillation_kernel.parameters(), 257 | lr=lr, 258 | momentum=0.9, 259 | weight_decay=5e-4) 260 | self.optimizers.append(gd_optimizer) 261 | self.lr_decay_rates.append(lr_decay_rate) 262 | 263 | self.xfer_to = xfer_to 264 | self.to_idx = to_idx 265 | self.from_idx = from_idx 266 | 267 | def _forward(self, inputs): 268 | logits, reprs = super(GraphDistillation, self)._forward(inputs) 269 | n_modalities, batch_size, length, _ = logits.size() 270 | logits = logits.view(n_modalities, batch_size * length, -1) 271 | reprs = reprs.view(n_modalities, batch_size * length, -1) 272 | # Get edge weights of the graph 273 | graph = self.distillation_kernel(logits, reprs) 274 | return logits, reprs, graph 275 | 276 | def _backward(self, results, label): 277 | logits, reprs, graph = results # graph: size (len(from_idx) x batch_size) 278 | label = label.view(-1) 279 | info_loss = [] 280 | 281 | # Classification loss 282 | loss_cls = self.criterion_cls(logits[self.to_idx], label) 283 | # Graph distillation loss 284 | loss_reg, loss_logit, loss_repr = \ 285 | self.distillation_kernel.distillation_loss(logits, reprs, graph) 286 | 287 | loss = loss_cls + loss_reg + loss_logit + loss_repr 288 | loss.backward() 289 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5) 290 | for optimizer in self.optimizers: 291 | optimizer.step() 292 | optimizer.zero_grad() 293 | 294 | info_loss = [('loss_cls', loss_cls.data[0]), ('loss_reg', loss_reg.data[0]), 295 | ('loss_logit', loss_logit.data[0]), ('loss_repr', 296 | loss_repr.data[0])] 297 | return info_loss 298 | -------------------------------------------------------------------------------- /detection/run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Train detection""" 17 | 18 | import argparse 19 | import os 20 | 21 | import utils 22 | import utils.logging as logging 23 | from .get_dataloader import * 24 | from .get_model import * 25 | from .evaluation.map import calc_map 26 | 27 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'] 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | # experimental settings 32 | parser.add_argument('--n_workers', type=int, default=24) 33 | parser.add_argument('--gpus', type=str, default='0') 34 | parser.add_argument('--split', type=str, choices=['train', 'val', 'test']) 35 | 36 | # ckpt and logging 37 | parser.add_argument('--ckpt_path', type=str, default='./ckpt', 38 | help='directory path that stores all checkpoints') 39 | parser.add_argument('--ckpt_name', type=str, default='ckpt') 40 | parser.add_argument('--pretrained_ckpt_name', type=str, default='ckpt', 41 | help='name of the teacher detection models') 42 | parser.add_argument('--load_epoch', type=int, default=400, 43 | help='epoch to load teacher model') 44 | parser.add_argument('--visual_encoder_ckpt_path', type=str, default='', 45 | help='classification checkpoint to initialize' 46 | 'the visual encoder weights') 47 | parser.add_argument('--load_ckpt_path', type=str, default='', 48 | help='checkpoint path to load for testing') 49 | parser.add_argument('--print_every', type=int, default=50) 50 | parser.add_argument('--save_every', type=int, default=50) 51 | 52 | # hyperparameters 53 | parser.add_argument('--batch_sizes', type=int, nargs='+', default=[8, 1], 54 | help='batch sizes: [train, test]') 55 | parser.add_argument('--n_epochs', type=int, default=400) 56 | parser.add_argument('--lr', type=float, default=1e-3) 57 | parser.add_argument('--lr_decay_at', type=int, nargs='+', default=[250, 350]) 58 | parser.add_argument('--lr_decay_rate', type=float, default=0.1) 59 | 60 | # data pipeline 61 | parser.add_argument('--dset', type=str, default='pku-mmd') 62 | parser.add_argument('--dset_path', type=str, 63 | default=os.path.join(os.environ['HOME'], 'slowbro')) 64 | parser.add_argument('--modalities', type=str, nargs='+', 65 | choices=['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']) 66 | parser.add_argument('--step_size', type=int, default=10, 67 | help='step size between samples (after downsample)') 68 | parser.add_argument('--n_frames', type=int, default=10, 69 | help='num frames per clip') 70 | parser.add_argument('--downsample', type=int, default=3, 71 | help='fps /= downsample') 72 | parser.add_argument('--timestep', type=int, default=10, 73 | help='number of clips in a sequence') 74 | parser.add_argument('--bg_w', type=float, default=0.5) 75 | parser.add_argument('--subsample_rate', type=int, default=20, 76 | help='rate to subsample the dataset. 0: False (use full dataset)') 77 | 78 | # visual encoder GRU (for modalities jjd, jjv, jld) 79 | parser.add_argument('--dropout', type=float, default=0.5) 80 | parser.add_argument('--hidden_size', type=int, default=512) 81 | parser.add_argument('--n_layers', type=int, default=3) 82 | # sequence encoder GRU 83 | parser.add_argument('--dropout_seq', type=float, default=0.5) 84 | parser.add_argument('--hidden_size_seq', type=int, default=512) 85 | parser.add_argument('--n_layers_seq', type=int, default=1) 86 | 87 | # Graph Distillation parameters 88 | parser.add_argument('--metric', type=str, default='cosine', 89 | choices=['cosine', 'kl', 'l2', 'l1'], 90 | help='distance metric for distillation loss') 91 | parser.add_argument('--w_losses', type=float, nargs='+', default=[10, 1], 92 | help='weights for losses: [logit, repr]') 93 | parser.add_argument('--w_modalities', type=float, nargs='+', 94 | default=[1, 1, 1, 1, 1, 1], 95 | help='modality prior') 96 | parser.add_argument('--xfer_to', type=str, default='', 97 | help='modality to train with graph distillation') 98 | parser.add_argument('--gd_size', type=int, default=32, 99 | help='hidden size of graph distillation') 100 | parser.add_argument('--gd_reg', type=float, default=10, 101 | help='regularization for graph distillation') 102 | 103 | 104 | def single_stream(opt): 105 | """Train a single modality from scratch.""" 106 | # Checkpoint path example: ckpt_path/pku-mmd/rgb/ckpt 107 | opt.ckpt_path = os.path.join(opt.ckpt_path, opt.dset, 108 | opt.modalities[0], opt.ckpt_name) 109 | os.makedirs(opt.ckpt_path, exist_ok=True) 110 | if opt.split == 'train': 111 | if opt.visual_encoder_ckpt_path != '': 112 | assert os.path.exists(opt.visual_encoder_ckpt_path), \ 113 | '{} does not exist'.format(opt.visual_encoder_ckpt_path) 114 | opt.load_ckpt_paths = [opt.visual_encoder_ckpt_path] 115 | opt.load_opts = [1] # load visual encoder 116 | else: 117 | opt.load_ckpt_paths = [opt.load_ckpt_path] 118 | opt.load_opts = [0] # load visual + sequence encoder 119 | 120 | # Data loader and model 121 | dataloader = get_dataloader(opt) 122 | model = get_model(opt) 123 | if opt.split == 'train': 124 | train(opt, model, dataloader) 125 | else: 126 | test(opt, model, dataloader) 127 | 128 | 129 | def multi_stream(opt): 130 | """Train a modality with graph distillation from other modalities.""" 131 | assert opt.xfer_to in opt.modalities, 'xfer_to must be in opt.modalities' 132 | # Checkpoints to load 133 | opt.load_ckpt_paths = [] 134 | opt.load_opts = [] 135 | for m in opt.modalities: 136 | if m != opt.xfer_to: 137 | # Checkpoint from single_stream 138 | path = os.path.join(opt.ckpt_path, opt.dset, m, opt.pretrained_ckpt_name) 139 | assert os.path.exists(path), '{} checkpoint does not exist.'.format(path) 140 | opt.load_ckpt_paths.append(path) 141 | opt.load_opts.append(0) 142 | else: 143 | opt.load_ckpt_paths.append(opt.visual_encoder_ckpt_path) 144 | opt.load_opts.append(1) 145 | 146 | # Checkpoint path example: ckpt_path/ntu-rgbd/xfer_rgb/ckpt_rgb_depth 147 | opt.ckpt_path = os.path.join( 148 | opt.ckpt_path, opt.dset, 'xfer_{}'.format(opt.xfer_to), '{}_{}'.format( 149 | opt.ckpt_name, '_'.join([m for m in opt.modalities]))) 150 | os.makedirs(opt.ckpt_path, exist_ok=True) 151 | 152 | # Data loader and model 153 | dataloader = get_dataloader(opt) 154 | model = get_model(opt) 155 | train(opt, model, dataloader) 156 | 157 | 158 | def train(opt, model, dataloader): 159 | # Logging 160 | logger = logging.Logger(opt.ckpt_path, opt.split) 161 | stats = logging.Statistics(opt.ckpt_path, opt.split) 162 | logger.log(opt) 163 | 164 | model.load(opt.load_ckpt_paths, opt.load_opts, opt.load_epoch) 165 | for epoch in range(1, opt.n_epochs + 1): 166 | for step, data in enumerate(dataloader, 1): 167 | # inputs is a list of input of each modality 168 | inputs, label, _ = data 169 | ret = model.train(inputs, label) 170 | update = stats.update(len(label), ret) 171 | if utils.is_due(step, opt.print_every): 172 | utils.info('epoch {}/{}, step {}/{}: {}'.format( 173 | epoch, opt.n_epochs, step, len(dataloader), update)) 174 | 175 | logger.log('[Summary] epoch {}/{}: {}'.format(epoch, opt.n_epochs, 176 | stats.summarize())) 177 | 178 | if utils.is_due(epoch, opt.n_epochs, opt.save_every): 179 | model.save(epoch) 180 | stats.save() 181 | logger.log('***** saved *****') 182 | 183 | if utils.is_due(epoch, opt.lr_decay_at): 184 | lrs = model.lr_decay() 185 | logger.log('***** lr decay *****: {}'.format(lrs)) 186 | 187 | 188 | def test(opt, model, dataloader): 189 | # Logging 190 | logger = logging.Logger(opt.ckpt_path, opt.split) 191 | stats = logging.Statistics(opt.ckpt_path, opt.split) 192 | logger.log(opt) 193 | 194 | model.load(opt.load_ckpt_paths, opt.load_opts, opt.load_epoch) 195 | all_scores = [] 196 | video_names = [] 197 | for step, data in enumerate(dataloader, 1): 198 | inputs, label, vid_name = data 199 | info_acc, logits, scores = model.test(inputs, label, opt.timestep) 200 | 201 | all_scores.append(scores) 202 | video_names.append(vid_name[0]) 203 | update = stats.update(logits.shape[0], info_acc) 204 | if utils.is_due(step, opt.print_every): 205 | utils.info('step {}/{}: {}'.format(step, len(dataloader), update)) 206 | 207 | logger.log('[Summary] {}'.format(stats.summarize())) 208 | 209 | # Evaluate 210 | iou_thresholds = [0.1, 0.3, 0.5] 211 | groundtruth_dir = os.path.join(opt.dset_path, opt.dset, 'groundtruth', 212 | 'validation/cross-subject') 213 | assert os.path.exists(groundtruth_dir), '{} does not exist'.format(groundtruth_dir) 214 | mean_aps = calc_map(opt, all_scores, video_names, groundtruth_dir, iou_thresholds) 215 | 216 | for i in range(len(iou_thresholds)): 217 | logger.log('IoU: {}, mAP: {}'.format(iou_thresholds[i], mean_aps[i])) 218 | 219 | 220 | if __name__ == '__main__': 221 | opt = parser.parse_args() 222 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus 223 | 224 | if opt.split == 'test': 225 | assert len(opt.modalities) == 1, 'specify only 1 modality for testing' 226 | assert len(opt.load_ckpt_path) > 0, 'specify load_ckpt_path for testing' 227 | 228 | if len(opt.modalities) == 1: 229 | single_stream(opt) 230 | else: 231 | multi_stream(opt) 232 | -------------------------------------------------------------------------------- /images/pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/graph_distillation/1a7ce5125098e7df869f08e15e2d6d8bb3189382/images/pull.png -------------------------------------------------------------------------------- /nets/get_distillation_kernel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Graph distillation kernel.""" 17 | 18 | import torch 19 | from torch.autograd import Variable 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import utils 23 | 24 | 25 | class DistillationKernel(nn.Module): 26 | """Graph Distillation kernel. 27 | 28 | Calculate the edge weights e_{j->k} for each j. Modality k is specified by 29 | to_idx, and the other modalities are specified by from_idx. 30 | """ 31 | 32 | def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx, 33 | gd_prior, gd_reg, w_losses, metric, alpha): 34 | super(DistillationKernel, self).__init__() 35 | self.W_logit = nn.Linear(n_classes, gd_size) 36 | self.W_repr = nn.Linear(hidden_size, gd_size) 37 | self.W_edge = nn.Linear(gd_size * 4, 1) 38 | 39 | self.gd_size = gd_size 40 | self.to_idx = to_idx 41 | self.from_idx = from_idx 42 | self.alpha = alpha 43 | # For calculating distillation loss 44 | self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda()) 45 | self.gd_reg = gd_reg 46 | self.w_losses = w_losses # [logit weight, repr weight] 47 | self.metric = metric 48 | 49 | 50 | def forward(self, logits, reprs): 51 | """ 52 | Args: 53 | logits: (n_modalities, batch_size, n_classes) 54 | reprs: (n_modalities, batch_siz`, hidden_size) 55 | Return: 56 | edges: weights e_{j->k} (n_modalities_from, batch_size) 57 | """ 58 | n_modalities, batch_size = logits.size()[:2] 59 | z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1)) 60 | z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1)) 61 | z = torch.cat( 62 | (z_logits, z_reprs), dim=1).view(n_modalities, batch_size, 63 | self.gd_size * 2) 64 | 65 | edges = [] 66 | for i in self.from_idx: 67 | # To calculate e_{j->k}, concatenate z^j, z^k 68 | e = self.W_edge(torch.cat((z[self.to_idx], z[i]), dim=1)) 69 | edges.append(e) 70 | edges = torch.cat(edges, dim=1) 71 | edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1) 72 | return edges 73 | 74 | 75 | def distillation_loss(self, logits, reprs, edges): 76 | """Calculate graph distillation losses, which include: 77 | regularization loss, loss for logits, and loss for representation. 78 | """ 79 | # Regularization for graph distillation (average across batch) 80 | loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg 81 | 82 | loss_logit, loss_repr = 0, 0 83 | for i, idx in enumerate(self.from_idx): 84 | w_distill = edges[i] + self.gd_prior[i] # add graph prior 85 | loss_logit += self.w_losses[0] * utils.distance_metric( 86 | logits[self.to_idx], logits[idx], self.metric, w_distill) 87 | loss_repr += self.w_losses[1] * utils.distance_metric( 88 | reprs[self.to_idx], reprs[idx], self.metric, w_distill) 89 | return loss_reg, loss_logit, loss_repr 90 | 91 | 92 | def get_distillation_kernel(n_classes, 93 | hidden_size, 94 | gd_size, 95 | to_idx, 96 | from_idx, 97 | gd_prior, 98 | gd_reg, 99 | w_losses, 100 | metric, 101 | alpha=1 / 8): 102 | return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx, 103 | gd_prior, gd_reg, w_losses, metric, alpha) 104 | -------------------------------------------------------------------------------- /nets/get_gru.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """GRU module.""" 17 | 18 | import torch 19 | from torch.autograd import Variable 20 | import torch.nn as nn 21 | 22 | 23 | class GRU(nn.Module): 24 | """ GRU Class""" 25 | 26 | def __init__(self, input_size, hidden_size, n_layers, dropout, n_classes): 27 | super(GRU, self).__init__() 28 | self.gru = nn.GRU( 29 | input_size=input_size, 30 | hidden_size=hidden_size, 31 | num_layers=n_layers, 32 | batch_first=True, 33 | dropout=dropout) 34 | self.fc = nn.Linear(hidden_size, n_classes) 35 | 36 | self.hidden_size = hidden_size 37 | self.n_layers = n_layers 38 | 39 | def _get_states(self, batch_size): 40 | h0 = Variable( 41 | torch.zeros(self.n_layers, batch_size, self.hidden_size).cuda(), 42 | requires_grad=False) 43 | return h0 44 | 45 | def forward(self, x): 46 | """:param x: input of size batch_size' x n_frames x input_size (batch_size' = batch_size*n_samples) :return: 47 | """ 48 | batch_size = x.size(0) 49 | h0 = self._get_states(batch_size) 50 | x, _ = self.gru(x, h0) 51 | representation = x.mean(1) 52 | logit = self.fc(representation) 53 | return logit, representation 54 | 55 | 56 | def get_gru(input_size, hidden_size, n_layers, dropout, n_classes): 57 | return GRU(input_size, hidden_size, n_layers, dropout, n_classes) 58 | -------------------------------------------------------------------------------- /nets/get_tad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model for temporal action detection.""" 17 | 18 | import torch.nn as nn 19 | from third_party.pytorch.get_cnn import * 20 | from .get_gru import * 21 | 22 | 23 | class TAD(nn.Module): 24 | """Temporal action detection model. 25 | 26 | Consists of visual encoder (specified by encoder_type) and sequence 27 | encoder. 28 | """ 29 | 30 | def __init__(self, n_classes, n_frames, n_channels, input_size, hidden_size, 31 | n_layers, dropout, hidden_size_seq, n_layers_seq, dropout_seq, 32 | encoder_type): 33 | super(TAD, self).__init__() 34 | 35 | # Visual encoder 36 | if encoder_type == 'cnn': 37 | self.embed = get_resnet(n_frames * n_channels, n_classes) 38 | elif encoder_type == 'rnn': 39 | self.embed = get_gru(input_size, hidden_size, n_layers, dropout, 40 | n_classes) 41 | else: 42 | raise NotImplementedError 43 | 44 | # Sequence encoder 45 | self.rnn = nn.GRU( 46 | input_size=hidden_size, 47 | hidden_size=hidden_size_seq, 48 | num_layers=n_layers_seq, 49 | batch_first=True, 50 | dropout=dropout_seq, 51 | bidirectional=True) 52 | # Classification layer 53 | self.fc = nn.Linear(hidden_size_seq, 54 | n_classes + 1) # plus 1 class for background 55 | 56 | self.n_classes = n_classes 57 | self.hidden_size = hidden_size 58 | self.hidden_size_seq = hidden_size_seq 59 | self.n_layers_seq = n_layers_seq 60 | self.encoder_type = encoder_type 61 | 62 | def forward(self, x): 63 | """:param x: if encoder_type == 'cnn', batch_size x timestep x n_frames x n_channels x h x w 64 | 65 | if encoder_type == 'lstm', batch_size x timestep x n_frames x 66 | input_size 67 | :return: representation and logits 68 | """ 69 | if self.encoder_type == 'cnn': 70 | batch_size, timestep, n_frames, n_channels, h, w = x.size() 71 | x = x.view(batch_size * timestep, n_frames * n_channels, h, w) 72 | _, x = self.embed(x) 73 | x = x.view(batch_size, timestep, self.hidden_size) 74 | elif self.encoder_type == 'rnn': 75 | batch_size, timestep, n_frames, input_size = x.size() 76 | x = x.view(batch_size * timestep, n_frames, input_size) 77 | _, x = self.embed(x) 78 | x = x.view(batch_size, timestep, self.hidden_size) 79 | else: 80 | raise NotImplementedError 81 | 82 | batch_size = x.size(0) 83 | x, _ = self.rnn(x) 84 | x = x.contiguous().view(x.size(0), x.size(1), 2, -1).sum(2) 85 | representation = x 86 | logit = self.fc(representation) 87 | 88 | return logit, representation 89 | 90 | 91 | def get_tad(n_classes, n_frames, n_channels, input_size, hidden_size, n_layers, 92 | dropout, hidden_size_seq, n_layers_seq, dropout_seq, encoder_type): 93 | return TAD(n_classes, n_frames, n_channels, input_size, hidden_size, n_layers, 94 | dropout, hidden_size_seq, n_layers_seq, dropout_seq, encoder_type) 95 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/depth.zip 2 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jjd.zip 3 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jjv.zip 4 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jld.zip 5 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/oflow.zip 6 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/rgb.zip 7 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/xfer_depth.zip 8 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/xfer_jjd.zip 9 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/depth.zip 10 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jjd.zip 11 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jjv.zip 12 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jld.zip 13 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/oflow.zip 14 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/rgb.zip 15 | 16 | -------------------------------------------------------------------------------- /scripts/test_ntu_rgbd.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m classification.run \ 18 | --gpus 0 \ 19 | --split test \ 20 | --dset ntu-rgbd \ 21 | --load_ckpt_path ckpt/ntu-rgbd/rgb/sub \ 22 | --modalities rgb 23 | -------------------------------------------------------------------------------- /scripts/test_pku_mmd.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m detection.run \ 18 | --gpus 0,1,2,3 \ 19 | --split test \ 20 | --dset pku-mmd \ 21 | --load_ckpt_path ckpt/pku-mmd/depth/ckpt \ 22 | --modalities depth 23 | -------------------------------------------------------------------------------- /scripts/train_ntu_rgbd.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m classification.run \ 18 | --gpus 0 \ 19 | --split train \ 20 | --dset ntu-rgbd \ 21 | --subsample 33 \ 22 | --modalities rgb 23 | -------------------------------------------------------------------------------- /scripts/train_ntu_rgbd_distillation.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m classification.run \ 18 | --gpus 0,1 \ 19 | --split train \ 20 | --dset ntu-rgbd \ 21 | --subsample 33 \ 22 | --pretrained_ckpt_name sub \ 23 | --xfer_to jjd \ 24 | --modalities rgb oflow depth jjd jjv jld 25 | -------------------------------------------------------------------------------- /scripts/train_pku_mmd.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m detection.run \ 18 | --gpus 3 \ 19 | --split train \ 20 | --dset pku-mmd \ 21 | --visual_encoder_ckpt_path ./ckpt/ntu-rgbd/jld/ckpt/embed_200.pth \ 22 | --subsample_rate 20 \ 23 | --ckpt_name sub \ 24 | --modalities jld 25 | -------------------------------------------------------------------------------- /scripts/train_pku_mmd_distillation.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | #!/bin/bash 17 | python -m detection.run \ 18 | --gpus 0 \ 19 | --split train \ 20 | --dset pku-mmd \ 21 | --visual_encoder_ckpt_path ./ckpt/ntu-rgbd/depth/ckpt/embed_200.pth \ 22 | --xfer_to depth \ 23 | --modalities depth jjd 24 | -------------------------------------------------------------------------------- /third_party/pku_mmd/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2017] [Institute of Computer Science and Technology, Peking University] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /third_party/pku_mmd/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright [2017] [Institute of Computer Science and Technology, Peking University] 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Code adopted and modified from 17 | https://github.com/ECHO960/PKU-MMD/blob/master/evaluate.py 18 | """ 19 | 20 | import argparse 21 | import os 22 | import numpy as np 23 | 24 | number_label = 52 25 | 26 | 27 | # calc_pr: calculate precision and recall 28 | # @positive: number of positive proposal 29 | # @proposal: number of all proposal 30 | # @ground: number of ground truth 31 | def calc_pr(positive, proposal, ground): 32 | if (proposal == 0): 33 | return 0, 0 34 | if (ground == 0): 35 | return 0, 0 36 | return (1.0 * positive) / proposal, (1.0 * positive) / ground 37 | 38 | 39 | # match: match proposal and ground truth 40 | # @lst: list of proposals(label, start, end, confidence, video_name) 41 | # @ratio: overlap ratio 42 | # @ground: list of ground truth(label, start, end, confidence, video_name) 43 | # 44 | # correspond_map: record matching ground truth for each proposal 45 | # count_map: record how many proposals is each ground truth matched by 46 | # index_map: index_list of each video for ground truth 47 | def match(lst, ratio, ground): 48 | 49 | def overlap(prop, ground): 50 | l_p, s_p, e_p, c_p, v_p = prop 51 | l_g, s_g, e_g, c_g, v_g = ground 52 | if (int(l_p) != int(l_g)): 53 | return 0 54 | if (v_p != v_g): 55 | return 0 56 | return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g)) 57 | 58 | cos_map = [-1 for x in range(len(lst))] 59 | count_map = [0 for x in range(len(ground))] 60 | #generate index_map to speed up 61 | index_map = [[] for x in range(number_label)] 62 | for x in range(len(ground)): 63 | index_map[int(ground[x][0])].append(x) 64 | 65 | for x in range(len(lst)): 66 | for y in index_map[int(lst[x][0])]: 67 | if (overlap(lst[x], ground[y]) < ratio): 68 | continue 69 | if (overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]])): 70 | continue 71 | cos_map[x] = y 72 | if (cos_map[x] != -1): 73 | count_map[cos_map[x]] += 1 74 | positive = sum([(x > 0) for x in count_map]) 75 | return cos_map, count_map, positive 76 | 77 | 78 | # Interpolated Average Precision: 79 | # @lst: list of proposals(label, start, end, confidence, video_name) 80 | # @ratio: overlap ratio 81 | # @ground: list of ground truth(label, start, end, confidence, video_name) 82 | # 83 | # score = sigma(precision(recall) * delta(recall)) 84 | # Note that when overlap ratio < 0.5, 85 | # one ground truth will correspond to many proposals 86 | # In that case, only one positive proposal is counted 87 | def ap(lst, ratio, ground): 88 | lst.sort(key=lambda x: x[3]) # sorted by confidence 89 | cos_map, count_map, positive = match(lst, ratio, ground) 90 | score = 0 91 | number_proposal = len(lst) 92 | number_ground = len(ground) 93 | old_precision, old_recall = calc_pr(positive, number_proposal, number_ground) 94 | 95 | for x in range(len(lst)): 96 | number_proposal -= 1 97 | if (cos_map[x] == -1): 98 | continue 99 | count_map[cos_map[x]] -= 1 100 | if (count_map[cos_map[x]] == 0): 101 | positive -= 1 102 | 103 | precision, recall = calc_pr(positive, number_proposal, number_ground) 104 | if precision > old_precision: 105 | old_precision = precision 106 | score += old_precision * (old_recall - recall) 107 | old_recall = recall 108 | return score 109 | 110 | 111 | def process(v_props, video_files, groundtruth_dir, theta): 112 | v_grounds = [] # ground-truth list separated by video 113 | 114 | #========== find all proposals separated by video======== 115 | for video in video_files: 116 | ground = open(os.path.join(groundtruth_dir, video), "r").readlines() 117 | ground = [ground[x].replace(",", " ") for x in range(len(ground))] 118 | ground = [[float(y) for y in ground[x].split()] for x in range(len(ground))] 119 | #append video name 120 | for x in ground: 121 | x.append(video) 122 | v_grounds.append(ground) 123 | 124 | assert len(v_props) == len(v_grounds), "{} != {}".format( 125 | len(v_props), len(v_grounds)) 126 | 127 | #========== find all proposals separated by action categories======== 128 | # proposal list separated by class 129 | a_props = [[] for x in range(number_label)] 130 | # ground-truth list separated by class 131 | a_grounds = [[] for x in range(number_label)] 132 | 133 | for x in range(len(v_props)): 134 | for y in range(len(v_props[x])): 135 | a_props[int(v_props[x][y][0])].append(v_props[x][y]) 136 | 137 | for x in range(len(v_grounds)): 138 | for y in range(len(v_grounds[x])): 139 | a_grounds[int(v_grounds[x][y][0])].append(v_grounds[x][y]) 140 | 141 | #========== find all proposals======== 142 | all_props = sum(a_props, []) 143 | all_grounds = sum(a_grounds, []) 144 | 145 | #========== calculate protocols======== 146 | # mAP_action 147 | aps_action = np.array([ 148 | ap(a_props[x + 1], theta, a_grounds[x + 1]) 149 | for x in range(number_label - 1) 150 | ]) 151 | # mAP_video 152 | aps_video = np.array( 153 | [ap(v_props[x], theta, v_grounds[x]) for x in range(len(v_props))]) 154 | # Return mAP action. 155 | return np.mean(aps_action) 156 | -------------------------------------------------------------------------------- /third_party/pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) Soumith Chintala 2016, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /third_party/pytorch/get_cnn.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) Soumith Chintala 2016, 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # * Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | 20 | """Code for ResNet. 21 | Adapted from 22 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 23 | """ 24 | 25 | import math 26 | 27 | import torch.nn as nn 28 | import torch.utils.model_zoo as model_zoo 29 | from torchvision.models.resnet import BasicBlock 30 | from torchvision.models.vgg import cfg, model_urls, VGG 31 | 32 | 33 | class ResNet(nn.Module): 34 | 35 | def __init__(self, block, layers, in_channels, n_classes): 36 | self.inplanes = 64 37 | super(ResNet, self).__init__() 38 | self.conv1 = nn.Conv2d( 39 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 40 | self.bn1 = nn.BatchNorm2d(64) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 43 | self.layer1 = self._make_layer(block, 64, layers[0]) 44 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 45 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 46 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 47 | self.avgpool = nn.AvgPool2d(7) 48 | self.fc = nn.Linear(512 * block.expansion, n_classes) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | def _make_layer(self, block, planes, blocks, stride=1): 59 | downsample = None 60 | if stride != 1 or self.inplanes != planes * block.expansion: 61 | downsample = nn.Sequential( 62 | nn.Conv2d( 63 | self.inplanes, 64 | planes * block.expansion, 65 | kernel_size=1, 66 | stride=stride, 67 | bias=False), 68 | nn.BatchNorm2d(planes * block.expansion), 69 | ) 70 | 71 | layers = [] 72 | layers.append(block(self.inplanes, planes, stride, downsample)) 73 | self.inplanes = planes * block.expansion 74 | for i in range(1, blocks): 75 | layers.append(block(self.inplanes, planes)) 76 | 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | x = self.conv1(x) 81 | x = self.bn1(x) 82 | x = self.relu(x) 83 | x = self.maxpool(x) 84 | 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | x = self.layer3(x) 88 | conv4 = self.layer4(x) 89 | 90 | repr = self.avgpool(conv4).view(conv4.size(0), -1) 91 | logit = self.fc(repr) 92 | 93 | return logit, repr 94 | 95 | 96 | def resnet18(**kwargs): 97 | """Constructs a ResNet-18 model. 98 | Args: 99 | pretrained (bool): If True, returns a model pre-trained on ImageNet 100 | """ 101 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 102 | return model 103 | 104 | 105 | def get_resnet(in_channels, n_classes): 106 | return resnet18(in_channels=in_channels, n_classes=n_classes) 107 | 108 | 109 | def make_layers(cfg, in_channels=3, batch_norm=False): 110 | layers = [] 111 | for v in cfg: 112 | if v == 'M': 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | if batch_norm: 117 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 118 | else: 119 | layers += [conv2d, nn.ReLU(inplace=True)] 120 | in_channels = v 121 | return nn.Sequential(*layers) 122 | 123 | 124 | def get_vgg(in_channels=3, **kwargs): 125 | model = VGG(make_layers(cfg['D'], in_channels), **kwargs) 126 | return model 127 | -------------------------------------------------------------------------------- /third_party/two_stream_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yi Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/two_stream_pytorch/video_transforms.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2017 Yi Zhu 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | """Transformations for videos. 16 | 17 | Code adapted from 18 | https://github.com/bryanyzhu/two-stream-pytorch/blob/master/video_transforms.py 19 | """ 20 | 21 | import collections 22 | import math 23 | import numbers 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | import torch.nn.functional as F 29 | 30 | import utils 31 | from utils import imgproc 32 | 33 | 34 | class ToTensor(object): 35 | """Converts a numpy.ndarray (... 36 | 37 | x H x W x C) in the range 38 | [0, 255] to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 39 | 1.0]. 40 | """ 41 | 42 | def __init__(self, scale=True, to_float=True): 43 | self.scale = scale 44 | self.to_float = to_float 45 | 46 | def __call__(self, arr): 47 | if isinstance(arr, np.ndarray): 48 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3)) 49 | 50 | if self.to_float: 51 | video = video.float() 52 | 53 | if self.scale: 54 | return video.div(255) 55 | else: 56 | return video 57 | else: 58 | raise NotImplementedError 59 | 60 | 61 | class Normalize(object): 62 | """Given mean and std, 63 | will normalize each channel of the torch.*Tensor, i.e. 64 | channel = (channel - mean) / std 65 | """ 66 | 67 | def __init__(self, mean, std): 68 | if not isinstance(mean, list): 69 | mean = [mean] 70 | if not isinstance(std, list): 71 | std = [std] 72 | 73 | self.mean = torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2) 74 | self.std = torch.FloatTensor(std).unsqueeze(1).unsqueeze(2) 75 | 76 | def __call__(self, tensor): 77 | return tensor.sub_(self.mean).div_(self.std) 78 | 79 | 80 | class Scale(object): 81 | """Rescale the input numpy.ndarray to the given size. 82 | Args: 83 | size (sequence or int): Desired output size. If size is a sequence like 84 | (w, h), output size will be matched to this. If size is an int, 85 | smaller edge of the image will be matched to this number. 86 | i.e, if height > width, then image will be rescaled to 87 | (size * height / width, size) 88 | interpolation (int, optional): Desired interpolation. Default is 89 | ``bilinear`` 90 | """ 91 | def __init__(self, size, transform_pixel=False, interpolation='bilinear'): 92 | """:param size: output size :param transform_pixel: transform pixel values for flow :param interpolation: 'bilinear', 'nearest' 93 | """ 94 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and 95 | len(size) == 2) 96 | self.size = size 97 | self.transform_pixel = transform_pixel 98 | self.interpolation = interpolation 99 | 100 | def __call__(self, video): 101 | """Args: video (numpy.ndarray): Video to be scaled. 102 | 103 | Returns: 104 | numpy.ndarray: Rescaled video. 105 | """ 106 | w, h = video.shape[-2], video.shape[-3] 107 | 108 | if isinstance(self.size, int): 109 | if (w <= h and w == self.size) or (h <= w and h == self.size): 110 | return video 111 | 112 | if w < h: 113 | ow = self.size 114 | oh = int(self.size * h / w) 115 | video = imgproc.resize(video, (oh, ow), self.interpolation) 116 | else: 117 | oh = self.size 118 | ow = int(self.size * w / h) 119 | video = imgproc.resize(video, (oh, ow), self.interpolation) 120 | 121 | if self.transform_pixel: 122 | video[..., 0] = (video[..., 0] - 128) * (ow / w) + 128 123 | video[..., 1] = (video[..., 1] - 128) * (oh / h) + 128 124 | else: 125 | video = imgproc.resize(video, self.size, self.interpolation) 126 | 127 | if self.transform_pixel: 128 | video[..., 0] = (video[..., 0] - 128) * (self.size / w) + 128 129 | video[..., 1] = (video[..., 1] - 128) * (self.size / h) + 128 130 | 131 | return video 132 | 133 | 134 | class CenterCrop(object): 135 | """Crops the given numpy.ndarray at the center to have a region of 136 | the given size. size can be a tuple (target_height, target_width) 137 | or an integer, in which case the target will be of a square shape (size, size) 138 | """ 139 | 140 | def __init__(self, size): 141 | if isinstance(size, numbers.Number): 142 | self.size = (int(size), int(size)) 143 | else: 144 | self.size = size 145 | 146 | def __call__(self, video): 147 | h, w = video.shape[-3:-1] 148 | th, tw = self.size 149 | x1 = int(round((w - tw) / 2.)) 150 | y1 = int(round((h - th) / 2.)) 151 | 152 | return video[..., y1:y1 + th, x1:x1 + tw, :] 153 | 154 | 155 | class Pad(object): 156 | """Pad the given np.ndarray on all sides with the given "pad" value. 157 | 158 | Args: padding (int or sequence): Padding on each border. If a sequence of 159 | length 4, it is used to pad left, top, right and bottom borders respectively. 160 | fill: Pixel fill value. Default is 0. 161 | """ 162 | 163 | def __init__(self, padding, fill=0): 164 | assert isinstance(padding, numbers.Number) 165 | assert isinstance(fill, numbers.Number) or isinstance( 166 | fill, str) or isinstance(fill, tuple) 167 | self.padding = padding 168 | self.fill = fill 169 | 170 | def __call__(self, video): 171 | """Args: video (np.ndarray): Video to be padded. 172 | 173 | Returns: 174 | np.ndarray: Padded video. 175 | """ 176 | pad_width = ((0, 0), (self.padding, self.padding), (self.padding, 177 | self.padding), (0, 0)) 178 | return np.pad( 179 | video, pad_width=pad_width, mode='constant', constant_values=self.fill) 180 | 181 | 182 | class RandomCrop(object): 183 | """Crop the given numpy.ndarray at a random location. 184 | Args: 185 | size (sequence or int): Desired output size of the crop. If size is an 186 | int instead of sequence like (h, w), a square crop (size, size) is 187 | made. 188 | padding (int or sequence, optional): Optional padding on each border 189 | of the image. Default is 0, i.e no padding. If a sequence of length 190 | 4 is provided, it is used to pad left, top, right, bottom borders 191 | respectively. 192 | """ 193 | 194 | def __init__(self, size, padding=0): 195 | if isinstance(size, numbers.Number): 196 | self.size = (int(size), int(size)) 197 | else: 198 | self.size = size 199 | self.padding = padding 200 | 201 | def __call__(self, video): 202 | """Args: video (np.ndarray): Video to be cropped. 203 | 204 | Returns: 205 | np.ndarray: Cropped video. 206 | """ 207 | if self.padding > 0: 208 | pad = Pad(self.padding, 0) 209 | video = pad(video) 210 | 211 | w, h = video.shape[-2], video.shape[-3] 212 | th, tw = self.size 213 | if w == tw and h == th: 214 | return video 215 | 216 | x1 = random.randint(0, w - tw) 217 | y1 = random.randint(0, h - th) 218 | return video[..., y1:y1 + th, x1:x1 + tw, :] 219 | 220 | 221 | class RandomHorizontalFlip(object): 222 | """Randomly horizontally flips the given numpy.ndarray with a probability of 0.5 223 | 224 | """ 225 | 226 | def __init__(self, transform_pixel=False): 227 | """:param transform_pixel: transform pixel values for flow 228 | """ 229 | self.transform_pixel = transform_pixel if isinstance( 230 | transform_pixel, list) else [transform_pixel] 231 | 232 | def __call__(self, videos): 233 | """Support joint transform 234 | :param videos: np.ndarray or a list of np.ndarray 235 | :return: 236 | """ 237 | if random.random() < 0.5: 238 | videos = utils.unsqueeze(videos) 239 | ret = [] 240 | for tp, video in zip(self.transform_pixel, videos): 241 | video = video[..., ::-1, :] 242 | if tp: 243 | video[..., 0] = 255 - video[..., 0] 244 | ret.append(video.copy()) 245 | return utils.squeeze(ret) 246 | else: 247 | return videos 248 | 249 | 250 | class RandomSizedCrop(object): 251 | """Crop the given np.ndarray to random size and aspect ratio. 252 | A crop of random size of (0.4 to 1.0) of the original size and a random 253 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 254 | is finally resized to given size. 255 | This is popularly used to train the Inception networks. 256 | """ 257 | 258 | def __init__(self, size, transform_pixel=False): 259 | """:param size: size of the smaller edge :param transform_pixel: transform pixel values for flow 260 | """ 261 | self.size = size 262 | self.transform_pixel = transform_pixel if isinstance( 263 | transform_pixel, list) else [transform_pixel] 264 | 265 | def __call__(self, videos): 266 | """Support joint transform 267 | :param videos: np.ndarray or a list of np.ndarray 268 | :return: 269 | """ 270 | videos = utils.unsqueeze(videos) 271 | h_orig, w_orig = videos[0].shape[-3:-1] 272 | 273 | for attempt in range(10): 274 | ret = [] 275 | 276 | area = h_orig * w_orig 277 | target_area = random.uniform(0.4, 1.0) * area 278 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 279 | 280 | w = int(round(math.sqrt(target_area * aspect_ratio))) 281 | h = int(round(math.sqrt(target_area / aspect_ratio))) 282 | 283 | if random.random() < 0.5: 284 | w, h = h, w 285 | 286 | if w <= w_orig and h <= h_orig: 287 | x1 = random.randint(0, w_orig - w) 288 | y1 = random.randint(0, h_orig - h) 289 | 290 | for tp, video in zip(self.transform_pixel, videos): 291 | video = video[..., y1:y1 + h, x1:x1 + w, :] 292 | video = imgproc.resize(video, (self.size, self.size), 'bilinear') 293 | if tp: 294 | video[..., 0] = (video[..., 0] - 128) * (self.size / w) + 128 295 | video[..., 1] = (video[..., 1] - 128) * (self.size / h) + 128 296 | 297 | ret.append(video) 298 | 299 | return utils.squeeze(ret) 300 | 301 | # Fallback 302 | ret = [] 303 | scales = [Scale(self.size, tp, 'bilinear') for tp in self.transform_pixel] 304 | crop = CenterCrop(self.size) 305 | for scale, video in zip(scales, videos): 306 | video = crop(scale(video)) 307 | ret.append(video) 308 | 309 | return utils.squeeze(ret) 310 | 311 | 312 | class SkelNormalize(object): 313 | """Given mean and std, will normalize the numpy array 314 | """ 315 | 316 | def __init__(self, mean, std): 317 | self.mean = mean 318 | self.std = std 319 | 320 | def __call__(self, skel): 321 | return (skel - self.mean) / self.std 322 | 323 | 324 | class AvgPool(object): 325 | """Rescale the input by performing average pooling The height and width are scaled down by a factor of kernel_size 326 | """ 327 | 328 | def __init__(self, kernel_size): 329 | self.kernel_size = kernel_size 330 | 331 | def __call__(self, x): 332 | batch_size, n_samples, n_channels, h, w = x.size() 333 | x = x.view(-1, n_channels, h, w) 334 | x = F.avg_pool2d(x, self.kernel_size, stride=self.kernel_size).data 335 | return x.view(batch_size, n_samples, *x.size()[-3:]) 336 | 337 | 338 | class Clip(object): 339 | """Clip values of the numpy array 340 | """ 341 | 342 | def __init__(self, lower, upper): 343 | self.lower = lower 344 | self.upper = upper 345 | 346 | def __call__(self, x): 347 | return np.clip(x, self.lower, self.upper) 348 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/imgproc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for reading and processing data.""" 17 | 18 | import cv2 19 | import numpy as np 20 | from scipy import interpolate 21 | from scipy.misc import imresize 22 | 23 | 24 | def imread_rgb(dset, path): 25 | if dset == 'ucf-101': 26 | rgb = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 27 | return rgb[:, :-1] # oflow is 1px smaller than rgb in ucf-101 28 | elif dset == 'ntu-rgbd' or dset == 'pku-mmd' or dset == 'cad-60': 29 | rgb = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 30 | return rgb 31 | else: 32 | assert False 33 | 34 | 35 | def imread_oflow(dset, *paths): 36 | if dset == 'ucf-101': 37 | path_u, path_v = paths 38 | oflow_u = cv2.imread(path_u, cv2.IMREAD_GRAYSCALE) 39 | oflow_v = cv2.imread(path_v, cv2.IMREAD_GRAYSCALE) 40 | oflow = np.stack((oflow_u, oflow_v), axis=2) 41 | return oflow 42 | elif dset == 'ntu-rgbd' or dset == 'pku-mmd' or dset == 'cad-60': 43 | path = paths[0] 44 | oflow = cv2.imread(path)[..., ::-1][..., :2] 45 | return oflow 46 | else: 47 | assert False 48 | 49 | 50 | def imread_depth(dset, path): 51 | # dset == 'ntu-rgbd' or dset == 'pku-mmd' 52 | depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, np.newaxis] 53 | depth = np.clip(depth/256, 0, 255).astype(np.uint8) 54 | return depth 55 | 56 | 57 | def inpaint(img, threshold=1): 58 | h, w = img.shape[:2] 59 | 60 | if len(img.shape) == 3: # RGB 61 | mask = np.all(img == 0, axis=2).astype(np.uint8) 62 | img = cv2.inpaint(img, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA) 63 | 64 | else: # depth 65 | mask = np.where(img > threshold) 66 | xx, yy = np.meshgrid(np.arange(w), np.arange(h)) 67 | xym = np.vstack((np.ravel(xx[mask]), np.ravel(yy[mask]))).T 68 | img = np.ravel(img[mask]) 69 | interp = interpolate.NearestNDInterpolator(xym, img) 70 | img = interp(np.ravel(xx), np.ravel(yy)).reshape(xx.shape) 71 | 72 | return img 73 | 74 | 75 | def resize(video, size, interpolation): 76 | """ 77 | :param video: ... x h x w x num_channels 78 | :param size: (h, w) 79 | :param interpolation: 'bilinear', 'nearest' 80 | :return: 81 | """ 82 | shape = video.shape[:-3] 83 | num_channels = video.shape[-1] 84 | video = video.reshape((-1, *video.shape[-3:])) 85 | resized_video = np.zeros((video.shape[0], *size, video.shape[-1])) 86 | 87 | for i in range(video.shape[0]): 88 | if num_channels == 3: 89 | resized_video[i] = imresize(video[i], size, interpolation) 90 | elif num_channels == 2: 91 | resized_video[i, ..., 0] = imresize(video[i, ..., 0], size, interpolation) 92 | resized_video[i, ..., 1] = imresize(video[i, ..., 1], size, interpolation) 93 | elif num_channels == 1: 94 | resized_video[i, ..., 0] = imresize(video[i, ..., 0], size, interpolation) 95 | else: 96 | raise NotImplementedError 97 | 98 | return resized_video.reshape((*shape, *size, video.shape[-1])) 99 | 100 | 101 | def proc_oflow(images): 102 | h, w = images.shape[-3:-1] 103 | 104 | processed_images = [] 105 | for image in images: 106 | hsv = np.zeros((h, w, 3), dtype=np.uint8) 107 | hsv[:, :, 0] = 255 108 | hsv[:, :, 1] = 255 109 | 110 | mag, ang = cv2.cartToPolar(image[..., 0], image[..., 1]) 111 | hsv[..., 0] = ang*180/np.pi/2 112 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 113 | 114 | processed_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 115 | processed_images.append(processed_image) 116 | 117 | return np.stack(processed_images) 118 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Calculate result statistics, logging.""" 17 | 18 | from collections import OrderedDict 19 | import logging 20 | import os 21 | import sys 22 | import numpy as np 23 | 24 | 25 | class _AverageMeter(object): 26 | """ Average Meter Class.""" 27 | 28 | def __init__(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val*n 37 | self.count += n 38 | self.avg = self.sum/self.count 39 | 40 | 41 | class Statistics(object): 42 | """ Statistics Class.""" 43 | 44 | def __init__(self, ckpt_path=None, name='history'): 45 | self.meters = OrderedDict() 46 | self.history = OrderedDict() 47 | self.ckpt_path = ckpt_path 48 | self.name = name 49 | 50 | def update(self, n, ordered_dict): 51 | info = '' 52 | for key in ordered_dict: 53 | if key not in self.meters: 54 | self.meters.update({key: _AverageMeter()}) 55 | self.meters[key].update(ordered_dict[key], n) 56 | info += '{key}={var.val:.4f}, avg {key}={var.avg:.4f}, '.format( 57 | key=key, var=self.meters[key]) 58 | 59 | return info[:-2] 60 | 61 | def summarize(self, reset=True): 62 | info = '' 63 | for key in self.meters: 64 | info += '{key}={var:.4f}, '.format(key=key, var=self.meters[key].avg) 65 | 66 | if reset: 67 | self.reset() 68 | 69 | return info[:-2] 70 | 71 | def reset(self): 72 | for key in self.meters: 73 | if key in self.history: 74 | self.history[key].append(self.meters[key].avg) 75 | else: 76 | self.history.update({key: [self.meters[key].avg]}) 77 | 78 | self.meters = OrderedDict() 79 | 80 | def load(self): 81 | self.history = np.load( 82 | os.path.join(self.ckpt_path, '{}.npy'.format(self.name))).item() 83 | 84 | def save(self): 85 | np.save( 86 | os.path.join(self.ckpt_path, '{}.npy'.format(self.name)), self.history) 87 | 88 | 89 | class Logger(object): 90 | """ Logger Class.""" 91 | 92 | def __init__(self, path, name='debug'): 93 | self.logger = logging.getLogger() 94 | self.logger.setLevel(logging.INFO) 95 | formatter = logging.Formatter( 96 | '%(asctime)s %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 97 | 98 | fh = logging.FileHandler(os.path.join(path, '{}.log'.format(name)), 'w') 99 | fh.setLevel(logging.INFO) 100 | fh.setFormatter(formatter) 101 | self.logger.addHandler(fh) 102 | 103 | ch = logging.StreamHandler(sys.stdout) 104 | ch.setLevel(logging.INFO) 105 | ch.setFormatter(formatter) 106 | self.logger.addHandler(ch) 107 | 108 | def log(self, info): 109 | self.logger.info(info) 110 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Miscellaneous utility functions.""" 17 | 18 | import numpy as np 19 | from sklearn.metrics import average_precision_score 20 | import torch 21 | import torch.nn.functional as F 22 | import torch.utils.data 23 | 24 | 25 | def to_numpy(array): 26 | if isinstance(array, np.ndarray): 27 | return array 28 | if isinstance(array, torch.autograd.Variable): 29 | array = array.data 30 | if array.is_cuda: 31 | array = array.cpu() 32 | 33 | return array.numpy() 34 | 35 | 36 | def squeeze(array): 37 | if not isinstance(array, list) or len(array) > 1: 38 | return array 39 | else: # len(array) == 1: 40 | return array[0] 41 | 42 | 43 | def unsqueeze(array): 44 | if isinstance(array, list): 45 | return array 46 | else: 47 | return [array] 48 | 49 | 50 | def is_due(*args): 51 | """Determines whether to perform an action or not, depending on the epoch. 52 | Used for logging, saving, learning rate decay, etc. 53 | 54 | Args: 55 | *args: epoch, due_at (due at epoch due_at) epoch, num_epochs, 56 | due_every (due every due_every epochs) 57 | step, due_every (due every due_every steps) 58 | Returns: 59 | due: boolean: perform action or not 60 | """ 61 | if len(args) == 2 and isinstance(args[1], list): 62 | epoch, due_at = args 63 | due = epoch in due_at 64 | elif len(args) == 3: 65 | epoch, num_epochs, due_every = args 66 | due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs) 67 | else: 68 | step, due_every = args 69 | due = (due_every > 0) and (step % due_every == 0) 70 | 71 | return due 72 | 73 | 74 | def softmax(w, t=1.0, axis=None): 75 | w = np.array(w) / t 76 | e = np.exp(w - np.amax(w, axis=axis, keepdims=True)) 77 | dist = e / np.sum(e, axis=axis, keepdims=True) 78 | return dist 79 | 80 | 81 | def distance_metric(student, teacher, option, weights=None): 82 | """Distance metric to calculate the imitation loss. 83 | 84 | Args: 85 | student: batch_size x n_classes 86 | teacher: batch_size x n_classes 87 | option: one of [cosine, l2, l2, kl] 88 | weights: batch_size or float 89 | 90 | Returns: 91 | The computed distance metric. 92 | """ 93 | if option == 'cosine': 94 | dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1) 95 | elif option == 'l2': 96 | dists = (student-teacher.detach()).pow(2).sum(1) 97 | elif option == 'l1': 98 | dists = torch.abs(student-teacher.detach()).sum(1) 99 | elif option == 'kl': 100 | assert weights is None 101 | T = 8 102 | # averaged for each minibatch 103 | dist = F.kl_div( 104 | F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * ( 105 | T * T) 106 | return dist 107 | else: 108 | raise NotImplementedError 109 | 110 | if weights is None: 111 | dist = dists.mean() 112 | else: 113 | dist = (dists * weights).mean() 114 | 115 | return dist 116 | 117 | 118 | def get_segments(input, timestep): 119 | """Split entire input into segments of length timestep. 120 | 121 | Args: 122 | input: 1 x total_length x n_frames x ... 123 | timestep: the timestamp. 124 | 125 | Returns: 126 | input: concatenated video segments 127 | start_indices: indices of the segments 128 | """ 129 | assert input.size(0) == 1, 'Test time, batch_size must be 1' 130 | 131 | input.squeeze_(dim=0) 132 | # Find overlapping segments 133 | length = input.size()[0] 134 | step = timestep // 2 135 | num_segments = (length - timestep) // step + 1 136 | start_indices = (np.arange(num_segments) * step).tolist() 137 | if length % step > 0: 138 | start_indices.append(length - timestep) 139 | 140 | # Get the segments 141 | segments = [] 142 | for s in start_indices: 143 | segment = input[s: (s + timestep)].unsqueeze(0) 144 | segments.append(segment) 145 | input = torch.cat(segments, dim=0) 146 | return input, start_indices 147 | 148 | def get_stats(logit, label): 149 | ''' 150 | Calculate the accuracy. 151 | ''' 152 | logit = to_numpy(logit) 153 | label = to_numpy(label) 154 | 155 | pred = np.argmax(logit, 1) 156 | acc = np.sum(pred == label)/label.shape[0] 157 | 158 | return acc, pred, label 159 | 160 | 161 | def get_stats_detection(logit, label, n_classes=52): 162 | ''' 163 | Calculate the accuracy and average precisions. 164 | ''' 165 | logit = to_numpy(logit) 166 | label = to_numpy(label) 167 | scores = softmax(logit, axis=1) 168 | 169 | pred = np.argmax(logit, 1) 170 | length = label.shape[0] 171 | acc = np.sum(pred == label)/length 172 | 173 | keep_bg = label == 0 174 | acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0] 175 | ratio_bg = np.sum(keep_bg)/length 176 | 177 | keep_action = label != 0 178 | acc_action = np.sum( 179 | pred[keep_action] == label[keep_action]) / label[keep_action].shape[0] 180 | 181 | # Average precision 182 | y_true = np.zeros((len(label), n_classes)) 183 | y_true[np.arange(len(label)), label] = 1 184 | acc = np.sum(pred == label)/label.shape[0] 185 | aps = average_precision_score(y_true, scores, average=None) 186 | aps = list(filter(lambda x: not np.isnan(x), aps)) 187 | ap = np.mean(aps) 188 | 189 | return ap, acc, acc_bg, acc_action, ratio_bg, pred, label 190 | 191 | 192 | def info(text): 193 | print('\033[94m' + text + '\033[0m') 194 | 195 | 196 | def warn(text): 197 | print('\033[93m' + text + '\033[0m') 198 | 199 | 200 | def err(text): 201 | print('\033[91m' + text + '\033[0m') 202 | -------------------------------------------------------------------------------- /utils/skelproc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Process Skeleton Feature.""" 17 | 18 | import numpy as np 19 | import utils 20 | 21 | 22 | def read_skel(dset, path): 23 | """ 24 | :param dset: name of dataset, either 'ntu-rgbd' or 'pku-mmd' 25 | :param path: path to the skeleton file 26 | :return: 27 | """ 28 | if dset == 'ntu-rgbd': 29 | file = open(path, 'r') 30 | lines = file.readlines() 31 | num_lines = len(lines) 32 | num_frames = int(lines[0]) 33 | # print(num_lines, num_frames) 34 | 35 | line_id = 1 36 | data = [] 37 | for i in range(num_frames): 38 | num_skels = int(lines[line_id]) 39 | # print(num_skels) 40 | 41 | joints = [] 42 | for _ in range(num_skels): 43 | num_joints = int(lines[line_id+2]) 44 | # print(num_joints) 45 | 46 | joint = [] 47 | for k in range(num_joints): 48 | tmp = lines[line_id+3+k].rstrip().split(' ') 49 | x_3d, y_3d, z_3d, x_depth, y_depth, x_rgb, y_rgb, orientation_w,\ 50 | orientation_x, orientation_y, orientation_z = list( 51 | map(float, tmp[:-1])) 52 | joint.append([x_3d, y_3d, z_3d]) 53 | joints.append(joint) 54 | line_id += 2+num_joints 55 | joints = np.array(joints) 56 | data.append(joints) 57 | line_id += 1 58 | 59 | assert line_id == num_lines 60 | 61 | elif dset == 'pku-mmd': 62 | file = open(path, 'r') 63 | lines = file.readlines() 64 | # num_lines = len(lines) 65 | 66 | data = [] 67 | for line in lines: 68 | joints = list(map(float, line.rstrip().split(' '))) 69 | joints = np.array(joints).reshape(2, -1, 3) 70 | 71 | if not np.any(joints[1]): 72 | joints = joints[0][np.newaxis, :, :] 73 | 74 | data.append(joints) 75 | 76 | elif dset == 'cad-60': 77 | f = open(path, 'r') 78 | lines = f.readlines() 79 | data = [] 80 | 81 | # Last line is "END" 82 | for line in lines[:-1]: 83 | # fist item is frame number, last item is empty 84 | row = line.split(',')[1:-1] 85 | row = list(map(float, row)) 86 | joints = [] 87 | for i in range(15): 88 | if i < 11: 89 | # First 11 joints 90 | index = 14 * i + 10 91 | else: 92 | # Joint 12 ~ 15 93 | index = 11 * 14 + (i - 11) * 4 94 | joint = row[index: index+3] 95 | joints.append(joint) 96 | joints = np.array(joints) / 1000.0 # millimeter to meter 97 | joints = joints[np.newaxis, :, :] # To match ntu-rgb format 98 | data.append(joints) 99 | 100 | else: 101 | raise NotImplementedError 102 | 103 | return data 104 | 105 | 106 | def flip_skel(skel, dset): 107 | """processed skel (normalized and center shifted to the origin).""" 108 | # Shape: (N x NUM_JOINTS x 3) 109 | if dset == 'cad-60': 110 | num_joints = 15 111 | assert skel.ndim == 3 and skel.shape[1] == num_joints 112 | assert np.sum(np.mean(skel, axis=(0, 1))) < 1e-8, 'Skeleton not centered.' 113 | new_skel = skel.copy() 114 | # Head, neck, torso 115 | new_skel[:, 0, 0] = -skel[:, 0, 0] 116 | new_skel[:, 1, 0] = -skel[:, 1, 0] 117 | new_skel[:, 2, 0] = -skel[:, 2, 0] 118 | # Shoulder 119 | new_skel[:, 3, 0] = -skel[:, 5, 0] 120 | new_skel[:, 5, 0] = -skel[:, 3, 0] 121 | new_skel[:, 3, 1:] = skel[:, 5, 1:] 122 | new_skel[:, 5, 1:] = skel[:, 3, 1:] 123 | # elbow 124 | new_skel[:, 4, 0] = -skel[:, 6, 0] 125 | new_skel[:, 6, 0] = -skel[:, 4, 0] 126 | new_skel[:, 4, 1:] = skel[:, 6, 1:] 127 | new_skel[:, 6, 1:] = skel[:, 4, 1:] 128 | # hip 129 | new_skel[:, 7, 0] = -skel[:, 9, 0] 130 | new_skel[:, 9, 0] = -skel[:, 7, 0] 131 | new_skel[:, 7, 1:] = skel[:, 9, 1:] 132 | new_skel[:, 9, 1:] = skel[:, 7, 1:] 133 | # knee 134 | new_skel[:, 8, 0] = -skel[:, 10, 0] 135 | new_skel[:, 10, 0] = -skel[:, 8, 0] 136 | new_skel[:, 8, 1:] = skel[:, 10, 1:] 137 | new_skel[:, 10, 1:] = skel[:, 8, 1:] 138 | # hand 139 | new_skel[:, 11, 0] = -skel[:, 12, 0] 140 | new_skel[:, 12, 0] = -skel[:, 11, 0] 141 | new_skel[:, 11, 1:] = skel[:, 12, 1:] 142 | new_skel[:, 12, 1:] = skel[:, 11, 1:] 143 | # foot 144 | new_skel[:, 13, 0] = -skel[:, 14, 0] 145 | new_skel[:, 14, 0] = -skel[:, 13, 0] 146 | new_skel[:, 13, 1:] = skel[:, 14, 1:] 147 | new_skel[:, 14, 1:] = skel[:, 13, 1:] 148 | return new_skel 149 | 150 | 151 | def pad_skel(skel, axis=1): 152 | if skel.shape[axis] == 1: 153 | skel = np.repeat(skel, 2, axis=axis) 154 | return skel 155 | 156 | 157 | def extract(skel): 158 | """Extract. timestep x 2 x num_joints (25) x 3""" 159 | timestep = skel.shape[0] 160 | keep_joints = [1, 4, 6, 8, 10, 12, 14, 16, 18, 20, 21] 161 | skel = skel[:, :, [keep_joint-1 for keep_joint in keep_joints]] 162 | skel = skel.reshape(timestep, -1, 3) # timestep x 22 x 3 163 | num_joints = len(keep_joints) 164 | 165 | jjd, jjv = [], [] 166 | for t in range(timestep): 167 | jjd_t, jjv_t = [], [] 168 | for i in range(num_joints): 169 | for j in range(i, num_joints, 1): 170 | # joint-joint distance 171 | jjd_t.append(utils.l2_norm(skel[t, i], skel[t, j])) 172 | 173 | # joint-joint vector 174 | jjv_t.append(skel[t, i]-skel[t, j]) 175 | jjd.append(jjd_t) 176 | jjv.append(jjv_t) 177 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """TODO: One-sentence doc string.""" 17 | 18 | import os 19 | import cv2 20 | import numpy as np 21 | import utils 22 | from utils import imgproc 23 | 24 | 25 | def visualize_rgb(images): 26 | """Visualize RGB modality.""" 27 | images = utils.to_numpy(images) 28 | 29 | mean = np.array([0.485, 0.456, 0.406]) 30 | std = np.array([0.229, 0.224, 0.225]) 31 | images = np.moveaxis(images, -3, -1) 32 | images = images*std+mean 33 | images = np.clip(images*255, 0, 255) 34 | images = images[..., ::-1].astype(np.uint8) 35 | images = images[0, 0] # subsample 36 | 37 | imgproc.save_avi('/home/luoa/research/rgb.avi', images) 38 | 39 | 40 | def visualize_oflow(images): 41 | """Visualize optical flow modality.""" 42 | images = utils.to_numpy(images) 43 | 44 | images = np.moveaxis(images, -3, -1) 45 | images = images[0, 0] # subsample 46 | 47 | images = imgproc.proc_oflow(images) 48 | imgproc.save_avi('/home/luoa/research/oflow.avi', images) 49 | 50 | 51 | def visualize_warp(rgb, oflow): 52 | """TODO: add info.""" 53 | rgb = utils.to_numpy(rgb) 54 | oflow = utils.to_numpy(oflow) 55 | 56 | mean = np.array([0.485, 0.456, 0.406]) 57 | std = np.array([0.229, 0.224, 0.225]) 58 | rgb = np.moveaxis(rgb, -3, -1) 59 | rgb = rgb*std+mean 60 | rgb = np.clip(rgb*255, 0, 255) 61 | bgr = rgb[..., ::-1].astype(np.uint8) 62 | bgr = bgr[0, 0] # subsample 63 | print(bgr.shape, np.amin(bgr), np.amax(bgr), np.mean(bgr), 64 | np.mean(np.absolute(bgr))) 65 | 66 | oflow = np.moveaxis(oflow, -3, -1) 67 | oflow = oflow[0, 0] # subsample 68 | print(oflow.shape, np.amin(oflow), np.amax(oflow), np.mean(oflow), 69 | np.mean(np.absolute(oflow))) 70 | 71 | warp = imgproc.warp(bgr[4], bgr[5], oflow[4]) 72 | 73 | root = '/home/luoa/research' 74 | cv2.imwrite(os.path.join(root, 'bgr1.jpg'), bgr[4]) 75 | cv2.imwrite(os.path.join(root, 'bgr2.jpg'), bgr[5]) 76 | cv2.imwrite(os.path.join(root, 'warp.jpg'), warp) 77 | --------------------------------------------------------------------------------