├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── classification
├── get_dataloader.py
├── get_model.py
├── model.py
└── run.py
├── data_pipeline
├── __init__.py
├── ntu_rgbd.py
├── pku_mmd.py
└── ucf_101.py
├── detection
├── evaluation
│ └── map.py
├── get_dataloader.py
├── get_model.py
├── model.py
└── run.py
├── images
└── pull.png
├── nets
├── get_distillation_kernel.py
├── get_gru.py
└── get_tad.py
├── scripts
├── download_models.sh
├── test_ntu_rgbd.sh
├── test_pku_mmd.sh
├── train_ntu_rgbd.sh
├── train_ntu_rgbd_distillation.sh
├── train_pku_mmd.sh
└── train_pku_mmd_distillation.sh
├── third_party
├── pku_mmd
│ ├── LICENSE
│ └── evaluate.py
├── pytorch
│ ├── LICENSE
│ └── get_cnn.py
└── two_stream_pytorch
│ ├── LICENSE
│ └── video_transforms.py
└── utils
├── __init__.py
├── imgproc.py
├── logging.py
├── misc.py
├── skelproc.py
└── visualize.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright 2018 Google LLC
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | https://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Distillation
2 |
3 |
4 |
5 | This is the code for the paper
6 | **[Graph Distillation for Action Detection with Privileged Modalities](https://arxiv.org/abs/1712.00108)**
7 | presented at [ECCV 2018](https://eccv2018.org/)
8 |
9 | *Please note that this is not an officially supported Google product.*
10 |
11 | In this work, we propose a method termed **graph distillation** that incorporates rich privileged information from a large-scale multi- modal dataset in the source domain, and improves the learning in the target domain where training data and modalities are scarce.
12 |
13 | If you find this code useful in your research then please cite
14 |
15 | ```
16 | @inproceedings{luo2018graph,
17 | title={Graph Distillation for Action Detection with Privileged Modalities},
18 | author={Luo, Zelun and Hsieh, Jun-Ting and Jiang, Lu and Niebles, Juan Carlos and Fei-Fei, Li},
19 | booktitle={ECCV},
20 | year={2018}
21 | }
22 | ```
23 |
24 | ## Setup
25 | All code was developed and tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.3.1.
26 |
27 |
28 |
29 |
30 | ## Pretrained Models
31 | We can download pretrained models used in our paper running the script:
32 |
33 | ```
34 | sh scripts/download_models.sh
35 | ```
36 |
37 | Or alternatively you can download Cloud SDK
38 |
39 | 1. Install Google Cloud SDK (https://cloud.google.com/sdk/install)
40 | 2. Copy the pretrained model using the following commands:
41 |
42 | ```
43 | gsutil -m cp -r gs://graph_distillation/ckpt .
44 | ```
45 |
46 |
47 | ## Running Models
48 | We can use the scripts in `scripts/` to train models on different modalities.
49 |
50 |
51 |
52 | ### Classification
53 | See `classification/run.py` for descriptions of the arguments.
54 |
55 | `scripts/train_ntu_rgbd.sh` trains a model for a single modality.
56 |
57 | `scripts/train_ntu_rgbd_distillation.sh` trains model with graph distillation. The modality being trained is specified by the `xfer_to` argument, and the modalities to distill from is specified in the `modalities` argument.
58 |
59 | ### Detection
60 | See `detection/run.py` for descriptions of the arguments. Note that the `visual_encoder_ckpt_path` argument is the pretrained visual encoder checkpoint, which should be from training classification models.
61 |
62 | `scripts/train_pku_mmd.sh` trains a model for a single modality.
63 |
64 | `scripts/train_pku_mmd_distillation.sh` trains model with graph distillation. The modality being trained is specified by the `xfer_to` argument, and the modalities to distill from is specified in the `modalities` argument.
65 |
--------------------------------------------------------------------------------
/classification/get_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Get data loader for classification."""
17 |
18 | import os
19 | from data_pipeline.ntu_rgbd import NTU_RGBD
20 | import third_party.two_stream_pytorch.video_transforms as vtransforms
21 | import torch.utils.data as data
22 | import torchvision.transforms as transforms
23 |
24 |
25 | CNN_MODALITIES = ['rgb', 'oflow', 'depth']
26 | GRU_MODALITIES = ['jjd', 'jjv', 'jld']
27 |
28 |
29 | def get_dataloader(opt):
30 | """Constructs dataset with transforms, and wrap it in a data loader."""
31 | idx_t = 0 if opt.split == 'train' else 1
32 |
33 | xforms = []
34 | for modality in opt.modalities:
35 | if opt.dset == 'ntu-rgbd':
36 | mean, std = NTU_RGBD.MEAN_STD[modality]
37 | else:
38 | raise NotImplementedError
39 |
40 | if opt.split == 'train' and (modality == 'rgb' or modality == 'depth'):
41 | xform = transforms.Compose([
42 | vtransforms.RandomSizedCrop(224),
43 | vtransforms.RandomHorizontalFlip(),
44 | vtransforms.ToTensor(),
45 | vtransforms.Normalize(mean, std)
46 | ])
47 | elif opt.split == 'train' and modality == 'oflow':
48 | # Special handling when flipping optical flow
49 | xform = transforms.Compose([
50 | vtransforms.RandomSizedCrop(224, True),
51 | vtransforms.RandomHorizontalFlip(True),
52 | vtransforms.ToTensor(),
53 | vtransforms.Normalize(mean, std)
54 | ])
55 | elif opt.split != 'train' and modality in CNN_MODALITIES:
56 | xform = transforms.Compose([
57 | vtransforms.Scale(256),
58 | vtransforms.CenterCrop(224),
59 | vtransforms.ToTensor(),
60 | vtransforms.Normalize(mean, std)
61 | ])
62 | elif modality in GRU_MODALITIES:
63 | xform = transforms.Compose([vtransforms.SkelNormalize(mean, std)])
64 | else:
65 | raise Exception
66 |
67 | xforms.append(xform)
68 |
69 | if opt.dset == 'ntu-rgbd':
70 | root = os.path.join(opt.dset_path, opt.dset)
71 | dset = NTU_RGBD(root, opt.split, 'cross-subject', opt.modalities,
72 | opt.n_samples[idx_t], opt.n_frames, opt.downsample, xforms,
73 | opt.subsample)
74 | else:
75 | raise NotImplementedError
76 |
77 | dataloader = data.DataLoader(
78 | dset,
79 | batch_size=opt.batch_sizes[idx_t],
80 | shuffle=(opt.split == 'train'),
81 | num_workers=opt.n_workers)
82 |
83 | return dataloader
84 |
--------------------------------------------------------------------------------
/classification/get_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Get classification model."""
17 |
18 | from .model import GraphDistillation
19 | from .model import SingleStream
20 |
21 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
22 |
23 |
24 | def get_model(opt):
25 | """Get model given the dataset and modalities."""
26 | if opt.dset == 'ntu-rgbd':
27 | n_classes = 60
28 | all_input_sizes = [-1, -1, -1, 276, 828, 836]
29 | all_n_channels = [3, 2, 1, -1, -1, -1]
30 | else:
31 | raise NotImplementedError
32 |
33 | n_channels = [all_n_channels[ALL_MODALITIES.index(m)] for m in opt.modalities]
34 | input_sizes = [
35 | all_input_sizes[ALL_MODALITIES.index(m)] for m in opt.modalities
36 | ]
37 |
38 | if len(opt.modalities) == 1:
39 | # Single stream
40 | model = SingleStream(opt.modalities, n_classes, opt.n_frames, n_channels,
41 | input_sizes, opt.hidden_size, opt.n_layers,
42 | opt.dropout, opt.lr, opt.lr_decay_rate, opt.ckpt_path)
43 | else:
44 | model = GraphDistillation(
45 | opt.modalities, n_classes, opt.n_frames, n_channels, input_sizes,
46 | opt.hidden_size, opt.n_layers, opt.dropout, opt.lr, opt.lr_decay_rate,
47 | opt.ckpt_path, opt.w_losses, opt.w_modalities, opt.metric, opt.xfer_to,
48 | opt.gd_size, opt.gd_reg)
49 |
50 | return model
51 |
--------------------------------------------------------------------------------
/classification/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Model to train classification."""
17 |
18 | from collections import OrderedDict
19 | import os
20 | from third_party.pytorch.get_cnn import *
21 | from nets.get_distillation_kernel import *
22 | from nets.get_gru import *
23 | from torch.autograd import Variable
24 | import torch.backends.cudnn as cudnn
25 | import torch.nn.functional as F
26 | import torch.optim as optim
27 | import utils
28 |
29 |
30 | CNN_MODALITIES = ['rgb', 'oflow', 'depth']
31 | GRU_MODALITIES = ['jjd', 'jjv', 'jld']
32 |
33 |
34 | class BaseModel:
35 | """Base class for the model."""
36 |
37 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes,
38 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path):
39 | super(BaseModel, self).__init__()
40 | cudnn.benchmark = True
41 |
42 | self.embeds = []
43 | for _, (modality, n_channels_m, input_size) in enumerate(
44 | zip(modalities, n_channels, input_sizes)):
45 | if modality in CNN_MODALITIES:
46 | self.embeds.append(
47 | nn.DataParallel(
48 | get_resnet(n_frames * n_channels_m, n_classes).cuda()))
49 | elif modality in GRU_MODALITIES:
50 | self.embeds.append(
51 | nn.DataParallel(
52 | get_gru(input_size, hidden_size, n_layers, dropout,
53 | n_classes).cuda()))
54 | else:
55 | raise NotImplementedError
56 |
57 | self.optimizer = None
58 | self.criterion_cls = nn.CrossEntropyLoss().cuda()
59 |
60 | self.modalities = modalities
61 | self.lr = lr
62 | self.lr_decay_rate = lr_decay_rate
63 | self.ckpt_path = ckpt_path
64 |
65 | def _forward(self, inputs):
66 | """Forward pass for all modalities. Return the representation and logits."""
67 | logits, reprs = [], []
68 |
69 | # Forward pass for all modalities
70 | for i, (input, embed, modality) in enumerate(
71 | zip(list(inputs), self.embeds, self.modalities)):
72 | if modality in CNN_MODALITIES:
73 | batch_size, n_samples, n_frames, n_channels, h, w = input.size()
74 | input = input.view(batch_size * n_samples, n_frames * n_channels, h, w)
75 | elif modality in GRU_MODALITIES:
76 | batch_size, n_samples, n_frames, input_size = input.size()
77 | input = input.view(batch_size * n_samples, n_frames, input_size)
78 | else:
79 | raise NotImplementedError
80 |
81 | logit, representation = embed(input)
82 | logit = logit.view(batch_size, n_samples, -1)
83 | representation = representation.view(batch_size, n_samples, -1)
84 |
85 | logits.append(logit.mean(1))
86 | reprs.append(representation.mean(1))
87 |
88 | logits = torch.stack(logits)
89 | reprs = torch.stack(reprs)
90 | return [logits, reprs]
91 |
92 | def _backward(self, results, label):
93 | raise NotImplementedError
94 |
95 | def train(self, inputs, label):
96 | """Train the model.
97 |
98 | Args:
99 | inputs: a list, each is batch_size x n_sample x n_frames x
100 | (n_channels x h x w) or (input_size).
101 | label: batch_size x n_samples.
102 | Returns:
103 | info: dictionary of results
104 | """
105 | for embed in self.embeds:
106 | embed.train()
107 |
108 | for i in range(len(inputs)):
109 | inputs[i] = Variable(inputs[i].cuda(), requires_grad=False)
110 | label = Variable(label.cuda(), requires_grad=False)
111 |
112 | results = self._forward(inputs)
113 | info_loss = self._backward(results, label)
114 | info_acc = self._get_acc(results[0], label)
115 |
116 | return OrderedDict(info_loss + info_acc)
117 |
118 | def test(self, inputs, label):
119 | """Test the model.
120 |
121 | Args:
122 | inputs: a list, each is batch_size x n_sample x n_frames x
123 | (n_channels x h x w) or (input_size).
124 | label: batch_size x n_samples.
125 | Returns:
126 | info_acc: dictionary of results
127 | """
128 | for embed in self.embeds:
129 | embed.eval()
130 |
131 | inputs = [Variable(x.cuda(), volatile=True) for x in inputs]
132 | label = Variable(label.cuda(), volatile=True)
133 |
134 | logits, _ = self._forward(inputs)
135 | info_acc = self._get_acc(logits, label)
136 |
137 | return OrderedDict(info_acc), logits
138 |
139 | def _get_acc(self, logits, label):
140 | info_acc = []
141 | for _, (logit, modality) in enumerate(zip(logits, self.modalities)):
142 | acc, _, label = utils.get_stats(logit, label)
143 | info_acc.append(('acc_{}'.format(modality), acc))
144 | return info_acc
145 |
146 | def lr_decay(self):
147 | lrs = []
148 | for param_group in self.optimizer.param_groups:
149 | param_group['lr'] *= self.lr_decay_rate
150 | lrs.append(param_group['lr'])
151 | return lrs
152 |
153 | def save(self, epoch):
154 | path = os.path.join(self.ckpt_path, 'embed_{}.pth'.format(epoch))
155 | torch.save(self.embeds[self.to_idx].state_dict(), path)
156 |
157 | def load(self, load_ckpt_paths, epoch=200):
158 | """Load trained models."""
159 | assert len(load_ckpt_paths) == len(self.embeds)
160 | for i, ckpt_path in enumerate(load_ckpt_paths):
161 | if len(ckpt_path) > 0:
162 | path = os.path.join(ckpt_path, 'embed_{}.pth'.format(epoch))
163 | self.embeds[i].load_state_dict(torch.load(path))
164 | utils.info('{}: ckpt {} loaded'.format(self.modalities[i], path))
165 | else:
166 | utils.info('{}: training from scratch'.format(self.modalities[i]))
167 |
168 |
169 | class SingleStream(BaseModel):
170 | """Model to train a single modality."""
171 |
172 | def __init__(self, *args, **kwargs):
173 | super(SingleStream, self).__init__(*args, **kwargs)
174 | assert len(self.embeds) == 1
175 | self.optimizer = optim.SGD(
176 | self.embeds[0].parameters(),
177 | lr=self.lr,
178 | momentum=0.9,
179 | weight_decay=5e-4)
180 | self.to_idx = 0
181 |
182 | def _backward(self, results, label):
183 | logits, _ = results
184 | logits = logits.view(*logits.size()[1:])
185 | loss = self.criterion_cls(logits, label)
186 | loss.backward()
187 | self.optimizer.step()
188 | self.optimizer.zero_grad()
189 |
190 | info_loss = [('loss', loss.data[0])]
191 | return info_loss
192 |
193 |
194 | class GraphDistillation(BaseModel):
195 | """Model to train with graph distillation.
196 |
197 | xfer_to is the modality to train.
198 | """
199 |
200 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes,
201 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path,
202 | w_losses, w_modalities, metric, xfer_to, gd_size, gd_reg):
203 | super(GraphDistillation, self).__init__( \
204 | modalities, n_classes, n_frames, n_channels, input_sizes,
205 | hidden_size, n_layers, dropout, lr, lr_decay_rate, ckpt_path)
206 |
207 | # Index of the modality to distill
208 | to_idx = self.modalities.index(xfer_to)
209 | from_idx = [x for x in range(len(self.modalities)) if x != to_idx]
210 | assert len(from_idx) >= 1
211 |
212 | # Prior
213 | w_modalities = [w_modalities[i] for i in from_idx
214 | ] # remove modality being transferred to
215 | gd_prior = utils.softmax(w_modalities, 0.25)
216 | # Distillation model
217 | self.distillation_kernel = get_distillation_kernel(
218 | n_classes, hidden_size, gd_size, to_idx, from_idx, gd_prior, gd_reg,
219 | w_losses, metric).cuda()
220 |
221 | params = list(self.embeds[to_idx].parameters()) + \
222 | list(self.distillation_kernel.parameters())
223 | self.optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=5e-4)
224 |
225 | self.xfer_to = xfer_to
226 | self.to_idx = to_idx
227 | self.from_idx = from_idx
228 |
229 | def _forward(self, inputs):
230 | logits, reprs = super(GraphDistillation, self)._forward(inputs)
231 | # Get edge weights of the graph
232 | graph = self.distillation_kernel(logits, reprs)
233 | return logits, reprs, graph
234 |
235 | def _backward(self, results, label):
236 | logits, reprs, graph = results # graph: size = len(from_idx) x batch_size
237 | info_loss = []
238 |
239 | # Classification loss
240 | loss_cls = self.criterion_cls(logits[self.to_idx], label)
241 | # Graph distillation loss
242 | loss_reg, loss_logit, loss_repr = \
243 | self.distillation_kernel.distillation_loss(logits, reprs, graph)
244 |
245 | loss = loss_cls + loss_reg + loss_logit + loss_repr
246 | loss.backward()
247 | if self.xfer_to in GRU_MODALITIES:
248 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5)
249 | self.optimizer.step()
250 | self.optimizer.zero_grad()
251 |
252 | info_loss = [('loss_cls', loss_cls.data[0]), ('loss_reg', loss_reg.data[0]),
253 | ('loss_logit', loss_logit.data[0]), ('loss_repr',
254 | loss_repr.data[0])]
255 | return info_loss
256 |
--------------------------------------------------------------------------------
/classification/run.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Train and test classification."""
17 |
18 | import argparse
19 | import os
20 | from .get_dataloader import *
21 | from .get_model import *
22 | import numpy as np
23 | from sklearn.metrics import average_precision_score
24 | import utils
25 | import utils.logging as logging
26 |
27 | parser = argparse.ArgumentParser()
28 |
29 | # experimental settings
30 | parser.add_argument('--n_workers', type=int, default=24)
31 | parser.add_argument('--gpus', type=str, default='0')
32 | parser.add_argument('--split', type=str, choices=['train', 'test'])
33 |
34 | # ckpt and logging
35 | parser.add_argument('--ckpt_path', type=str, default='./ckpt',
36 | help='directory path that stores all checkpoints')
37 | parser.add_argument('--ckpt_name', type=str, default='ckpt')
38 | parser.add_argument('--pretrained_ckpt_name', type=str, default='ckpt',
39 | help='prefix of checkpoints used for graph distillation')
40 | parser.add_argument('--load_ckpt_path', type=str, default='',
41 | help='checkpoint path to load for testing/initialization')
42 | parser.add_argument('--load_epoch', type=int, default=200,
43 | help='Checkpoint epoch to load for testing.')
44 | parser.add_argument('--print_every', type=int, default=50)
45 | parser.add_argument('--save_every', type=int, default=50)
46 |
47 | # hyperparameters
48 | parser.add_argument('--batch_sizes', type=int, nargs='+', default=[64, 8],
49 | help='batch sizes: [train, test]')
50 | parser.add_argument('--n_epochs', type=int, default=200)
51 | parser.add_argument('--lr', type=float, default=1e-2)
52 | parser.add_argument('--lr_decay_at', type=int, nargs='+', default=[125, 175])
53 | parser.add_argument('--lr_decay_rate', type=float, default=0.1)
54 |
55 | # data pipeline
56 | parser.add_argument('--dset', type=str, default='ntu-rgbd')
57 | parser.add_argument('--dset_path', type=str,
58 | default=os.path.join(os.environ['HOME'], 'slowbro'))
59 | parser.add_argument('--modalities', type=str, nargs='+',
60 | choices=['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'])
61 | parser.add_argument('--n_samples', type=int, nargs='+', default=[1, 5],
62 | help='Number of samples clips per video: [train, test]')
63 | parser.add_argument('--step_size', type=int, default=10,
64 | help='step size between samples (after downsample)')
65 | parser.add_argument('--n_frames', type=int, default=10,
66 | help='num frames per sample')
67 | parser.add_argument('--downsample', type=int, default=3,
68 | help='fps /= downsample')
69 | parser.add_argument('--subsample', type=int, default=33,
70 | help='subsample the dataset. 0: False, >0:'
71 | 'number of examples per class')
72 |
73 | # GRU
74 | parser.add_argument('--dropout', type=float, default=0.5)
75 | parser.add_argument('--hidden_size', type=int, default=512)
76 | parser.add_argument('--n_layers', type=int, default=3)
77 |
78 | # Graph Distillation parameters
79 | parser.add_argument('--metric', type=str, default='cosine',
80 | choices=['cosine', 'kl', 'l2', 'l1'],
81 | help='distance metric for distillation loss')
82 | parser.add_argument('--w_losses', type=float, nargs='+', default=[10, 1],
83 | help='weights for losses: [logit, repr]')
84 | parser.add_argument('--w_modalities', type=float, nargs='+',
85 | default=[1, 1, 1, 1, 1, 1],
86 | help='modality prior')
87 | parser.add_argument('--xfer_to', type=str, default='',
88 | help='modality to train with graph distillation')
89 | parser.add_argument('--gd_size', type=int, default=32,
90 | help='hidden size of graph distillation')
91 | parser.add_argument('--gd_reg', type=float, default=10,
92 | help='regularization for graph distillation')
93 |
94 |
95 | def single_stream(opt):
96 | """Train a single modality from scratch."""
97 | # Checkpoint path example: ckpt_path/ntu-rgbd/rgb/ckpt
98 | opt.ckpt_path = os.path.join(opt.ckpt_path, opt.dset, opt.modalities[0],
99 | opt.ckpt_name)
100 | opt.load_ckpt_paths = [opt.load_ckpt_path]
101 | os.makedirs(opt.ckpt_path, exist_ok=True)
102 |
103 | # Data loader and model
104 | dataloader = get_dataloader(opt)
105 | model = get_model(opt)
106 | if opt.split == 'train':
107 | train(opt, model, dataloader)
108 | else:
109 | test(opt, model, dataloader)
110 |
111 |
112 | def multi_stream(opt):
113 | """Train a modality with graph distillation from other modalities.
114 |
115 | The modality is specified by opt.xfer_to
116 | """
117 | assert opt.xfer_to in opt.modalities, 'opt.xfer_to must be in opt.modalities'
118 | # Checkpoints to load
119 | opt.load_ckpt_paths = []
120 | for m in opt.modalities:
121 | if m != opt.xfer_to:
122 | # Checkpoint from single_stream
123 | path = os.path.join(opt.ckpt_path, opt.dset, m, opt.pretrained_ckpt_name)
124 | assert os.path.exists(path), '{} checkpoint does not exist.'.format(path)
125 | opt.load_ckpt_paths.append(path)
126 | else:
127 | opt.load_ckpt_paths.append(opt.load_ckpt_path)
128 |
129 | # Checkpoint path example: ckpt_path/ntu-rgbd/xfer_rgb/ckpt_rgb_depth
130 | opt.ckpt_path = os.path.join(
131 | opt.ckpt_path, opt.dset, 'xfer_{}'.format(opt.xfer_to), '{}_{}'.format(
132 | opt.ckpt_name, '_'.join([m for m in opt.modalities])))
133 | os.makedirs(opt.ckpt_path, exist_ok=True)
134 |
135 | # Data loader and model
136 | dataloader = get_dataloader(opt)
137 | model = get_model(opt)
138 | train(opt, model, dataloader)
139 |
140 |
141 | def train(opt, model, dataloader):
142 | """Train the model."""
143 | # Logging
144 | logger = logging.Logger(opt.ckpt_path, opt.split)
145 | stats = logging.Statistics(opt.ckpt_path, opt.split)
146 | logger.log(opt)
147 |
148 | model.load(opt.load_ckpt_paths, opt.load_epoch)
149 | for epoch in range(1, opt.n_epochs + 1):
150 | for step, data in enumerate(dataloader, 1):
151 | ret = model.train(*data)
152 | update = stats.update(data[-1].size(0), ret)
153 | if utils.is_due(step, opt.print_every):
154 | utils.info('epoch {}/{}, step {}/{}: {}'.format(
155 | epoch, opt.n_epochs, step, len(dataloader), update))
156 |
157 | logger.log('[Summary] epoch {}/{}: {}'.format(epoch, opt.n_epochs,
158 | stats.summarize()))
159 |
160 | if utils.is_due(epoch, opt.n_epochs, opt.save_every):
161 | model.save(epoch)
162 | logger.log('***** saved *****')
163 |
164 | if utils.is_due(epoch, opt.lr_decay_at):
165 | lrs = model.lr_decay()
166 | logger.log('***** lr decay *****: {}'.format(lrs))
167 |
168 |
169 | def test(opt, model, dataloader):
170 | '''Test model.'''
171 | # Logging
172 | logger = logging.Logger(opt.load_ckpt_path, opt.split)
173 | stats = logging.Statistics(opt.ckpt_path, opt.split)
174 | logger.log(opt)
175 |
176 | logits, labels = [], []
177 | model.load(opt.load_ckpt_paths, opt.load_epoch)
178 | for step, data in enumerate(dataloader, 1):
179 | inputs, label = data
180 | info_acc, logit = model.test(inputs, label)
181 | logits.append(utils.to_numpy(logit.squeeze(0)))
182 | labels.append(utils.to_numpy(label))
183 | update = stats.update(label.size(0), info_acc)
184 | if utils.is_due(step, opt.print_every):
185 | utils.info('step {}/{}: {}'.format(step, len(dataloader), update))
186 |
187 | logits = np.concatenate(logits, axis=0)
188 | length, n_classes = logits.shape
189 | labels = np.concatenate(labels)
190 | scores = utils.softmax(logits, axis=1)
191 |
192 | # Accuracy
193 | preds = np.argmax(scores, axis=1)
194 | acc = np.sum(preds == labels) / length
195 | # Average precision
196 | y_true = np.zeros((length, n_classes))
197 | y_true[np.arange(length), labels] = 1
198 | aps = average_precision_score(y_true, scores, average=None)
199 | aps = list(filter(lambda x: not np.isnan(x), aps))
200 | mAP = np.mean(aps)
201 |
202 | logger.log('[Summary]: {}'.format(stats.summarize()))
203 | logger.log('Acc: {}, mAP: {}'.format(acc, mAP))
204 |
205 |
206 | if __name__ == '__main__':
207 | opt = parser.parse_args()
208 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
209 |
210 | if opt.split == 'test':
211 | assert len(opt.modalities) == 1, 'specify only 1 modality for testing'
212 | assert len(opt.load_ckpt_path) > 0, 'specify load_ckpt_path for testing'
213 |
214 | if len(opt.modalities) == 1:
215 | single_stream(opt)
216 | else:
217 | multi_stream(opt)
218 |
--------------------------------------------------------------------------------
/data_pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/graph_distillation/1a7ce5125098e7df869f08e15e2d6d8bb3189382/data_pipeline/__init__.py
--------------------------------------------------------------------------------
/data_pipeline/ntu_rgbd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """NTU RGB-D dataset."""
17 |
18 |
19 | import glob
20 | import os
21 | import numpy as np
22 | import torch.utils.data as data
23 |
24 | import utils
25 | import utils.imgproc as imgproc
26 |
27 |
28 | rgb_folder_name = 'rgb'
29 | rgb_pattern = 'RGB-%08d.jpg'
30 | oflow_folder_name = 'oflow'
31 | oflow_pattern = 'OFlow-%08d.jpg'
32 | depth_folder_name = 'depth'
33 | depth_pattern = 'Depth-%08d.png'
34 | skel_folder_name = 'skeleton'
35 | feat_folder_name = 'feature'
36 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
37 | train_cross_subject = [
38 | 1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38
39 | ]
40 | train_cross_view = [2, 3]
41 | n_classes = 60
42 |
43 |
44 | def make_dataset(root, has_skel, evaluation, split, subsample):
45 | """Returns a list of (video_name, class)."""
46 | if has_skel or split == 'test' or subsample:
47 | vid_names = [os.path.splitext(n)[0] \
48 | for n in glob.glob1(os.path.join(root, skel_folder_name, 'jjd')
49 | , '*.npy')]
50 | else:
51 | vid_names = os.listdir(os.path.join(root, rgb_folder_name))
52 | vid_names = sorted(vid_names)
53 |
54 | if evaluation == 'cross-subject' and split != 'test':
55 | vid_names = [n for n in vid_names if int(n[9:12]) in train_cross_subject]
56 | elif evaluation == 'cross-subject' and split == 'test':
57 | vid_names = [
58 | n for n in vid_names if int(n[9:12]) not in train_cross_subject
59 | ]
60 | elif evaluation == 'cross-view' and split != 'test':
61 | vid_names = [n for n in vid_names if int(n[5:8]) in train_cross_view]
62 | elif evaluation == 'cross-view' and split == 'test':
63 | vid_names = [n for n in vid_names if int(n[5:8]) not in train_cross_view]
64 | else:
65 | raise NotImplementedError
66 |
67 | if subsample:
68 | labels = np.array([int(n[-3:])-1 for n in vid_names])
69 | vid_names_subsample = []
70 | for i in range(n_classes):
71 | keep = np.where(labels == i)[0]
72 | keep = keep[np.linspace(0, len(keep)-1, subsample).astype(int)]
73 | vid_names_subsample += [vid_names[j] for j in keep]
74 | vid_names = vid_names_subsample
75 |
76 | elif has_skel and split == 'train' and not subsample:
77 | labels = np.array([int(n[-3:])-1 for n in vid_names])
78 | vid_names_add = []
79 | class_54 = np.where(labels == 54)[0].tolist()
80 | class_58 = np.where(labels == 58)[0].tolist()
81 | class_59 = np.where(labels == 59)[0].tolist()
82 |
83 | # deterministic oversampling for consistency
84 | for i in class_54[::7]:
85 | vid_names_add.append(vid_names[i])
86 | for i in class_58+class_58[::3]:
87 | vid_names_add.append(vid_names[i])
88 | for i in class_59+class_59[::2]+class_59[1::3]:
89 | vid_names_add.append(vid_names[i])
90 | vid_names += vid_names_add
91 |
92 | dataset = [(n, int(n[-3:]) - 1) for n in vid_names]
93 |
94 | utils.info('NTU-RGBD: {}, {}, {} videos'.format(split, evaluation,
95 | len(dataset)))
96 | return dataset
97 |
98 |
99 | def rgb_loader(root, vid_name, frame_ids):
100 | """Loads the RGB data."""
101 | vid = []
102 | for frame_ids_s in frame_ids:
103 | vid_s = []
104 | for frame_id in frame_ids_s:
105 | path = os.path.join(root, rgb_folder_name, vid_name,
106 | rgb_pattern % (frame_id + 1))
107 | img = imgproc.imread_rgb('ntu-rgbd', path)
108 | vid_s.append(img)
109 | vid.append(vid_s)
110 | return np.array(vid)
111 |
112 |
113 | def oflow_loader(root, vid_name, frame_ids):
114 | """Loads the flow data."""
115 | vid = []
116 | for frame_ids_s in frame_ids:
117 | vid_s = []
118 | for frame_id in frame_ids_s:
119 | path = os.path.join(root, oflow_folder_name, vid_name,
120 | oflow_pattern % (frame_id + 1))
121 | img = imgproc.imread_oflow('ntu-rgbd', path)
122 | vid_s.append(img)
123 | vid.append(vid_s)
124 | return np.array(vid)
125 |
126 |
127 | def depth_loader(root, vid_name, frame_ids):
128 | """Loads the depth data."""
129 | vid = []
130 | for frame_ids_s in frame_ids:
131 | vid_s = []
132 | for frame_id in frame_ids_s:
133 | path = os.path.join(root, depth_folder_name, vid_name,
134 | depth_pattern % (frame_id + 1))
135 | img = imgproc.imread_depth('ntu-rgbd', path)
136 | vid_s.append(img)
137 | vid.append(vid_s)
138 | return np.array(vid)
139 |
140 |
141 | def jjd_loader(root, vid_name, frame_ids):
142 | """Loads the skeleton data JJD."""
143 | path = os.path.join(root, skel_folder_name, 'jjd', vid_name + '.npy')
144 | skel = np.load(path).astype(np.float32)
145 | skel = skel[frame_ids]
146 | return skel
147 |
148 |
149 | def jjv_loader(root, vid_name, frame_ids):
150 | """Loads the skeleton data JJV."""
151 | path = os.path.join(root, skel_folder_name, 'jjv', vid_name + '.npy')
152 | skel = np.load(path).astype(np.float32)
153 | skel = skel[frame_ids]
154 | return skel
155 |
156 |
157 | def jld_loader(root, vid_name, frame_ids):
158 | """Loads the skeleton data JLD."""
159 | path = os.path.join(root, skel_folder_name, 'jld', vid_name + '.npy')
160 | skel = np.load(path).astype(np.float32)
161 | skel = skel[frame_ids]
162 | return skel
163 |
164 |
165 | class NTU_RGBD(data.Dataset):
166 | """Class for NTU RGBD Dataset"""
167 | MEAN_STD = {
168 | 'rgb': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
169 | 'oflow': (0.5, 1. / 255),
170 | 'depth': (4084.1213735 / (255 * 256), 1008.31271366 / (255 * 256)),
171 | 'jjd': (0.53968146, 0.32319776),
172 | 'jjv': (0, 0.35953656),
173 | 'jld': (0.15982792, 0.12776225)
174 | }
175 |
176 | def __init__(self,
177 | root,
178 | split,
179 | evaluation,
180 | modalities,
181 | n_samples,
182 | n_frames,
183 | downsample,
184 | transforms=None,
185 | subsample=0):
186 | """NTU RGBD dataset.
187 |
188 | Args:
189 | root: dataset root
190 | split: train to randomly select n_samples samples; test to uniformly
191 | select n_samples spanning the whole video
192 | evaluation: one of ['cross-subject', 'cross-view']
193 | modalities: subset of ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
194 | n_samples: number of samples from the video
195 | n_frames: number of frames per sample
196 | downsample: fps /= downsample
197 | transforms: transformations to apply to data
198 | subsample: number of samples per class. 0 if using full dataset.
199 | """
200 | modalities = utils.unsqueeze(modalities)
201 | transforms = utils.unsqueeze(transforms)
202 |
203 | # Loader functions
204 | loaders = {
205 | 'rgb': rgb_loader,
206 | 'oflow': oflow_loader,
207 | 'depth': depth_loader,
208 | 'jjd': jjd_loader,
209 | 'jjv': jjv_loader,
210 | 'jld': jld_loader
211 | }
212 | has_skel = any([m in ALL_MODALITIES[3:] for m in modalities])
213 | dataset = make_dataset(root, has_skel, evaluation, split, subsample)
214 |
215 | self.root = root
216 | self.split = split
217 | self.modalities = modalities
218 | self.n_samples = n_samples
219 | self.n_frames = n_frames
220 | self.downsample = downsample
221 | self.transforms = transforms
222 | self.loaders = loaders
223 | self.dataset = dataset
224 |
225 | def __getitem__(self, idx):
226 | vid_name, label = self.dataset[idx]
227 | # -1 because len(oflow) = len(rgb)-1
228 | length = len(
229 | glob.glob(
230 | os.path.join(self.root, rgb_folder_name, vid_name,
231 | '*.' + rgb_pattern.split('.')[1]))) - 1
232 |
233 | length_ds = length // self.downsample
234 | if length_ds < self.n_frames:
235 | frame_ids_s = np.arange(0, length_ds, 1) # arange: exclusive
236 | frame_ids_s = np.concatenate(
237 | (frame_ids_s,
238 | np.array([frame_ids_s[-1]] * (self.n_frames - length_ds))))
239 | frame_ids = np.repeat(
240 | frame_ids_s[np.newaxis, :], self.n_samples,
241 | axis=0).astype(int) * self.downsample
242 | else:
243 | if self.split == 'train': # randomly select n_samples samples
244 | starts = np.random.randint(0, length_ds - self.n_frames + 1,
245 | self.n_samples) # randint: exclusive
246 | # uniformly select n_samples spanning the whole video
247 | elif self.split == 'val' or self.split == 'test':
248 | starts = np.linspace(
249 | 0, length_ds - self.n_frames, self.n_samples,
250 | dtype=int) # linspace: inclusive
251 | else:
252 | starts = np.arange(0,
253 | length_ds - self.n_frames + 1) # arange: exclusive
254 |
255 | frame_ids = []
256 | for start in starts:
257 | frame_ids_s = np.arange(start, start + self.n_frames,
258 | 1) * self.downsample # arange: exclusive
259 | frame_ids.append(frame_ids_s)
260 | frame_ids = np.stack(frame_ids)
261 |
262 | # load raw data
263 | inputs = []
264 | for modality in self.modalities:
265 | vid = self.loaders[modality](self.root, vid_name, frame_ids)
266 | inputs.append(vid)
267 |
268 | # transform
269 | if self.transforms is not None:
270 | for i in range(len(self.transforms)):
271 | if self.transforms[i] is not None:
272 | inputs[i] = self.transforms[i](inputs[i])
273 |
274 | return inputs, label
275 |
276 | def __len__(self):
277 | return len(self.dataset)
278 |
--------------------------------------------------------------------------------
/data_pipeline/pku_mmd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """PKU-MMD dataset."""
17 |
18 | import glob
19 | import numpy as np
20 | import os
21 | import random
22 | import torch.utils.data as data
23 |
24 | import utils
25 | import utils.imgproc as imgproc
26 |
27 | rgb_folder_name = 'rgb'
28 | rgb_pattern = 'RGB-%08d.jpg'
29 | oflow_folder_name = 'oflow'
30 | oflow_pattern = 'OFlow-%08d.jpg'
31 | depth_folder_name = 'depth'
32 | depth_pattern = 'Depth-%08d.png'
33 | skel_folder_name = 'skeleton'
34 | feat_folder_name = 'feature'
35 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
36 |
37 |
38 | def make_dataset(root, evaluation, mode, folder_name=rgb_folder_name, subsample_rate=0):
39 | """
40 | mode: train (0), test (1)
41 | subsample_rate: Rate of subsampling. 0 if using full dataset.
42 | """
43 | vid_names = os.listdir(os.path.join(root, folder_name))
44 | f = open(os.path.join(root, 'split', evaluation + '.txt')).read().split('\n')
45 | idx = 1 if mode == 0 else 3
46 | vid_names_split = f[idx].split(', ')[:-1]
47 | vid_names = sorted(list(set(vid_names).intersection(set(vid_names_split))))
48 |
49 | if mode == 0 and subsample_rate:
50 | # Subsample for training
51 | vid_names = vid_names[::subsample_rate]
52 |
53 | dataset = []
54 | for vid_name in vid_names:
55 | label = np.loadtxt(
56 | os.path.join(root, 'label', vid_name + '.txt'),
57 | delimiter=',',
58 | dtype=int)
59 | dataset.append((vid_name, label))
60 |
61 | utils.info('PKU-MMD: {}, {} videos'.format(evaluation, len(dataset)))
62 | return dataset
63 |
64 |
65 | def rgb_loader(root, vid_name, frame_ids):
66 | vid = []
67 | for frame_ids_s in frame_ids:
68 | vid_s = []
69 | for frame_id in frame_ids_s:
70 | path = os.path.join(root, rgb_folder_name, vid_name,
71 | rgb_pattern % (frame_id + 1))
72 | img = imgproc.imread_rgb('pku-mmd', path)
73 | vid_s.append(img)
74 | vid.append(vid_s)
75 | return np.array(vid)
76 |
77 |
78 | def oflow_loader(root, vid_name, frame_ids):
79 | vid = []
80 | for frame_ids_s in frame_ids:
81 | vid_s = []
82 | for frame_id in frame_ids_s:
83 | path = os.path.join(root, oflow_folder_name, vid_name,
84 | oflow_pattern % (frame_id + 1))
85 | img = imgproc.imread_oflow('pku-mmd', path)
86 | vid_s.append(img)
87 | vid.append(vid_s)
88 | return np.array(vid)
89 |
90 |
91 | def depth_loader(root, vid_name, frame_ids):
92 | vid = []
93 | for frame_ids_s in frame_ids:
94 | vid_s = []
95 | for frame_id in frame_ids_s:
96 | path = os.path.join(root, depth_folder_name, vid_name,
97 | depth_pattern % (frame_id + 1))
98 | img = imgproc.imread_depth('pku-mmd', path)
99 | vid_s.append(img)
100 | vid.append(vid_s)
101 | return np.array(vid)
102 |
103 |
104 | def jjd_loader(root, vid_name, frame_ids):
105 | path = os.path.join(root, skel_folder_name, 'jjd', vid_name + '.npy')
106 | skel = np.load(path).astype(np.float32)
107 | skel = skel[frame_ids]
108 | return skel
109 |
110 |
111 | def jjv_loader(root, vid_name, frame_ids):
112 | path = os.path.join(root, skel_folder_name, 'jjv', vid_name + '.npy')
113 | skel = np.load(path).astype(np.float32)
114 | skel = skel[frame_ids]
115 | return skel
116 |
117 |
118 | def jld_loader(root, vid_name, frame_ids):
119 | path = os.path.join(root, skel_folder_name, 'jld', vid_name + '.npy')
120 | skel = np.load(path).astype(np.float32)
121 | skel = skel[frame_ids]
122 | return skel
123 |
124 |
125 | def get_overlap(a, b):
126 | return max(0, min(a[1], b[1]) - max(a[0], b[0]))
127 |
128 |
129 | class PKU_MMD(data.Dataset):
130 | MEAN_STD = {
131 | 'rgb': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
132 | 'oflow': (0.5, 1. / 255),
133 | 'depth': (4084.1213735 / (255 * 256), 1008.31271366 / (255 * 256)),
134 | 'jjd': (0.53968146, 0.32319776),
135 | 'jjv': (0, 0.35953656),
136 | 'jld': (0.15982792, 0.12776225)
137 | }
138 |
139 | def __init__(self,
140 | root,
141 | mode,
142 | evaluation,
143 | modalities,
144 | step_size,
145 | n_frames,
146 | downsample,
147 | timestep,
148 | transforms=None,
149 | subsample_rate=0):
150 | """PKU_MMD Constructor.
151 |
152 | Contructs the PKU_MMD dataset.
153 |
154 | Args:
155 | root: dataset root
156 | mode: train (0), test (1)
157 | evaluation: one of ['cross-subject', 'cross-view']
158 | modalities: one of ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
159 | step_size: step size between clips
160 | n_frames: number of frames per clip
161 | transforms: transform.
162 | subsample_rate: sampling rate
163 | downsample: fps /= downsample
164 | timestep: number of clips in a sequence.
165 | """
166 | modalities = utils.unsqueeze(modalities)
167 | transforms = utils.unsqueeze(transforms)
168 |
169 | loaders = {
170 | 'rgb': rgb_loader,
171 | 'oflow': oflow_loader,
172 | 'depth': depth_loader,
173 | 'jjd': jjd_loader,
174 | 'jjv': jjv_loader,
175 | 'jld': jld_loader
176 | }
177 | self.loaders = loaders
178 | self.dataset = make_dataset(
179 | root, evaluation, mode, subsample_rate=subsample_rate)
180 |
181 | self.root = root
182 | self.modalities = modalities
183 |
184 | self.step_size = step_size
185 | self.n_frames = n_frames
186 | self.downsample = downsample
187 | self.timestep = timestep
188 | self.all = mode != 0 # True if test mode, return the entire video
189 | self.transforms = transforms
190 |
191 | def __getitem__(self, idx):
192 | vid_name, label = self.dataset[idx]
193 | # label: action_id, start_frame, end_frame, confidence
194 | # -1 because len(oflow) = len(rgb)-1
195 | length = len(
196 | glob.glob(
197 | os.path.join(self.root, rgb_folder_name, vid_name,
198 | '*.' + rgb_pattern.split('.')[1]))) - 1
199 | length_ds = length // self.downsample
200 |
201 | if self.all:
202 | # Return entire video
203 | starts = np.arange(0, length_ds - self.n_frames + 1,
204 | self.step_size) # arange: exclusive
205 | else:
206 | start = random.randint(
207 | 0, length_ds - ((self.timestep - 1) * self.step_size + self.n_frames))
208 | starts = [start + i * self.step_size for i in range(self.timestep)
209 | ] # randint: inclusive
210 |
211 | frame_ids = []
212 | for start in starts:
213 | frame_ids_s = np.arange(start, start + self.n_frames,
214 | 1) * self.downsample # arange: exclusive
215 | frame_ids.append(frame_ids_s)
216 | frame_ids = np.stack(frame_ids)
217 |
218 | targets = []
219 | for frame_ids_s in frame_ids:
220 | target = 0
221 | max_ratio = 0.5
222 | for action_id, start_frame, end_frame, _ in label:
223 | overlap = get_overlap([frame_ids_s[0], frame_ids_s[-1] - 1],
224 | [start_frame, end_frame - 1])
225 | ratio = overlap / (frame_ids_s[-1] - frame_ids_s[0])
226 | if ratio > max_ratio:
227 | target = int(action_id)
228 | targets.append(target)
229 | targets = np.stack(targets)
230 |
231 | # load raw data
232 | inputs = []
233 | for modality in self.modalities:
234 | vid = self.loaders[modality](self.root, vid_name, frame_ids)
235 | inputs.append(vid)
236 |
237 | # transform
238 | if self.transforms is not None:
239 | for i, transform in enumerate(self.transforms):
240 | if transform is not None:
241 | inputs[i] = transform(inputs[i])
242 |
243 | return inputs, targets, vid_name
244 |
245 | def __len__(self):
246 | return len(self.dataset)
247 |
--------------------------------------------------------------------------------
/data_pipeline/ucf_101.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """TODO: One-sentence doc string."""
17 |
18 | import glob
19 | import os
20 |
21 | import numpy as np
22 | import torch.utils.data as data
23 |
24 | import utils
25 | import utils.imgproc as imgproc
26 |
27 | rgb_folder_name = 'jpegs_256'
28 | rgb_pattern = 'frame%06d.jpg'
29 | oflow_folder_name = 'tvl1_flow'
30 | oflow_pattern = 'frame%06d.jpg'
31 | train_split_subpath = 'ucfTrainTestlist/trainlist01.txt'
32 | test_split_subpath = 'ucfTrainTestlist/testlist01.txt'
33 | ALL_MODALITIES = ['rgb', 'oflow']
34 |
35 |
36 | def find_classes(root):
37 | rgb_folder_path = os.path.join(root, rgb_folder_name)
38 | classes = [
39 | n.split('_')[1]
40 | for n in os.listdir(rgb_folder_path)
41 | if os.path.isdir(os.path.join(rgb_folder_path, n))
42 | ]
43 | classes = sorted(list(set(classes)))
44 | class_to_idx = {classes[i]: i for i in range(len(classes))}
45 |
46 | return classes, class_to_idx
47 |
48 |
49 | def make_dataset(root, class_to_idx, train):
50 | dataset = []
51 | split_subpath = train_split_subpath if train else test_split_subpath
52 | split_path = os.path.join(root, split_subpath)
53 |
54 | with open(split_path) as split_file:
55 | split_lines = split_file.readlines()
56 |
57 | for line in split_lines:
58 | vid_name = line.split()[0].split('/')[1].replace('.avi', '')
59 | class_name = vid_name.split('_')[1]
60 | item = (vid_name, class_to_idx[class_name])
61 | dataset.append(item)
62 |
63 | return dataset
64 |
65 |
66 | def rgb_loader(root, vid_name, frame_id):
67 | rgb_path = os.path.join(root, rgb_folder_name, vid_name,
68 | rgb_pattern % frame_id)
69 | return imgproc.imread_rgb('ucf-101', rgb_path)
70 |
71 |
72 | def oflow_loader(root, vid_name, frame_id):
73 | oflow_path_u = os.path.join(root, oflow_folder_name, 'u', vid_name,
74 | oflow_pattern % frame_id)
75 | oflow_path_v = os.path.join(root, oflow_folder_name, 'v', vid_name,
76 | oflow_pattern % frame_id)
77 | return imgproc.imread_oflow('ucf-101', oflow_path_u, oflow_path_v)
78 |
79 |
80 | class UCF_101(data.Dataset):
81 |
82 | def __init__(self,
83 | root,
84 | train,
85 | modalities,
86 | n_samples,
87 | n_frames,
88 | transforms=None,
89 | target_transform=None):
90 | classes, class_to_idx = find_classes(root)
91 | dataset = make_dataset(root, class_to_idx, train)
92 |
93 | modalities = utils.unsqueeze(modalities)
94 | transforms = utils.unsqueeze(transforms)
95 |
96 | all_loaders = [rgb_loader, oflow_loader]
97 | all_modalities = ['rgb', 'oflow']
98 | loaders = [
99 | all_loaders[i]
100 | for i in range(len(all_loaders))
101 | if all_modalities[i] in modalities
102 | ]
103 |
104 | assert len(modalities) == len(loaders)
105 |
106 | self.root = root
107 | self.train = train
108 | self.modalities = modalities
109 | self.n_samples = n_samples
110 | self.n_frames = n_frames
111 | self.transforms = transforms
112 | self.target_transform = target_transform
113 |
114 | self.loaders = loaders
115 |
116 | self.classes = classes
117 | self.class_to_idx = class_to_idx
118 | self.dataset = dataset
119 |
120 | def __getitem__(self, idx):
121 | vid_name, target = self.dataset[idx]
122 | length = len(
123 | glob.glob(
124 | os.path.join(self.root, rgb_folder_name, vid_name,
125 | '*.' + rgb_pattern.split('.')[1]))) - 1
126 |
127 | if self.train:
128 | samples = np.random.randint(0, length - self.n_frames + 1,
129 | self.n_samples) # randint: exclusive
130 | else:
131 | if length > self.n_samples:
132 | samples = np.round(
133 | np.linspace(0, length - self.n_frames,
134 | self.n_samples)).astype(int) # linspace: inclusive
135 | else:
136 | samples = np.arange(0, length - self.n_frames + 1) # arange: exclusive
137 |
138 | # load raw data
139 | inputs = []
140 | for loader in self.loaders:
141 | vid = []
142 | for s in samples:
143 | clip = []
144 | for t in range(self.n_frames):
145 | frame_id = s + t + 1
146 | image = loader(self.root, vid_name, frame_id)
147 | clip.append(image)
148 | vid.append(clip)
149 | inputs.append(np.array(vid))
150 |
151 | # transform
152 | if self.transforms is not None:
153 | for i, transform in enumerate(self.transforms):
154 | if self.transforms[i] is not None:
155 | inputs[i] = transform(inputs[i])
156 |
157 | return utils.squeeze(inputs), target
158 |
159 | def __len__(self):
160 | return len(self.dataset)
161 |
--------------------------------------------------------------------------------
/detection/evaluation/map.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Calculate mAP @ IoU thresholds for detection."""
17 |
18 | import numpy as np
19 | from ..third_party.pku-mmd.evaluate import process # TODO(alan) please see whether this work
20 |
21 |
22 | def get_segments(scores, activity_threshold):
23 | """Get prediction segments of a video."""
24 | # Each segment contains start, end, class, confidence score.
25 | # Sum of all probabilities (1 - probability of no-activity)
26 | activity_prob = 1 - scores[:, 0]
27 | # Binary vector indicating whether a clip is an activity or no-activity
28 | activity_tag = np.zeros(activity_prob.shape, dtype=np.int32)
29 | activity_tag[activity_prob >= activity_threshold] = 1
30 | assert activity_tag.ndim == 1
31 | # For each index, subtract the previous index, getting -1, 0, or 1
32 | # 1 indicates the start of a segment, and -1 indicates the end.
33 | padded = np.pad(activity_tag, pad_width=1, mode='constant')
34 | diff = padded[1:] - padded[:-1]
35 | indexes = np.arange(diff.size)
36 | startings = indexes[diff == 1]
37 | endings = indexes[diff == -1]
38 | assert startings.size == endings.size
39 |
40 | segments = []
41 | for start, end in zip(startings, endings):
42 | segment_scores = scores[start:end, :]
43 | class_prob = np.mean(segment_scores, axis=0)
44 | segment_class_index = np.argmax(class_prob[1:]) + 1
45 | confidence = np.mean(segment_scores[:, segment_class_index])
46 | segments.append((start, end, segment_class_index, confidence))
47 | return segments
48 |
49 |
50 | def calc_map(opt, video_scores, video_names, groundtruth_dir, iou_thresholds):
51 | """Get mAP (action) for IoU 0.1, 0.3 and 0.5."""
52 | activity_threshold = 0.4
53 | num_videos = len(video_scores)
54 | video_files = [name + '.txt' for name in video_names]
55 |
56 | v_props = []
57 | for i in range(num_videos):
58 | # video_name = video_names[i]
59 | scores = video_scores[i]
60 | segments = get_segments(scores, activity_threshold)
61 |
62 | prop = []
63 | for segment in segments:
64 | start, end, cls, score = segment
65 | # start, end are indices of clips. Transform to frame index.
66 | start_index = start * opt.step_size * opt.downsample
67 | end_index = (
68 | (end - 1) * opt.step_size + opt.n_frames) * opt.downsample - 1
69 | prop.append([cls, start_index, end_index, score, video_files[i]])
70 | v_props.append(prop)
71 |
72 | # Run evaluation on different IoU thresholds.
73 | mean_aps = []
74 | for iou in iou_thresholds:
75 | mean_ap = process(v_props, video_files, groundtruth_dir, iou)
76 | mean_aps.append(mean_ap)
77 | return mean_aps
78 |
--------------------------------------------------------------------------------
/detection/get_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Get data loader for detection."""
17 |
18 | import os
19 | import torch.utils.data as data
20 | import torchvision.transforms as transforms
21 |
22 | import third_party.two_stream_pytorch.video_transforms as vtransforms
23 | from data_pipeline.pku_mmd import PKU_MMD
24 |
25 | CNN_MODALITIES = ['rgb', 'oflow', 'depth']
26 | GRU_MODALITIES = ['jjd', 'jjv', 'jld']
27 |
28 |
29 | def get_dataloader(opt):
30 | idx_t = 0 if opt.split == 'train' else 1
31 |
32 | xforms = []
33 | for modality in opt.modalities:
34 | if opt.dset == 'pku-mmd':
35 | mean, std = PKU_MMD.MEAN_STD[modality]
36 | else:
37 | raise NotImplementedError
38 |
39 | if opt.split == 'train' and (modality == 'rgb' or modality == 'depth'):
40 | xform = transforms.Compose([
41 | vtransforms.RandomSizedCrop(224),
42 | vtransforms.RandomHorizontalFlip(),
43 | vtransforms.ToTensor(),
44 | vtransforms.Normalize(mean, std)
45 | ])
46 | elif opt.split == 'train' and modality == 'oflow':
47 | # Special handling when flipping optical flow
48 | xform = transforms.Compose([
49 | vtransforms.RandomSizedCrop(224, True),
50 | vtransforms.RandomHorizontalFlip(True),
51 | vtransforms.ToTensor(),
52 | vtransforms.Normalize(mean, std)
53 | ])
54 | elif opt.split != 'train' and modality in CNN_MODALITIES:
55 | xform = transforms.Compose([
56 | vtransforms.Scale(256),
57 | vtransforms.CenterCrop(224),
58 | vtransforms.ToTensor(),
59 | vtransforms.Normalize(mean, std)
60 | ])
61 | elif modality in GRU_MODALITIES:
62 | xform = transforms.Compose([vtransforms.SkelNormalize(mean, std)])
63 | else:
64 | raise Exception
65 |
66 | xforms.append(xform)
67 |
68 | if opt.dset == 'pku-mmd':
69 | root = os.path.join(opt.dset_path, opt.dset)
70 | dset = PKU_MMD(root, idx_t, 'cross-subject', opt.modalities, opt.step_size,
71 | opt.n_frames, opt.downsample, opt.timestep, xforms,
72 | opt.subsample_rate)
73 | else:
74 | raise NotImplementedError
75 |
76 | dataloader = data.DataLoader(
77 | dset,
78 | batch_size=opt.batch_sizes[idx_t],
79 | shuffle=(opt.split == 'train'),
80 | num_workers=opt.n_workers)
81 |
82 | return dataloader
83 |
--------------------------------------------------------------------------------
/detection/get_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Get detection model."""
17 |
18 | from .model import SingleStream
19 | from .model import GraphDistillation
20 |
21 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
22 |
23 |
24 | def get_model(opt):
25 | if opt.dset == 'pku-mmd':
26 | n_classes = 51
27 | all_input_sizes = [-1, -1, -1, 276, 828, 836]
28 | all_n_channels = [3, 2, 1, -1, -1, -1]
29 | else:
30 | raise NotImplementedError
31 |
32 | n_channels = [all_n_channels[ALL_MODALITIES.index(m)] for m in opt.modalities]
33 | input_sizes = [
34 | all_input_sizes[ALL_MODALITIES.index(m)] for m in opt.modalities
35 | ]
36 |
37 | if len(opt.modalities) == 1:
38 | # Single stream
39 | index = 0
40 | model = SingleStream(opt.modalities, n_classes, opt.n_frames, n_channels,
41 | input_sizes, opt.hidden_size, opt.n_layers,
42 | opt.dropout, opt.hidden_size_seq, opt.n_layers_seq,
43 | opt.dropout_seq, opt.bg_w, opt.lr, opt.lr_decay_rate,
44 | index, opt.ckpt_path)
45 | else:
46 | index = opt.modalities.index(opt.xfer_to)
47 | model = GraphDistillation(
48 | opt.modalities, n_classes, opt.n_frames, n_channels, input_sizes,
49 | opt.hidden_size, opt.n_layers, opt.dropout, opt.hidden_size_seq,
50 | opt.n_layers_seq, opt.dropout_seq, opt.bg_w, opt.lr, opt.lr_decay_rate,
51 | index, opt.ckpt_path, opt.w_losses, opt.w_modalities, opt.metric,
52 | opt.xfer_to, opt.gd_size, opt.gd_reg)
53 |
54 | return model
55 |
--------------------------------------------------------------------------------
/detection/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Model to train detection."""
17 |
18 | from collections import OrderedDict
19 | import numpy as np
20 | import torch
21 | import torch.backends.cudnn as cudnn
22 | import torch.optim as optim
23 |
24 | import utils
25 | from nets.get_tad import *
26 | from nets.get_distillation_kernel import *
27 |
28 | CNN_MODALITIES = ['rgb', 'oflow', 'depth']
29 | GRU_MODALITIES = ['jjd', 'jjv', 'jld']
30 |
31 |
32 | class BaseModel:
33 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes,
34 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq,
35 | dropout_seq, bg_w, lr, lr_decay_rate, to_idx, ckpt_path):
36 | super(BaseModel, self).__init__()
37 | cudnn.benchmark = True
38 | utils.info('{} modality'.format(modalities[to_idx]))
39 |
40 | self.embeds = []
41 | for i, m in enumerate(modalities):
42 | encoder_type = 'cnn' if m in CNN_MODALITIES else 'rnn'
43 | embed = nn.DataParallel(
44 | get_tad(n_classes, n_frames, n_channels[i], input_sizes[i],
45 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq,
46 | dropout_seq, encoder_type).cuda())
47 | self.embeds.append(embed)
48 |
49 | # Multiple optimizers
50 | self.optimizers = []
51 | self.lr_decay_rates = []
52 | # Visual encoder: SGD
53 | visual_params = list(self.embeds[to_idx].module.embed.parameters())
54 | visual_optimizer = optim.SGD(
55 | visual_params, lr=lr, momentum=0.9, weight_decay=5e-4)
56 | self.optimizers.append(visual_optimizer)
57 | self.lr_decay_rates.append(lr_decay_rate)
58 | # Sequence encoder: Adam
59 | sequence_params = list(self.embeds[to_idx].module.rnn.parameters()) + \
60 | list(self.embeds[to_idx].module.fc.parameters())
61 | sequence_optimizer = optim.Adam(sequence_params, lr=1e-3)
62 | self.optimizers.append(sequence_optimizer)
63 | self.lr_decay_rates.append(1) # No learning rate decay for Adam
64 |
65 | # Weighted cross-entropy loss
66 | self.criterion_cls = nn.CrossEntropyLoss(
67 | torch.FloatTensor([bg_w] + [1] * n_classes)).cuda()
68 |
69 | self.n_classes = n_classes
70 | self.modalities = modalities
71 | self.to_idx = to_idx
72 | self.ckpt_path = ckpt_path
73 |
74 | def _forward(self, inputs):
75 | """Forward pass for all modalities.
76 | """
77 | logits, reprs = [], []
78 | for i in range(len(inputs)):
79 | logit, repr = self.embeds[i](inputs[i])
80 | logits.append(logit)
81 | reprs.append(repr)
82 |
83 | logits = torch.stack(logits)
84 | reprs = torch.stack(reprs)
85 | return [logits, reprs]
86 |
87 | def _backward(self, results, label):
88 | raise NotImplementedError
89 |
90 | def train(self, inputs, label):
91 | """Train model.
92 | :param inputs: a list, each is batch_size x timestep x n_frames x (n_channels x h x w) or (input_size)
93 | :param label: batch_size x timestep
94 | """
95 | for embed in self.embeds:
96 | embed.train()
97 |
98 | for i in range(len(inputs)):
99 | inputs[i] = Variable(inputs[i].cuda(), requires_grad=False)
100 | label = Variable(label.cuda(), requires_grad=False)
101 |
102 | results = self._forward(inputs)
103 | info_loss = self._backward(results, label)
104 | info_acc = self._get_acc(results[0], label)
105 | return OrderedDict(info_loss + info_acc)
106 |
107 | def test(self, inputs, label, timestep):
108 | '''Test model.
109 | param timestep: split into segments of length timestep.
110 | '''
111 | for embed in self.embeds:
112 | embed.eval()
113 |
114 | input = Variable(inputs[0].cuda(), requires_grad=False)
115 | label = Variable(label.cuda(), requires_grad=False)
116 | length = input.size(1)
117 |
118 | # Split video into segments
119 | input, start_indices = utils.get_segments(input, timestep)
120 | inputs = [input]
121 |
122 | logits, _ = self._forward(inputs)
123 | logits = utils.to_numpy(logits).squeeze(0)
124 | all_logits = [[] for i in range(length)]
125 | for i in range(len(start_indices)):
126 | s = start_indices[i]
127 | for j in range(timestep):
128 | all_logits[s + j].append(logits[i][j])
129 | # Average logits for each time step.
130 | final_logits = np.zeros((length, self.n_classes + 1))
131 | for i in range(length):
132 | final_logits[i] = np.mean(all_logits[i], axis=0)
133 | logits = final_logits
134 |
135 | info_acc = self._get_acc([torch.Tensor(logits)], label)
136 | scores = utils.softmax(logits, axis=1)
137 | return OrderedDict(info_acc), logits, scores
138 |
139 | def _get_acc(self, logits, label):
140 | """Get detection statistics for modality.
141 | """
142 | info_acc = []
143 | for i, m in enumerate(self.modalities):
144 | logit = logits[i].view(-1, self.n_classes + 1)
145 | label = label.view(-1)
146 | stats = utils.get_stats_detection(logit, label, self.n_classes + 1)
147 | info_acc.append(('ap_{}'.format(m), stats[0]))
148 | info_acc.append(('acc_{}'.format(m), stats[1]))
149 | info_acc.append(('acc_bg_{}'.format(m), stats[2]))
150 | info_acc.append(('acc_action_{}'.format(m), stats[3]))
151 | return info_acc
152 |
153 | def save(self, epoch):
154 | path = os.path.join(self.ckpt_path, 'embed_{}.pth'.format(epoch))
155 | torch.save(self.embeds[self.to_idx].state_dict(), path)
156 |
157 | def load(self, load_ckpt_paths, options, epoch=200):
158 | """Load checkpoints.
159 | """
160 | assert len(load_ckpt_paths) == len(self.embeds)
161 | for i in range(len(self.embeds)):
162 | ckpt_path = load_ckpt_paths[i]
163 | load_opt = options[i]
164 | if len(ckpt_path) == 0:
165 | utils.info('{}: training from scratch'.format(self.modalities[i]))
166 | continue
167 |
168 | if load_opt == 0: # load teacher model (visual + sequence)
169 | path = os.path.join(ckpt_path, 'embed_{}.pth'.format(epoch))
170 | ckpt = torch.load(path)
171 | try:
172 | self.embeds[i].load_state_dict(ckpt)
173 | except:
174 | utils.warn('Check that the "modalities" argument is correct.')
175 | exit(0)
176 | utils.info('{}: ckpt {} loaded'.format(self.modalities[i], path))
177 | elif load_opt == 1: # load pretrained visual encoder
178 | ckpt = torch.load(ckpt_path)
179 | # Change keys in the ckpt
180 | new_state_dict = OrderedDict()
181 | for key in list(ckpt.keys())[:-2]: # exclude fc weights
182 | new_key = key[7:] # Remove 'module.'
183 | new_state_dict[new_key] = ckpt[key]
184 | # update state_dict
185 | state_dict = self.embeds[i].module.embed.state_dict()
186 | state_dict.update(new_state_dict)
187 | self.embeds[i].module.embed.load_state_dict(state_dict)
188 | utils.info('{}: visual encoder from {} loaded'.format(
189 | self.modalities[i], ckpt_path))
190 | else:
191 | raise NotImplementedError
192 |
193 | def lr_decay(self):
194 | lrs = []
195 | for optimizer, decay_rate in zip(self.optimizers, self.lr_decay_rates):
196 | for param_group in optimizer.param_groups:
197 | param_group['lr'] *= decay_rate
198 | lrs.append(param_group['lr'])
199 | return lrs
200 |
201 |
202 | class SingleStream(BaseModel):
203 | """Model to train a single modality.
204 | """
205 |
206 | def __init__(self, *args, **kwargs):
207 | super(SingleStream, self).__init__(*args, **kwargs)
208 | assert len(self.embeds) == 1
209 |
210 | def _backward(self, results, label):
211 | logits, _ = results
212 | logits = logits.view(-1, logits.size(-1))
213 |
214 | loss = self.criterion_cls(logits, label.view(-1))
215 | loss.backward()
216 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5)
217 | for optimizer in self.optimizers:
218 | optimizer.step()
219 | optimizer.zero_grad()
220 |
221 | info_loss = [('loss', loss.data[0])]
222 | return info_loss
223 |
224 |
225 | class GraphDistillation(BaseModel):
226 | """Model to train with graph distillation.
227 |
228 | xfer_to is the modality to train.
229 | """
230 |
231 | def __init__(self, modalities, n_classes, n_frames, n_channels, input_sizes,
232 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq,
233 | dropout_seq, bg_w, lr, lr_decay_rate, to_idx, ckpt_path,
234 | w_losses, w_modalities, metric, xfer_to, gd_size, gd_reg):
235 | super(GraphDistillation, self).__init__(\
236 | modalities, n_classes, n_frames, n_channels, input_sizes,
237 | hidden_size, n_layers, dropout, hidden_size_seq, n_layers_seq, dropout_seq,
238 | bg_w, lr, lr_decay_rate, to_idx, ckpt_path)
239 |
240 | # Index of the modality to distill
241 | to_idx = self.modalities.index(xfer_to)
242 | from_idx = [x for x in range(len(self.modalities)) if x != to_idx]
243 | assert len(from_idx) >= 1
244 |
245 | # Prior
246 | w_modalities = [w_modalities[i] for i in from_idx
247 | ] # remove modality being transferred to
248 | gd_prior = utils.softmax(w_modalities, 0.25)
249 | # Distillation model
250 | self.distillation_kernel = \
251 | get_distillation_kernel(n_classes + 1, hidden_size, gd_size, to_idx, from_idx,
252 | gd_prior, gd_reg, w_losses, metric).cuda()
253 |
254 | # Add optimizer to self.optimizers
255 | gd_optimizer = optim.SGD(
256 | self.distillation_kernel.parameters(),
257 | lr=lr,
258 | momentum=0.9,
259 | weight_decay=5e-4)
260 | self.optimizers.append(gd_optimizer)
261 | self.lr_decay_rates.append(lr_decay_rate)
262 |
263 | self.xfer_to = xfer_to
264 | self.to_idx = to_idx
265 | self.from_idx = from_idx
266 |
267 | def _forward(self, inputs):
268 | logits, reprs = super(GraphDistillation, self)._forward(inputs)
269 | n_modalities, batch_size, length, _ = logits.size()
270 | logits = logits.view(n_modalities, batch_size * length, -1)
271 | reprs = reprs.view(n_modalities, batch_size * length, -1)
272 | # Get edge weights of the graph
273 | graph = self.distillation_kernel(logits, reprs)
274 | return logits, reprs, graph
275 |
276 | def _backward(self, results, label):
277 | logits, reprs, graph = results # graph: size (len(from_idx) x batch_size)
278 | label = label.view(-1)
279 | info_loss = []
280 |
281 | # Classification loss
282 | loss_cls = self.criterion_cls(logits[self.to_idx], label)
283 | # Graph distillation loss
284 | loss_reg, loss_logit, loss_repr = \
285 | self.distillation_kernel.distillation_loss(logits, reprs, graph)
286 |
287 | loss = loss_cls + loss_reg + loss_logit + loss_repr
288 | loss.backward()
289 | torch.nn.utils.clip_grad_norm(self.embeds[self.to_idx].parameters(), 5)
290 | for optimizer in self.optimizers:
291 | optimizer.step()
292 | optimizer.zero_grad()
293 |
294 | info_loss = [('loss_cls', loss_cls.data[0]), ('loss_reg', loss_reg.data[0]),
295 | ('loss_logit', loss_logit.data[0]), ('loss_repr',
296 | loss_repr.data[0])]
297 | return info_loss
298 |
--------------------------------------------------------------------------------
/detection/run.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Train detection"""
17 |
18 | import argparse
19 | import os
20 |
21 | import utils
22 | import utils.logging as logging
23 | from .get_dataloader import *
24 | from .get_model import *
25 | from .evaluation.map import calc_map
26 |
27 | ALL_MODALITIES = ['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld']
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | # experimental settings
32 | parser.add_argument('--n_workers', type=int, default=24)
33 | parser.add_argument('--gpus', type=str, default='0')
34 | parser.add_argument('--split', type=str, choices=['train', 'val', 'test'])
35 |
36 | # ckpt and logging
37 | parser.add_argument('--ckpt_path', type=str, default='./ckpt',
38 | help='directory path that stores all checkpoints')
39 | parser.add_argument('--ckpt_name', type=str, default='ckpt')
40 | parser.add_argument('--pretrained_ckpt_name', type=str, default='ckpt',
41 | help='name of the teacher detection models')
42 | parser.add_argument('--load_epoch', type=int, default=400,
43 | help='epoch to load teacher model')
44 | parser.add_argument('--visual_encoder_ckpt_path', type=str, default='',
45 | help='classification checkpoint to initialize'
46 | 'the visual encoder weights')
47 | parser.add_argument('--load_ckpt_path', type=str, default='',
48 | help='checkpoint path to load for testing')
49 | parser.add_argument('--print_every', type=int, default=50)
50 | parser.add_argument('--save_every', type=int, default=50)
51 |
52 | # hyperparameters
53 | parser.add_argument('--batch_sizes', type=int, nargs='+', default=[8, 1],
54 | help='batch sizes: [train, test]')
55 | parser.add_argument('--n_epochs', type=int, default=400)
56 | parser.add_argument('--lr', type=float, default=1e-3)
57 | parser.add_argument('--lr_decay_at', type=int, nargs='+', default=[250, 350])
58 | parser.add_argument('--lr_decay_rate', type=float, default=0.1)
59 |
60 | # data pipeline
61 | parser.add_argument('--dset', type=str, default='pku-mmd')
62 | parser.add_argument('--dset_path', type=str,
63 | default=os.path.join(os.environ['HOME'], 'slowbro'))
64 | parser.add_argument('--modalities', type=str, nargs='+',
65 | choices=['rgb', 'oflow', 'depth', 'jjd', 'jjv', 'jld'])
66 | parser.add_argument('--step_size', type=int, default=10,
67 | help='step size between samples (after downsample)')
68 | parser.add_argument('--n_frames', type=int, default=10,
69 | help='num frames per clip')
70 | parser.add_argument('--downsample', type=int, default=3,
71 | help='fps /= downsample')
72 | parser.add_argument('--timestep', type=int, default=10,
73 | help='number of clips in a sequence')
74 | parser.add_argument('--bg_w', type=float, default=0.5)
75 | parser.add_argument('--subsample_rate', type=int, default=20,
76 | help='rate to subsample the dataset. 0: False (use full dataset)')
77 |
78 | # visual encoder GRU (for modalities jjd, jjv, jld)
79 | parser.add_argument('--dropout', type=float, default=0.5)
80 | parser.add_argument('--hidden_size', type=int, default=512)
81 | parser.add_argument('--n_layers', type=int, default=3)
82 | # sequence encoder GRU
83 | parser.add_argument('--dropout_seq', type=float, default=0.5)
84 | parser.add_argument('--hidden_size_seq', type=int, default=512)
85 | parser.add_argument('--n_layers_seq', type=int, default=1)
86 |
87 | # Graph Distillation parameters
88 | parser.add_argument('--metric', type=str, default='cosine',
89 | choices=['cosine', 'kl', 'l2', 'l1'],
90 | help='distance metric for distillation loss')
91 | parser.add_argument('--w_losses', type=float, nargs='+', default=[10, 1],
92 | help='weights for losses: [logit, repr]')
93 | parser.add_argument('--w_modalities', type=float, nargs='+',
94 | default=[1, 1, 1, 1, 1, 1],
95 | help='modality prior')
96 | parser.add_argument('--xfer_to', type=str, default='',
97 | help='modality to train with graph distillation')
98 | parser.add_argument('--gd_size', type=int, default=32,
99 | help='hidden size of graph distillation')
100 | parser.add_argument('--gd_reg', type=float, default=10,
101 | help='regularization for graph distillation')
102 |
103 |
104 | def single_stream(opt):
105 | """Train a single modality from scratch."""
106 | # Checkpoint path example: ckpt_path/pku-mmd/rgb/ckpt
107 | opt.ckpt_path = os.path.join(opt.ckpt_path, opt.dset,
108 | opt.modalities[0], opt.ckpt_name)
109 | os.makedirs(opt.ckpt_path, exist_ok=True)
110 | if opt.split == 'train':
111 | if opt.visual_encoder_ckpt_path != '':
112 | assert os.path.exists(opt.visual_encoder_ckpt_path), \
113 | '{} does not exist'.format(opt.visual_encoder_ckpt_path)
114 | opt.load_ckpt_paths = [opt.visual_encoder_ckpt_path]
115 | opt.load_opts = [1] # load visual encoder
116 | else:
117 | opt.load_ckpt_paths = [opt.load_ckpt_path]
118 | opt.load_opts = [0] # load visual + sequence encoder
119 |
120 | # Data loader and model
121 | dataloader = get_dataloader(opt)
122 | model = get_model(opt)
123 | if opt.split == 'train':
124 | train(opt, model, dataloader)
125 | else:
126 | test(opt, model, dataloader)
127 |
128 |
129 | def multi_stream(opt):
130 | """Train a modality with graph distillation from other modalities."""
131 | assert opt.xfer_to in opt.modalities, 'xfer_to must be in opt.modalities'
132 | # Checkpoints to load
133 | opt.load_ckpt_paths = []
134 | opt.load_opts = []
135 | for m in opt.modalities:
136 | if m != opt.xfer_to:
137 | # Checkpoint from single_stream
138 | path = os.path.join(opt.ckpt_path, opt.dset, m, opt.pretrained_ckpt_name)
139 | assert os.path.exists(path), '{} checkpoint does not exist.'.format(path)
140 | opt.load_ckpt_paths.append(path)
141 | opt.load_opts.append(0)
142 | else:
143 | opt.load_ckpt_paths.append(opt.visual_encoder_ckpt_path)
144 | opt.load_opts.append(1)
145 |
146 | # Checkpoint path example: ckpt_path/ntu-rgbd/xfer_rgb/ckpt_rgb_depth
147 | opt.ckpt_path = os.path.join(
148 | opt.ckpt_path, opt.dset, 'xfer_{}'.format(opt.xfer_to), '{}_{}'.format(
149 | opt.ckpt_name, '_'.join([m for m in opt.modalities])))
150 | os.makedirs(opt.ckpt_path, exist_ok=True)
151 |
152 | # Data loader and model
153 | dataloader = get_dataloader(opt)
154 | model = get_model(opt)
155 | train(opt, model, dataloader)
156 |
157 |
158 | def train(opt, model, dataloader):
159 | # Logging
160 | logger = logging.Logger(opt.ckpt_path, opt.split)
161 | stats = logging.Statistics(opt.ckpt_path, opt.split)
162 | logger.log(opt)
163 |
164 | model.load(opt.load_ckpt_paths, opt.load_opts, opt.load_epoch)
165 | for epoch in range(1, opt.n_epochs + 1):
166 | for step, data in enumerate(dataloader, 1):
167 | # inputs is a list of input of each modality
168 | inputs, label, _ = data
169 | ret = model.train(inputs, label)
170 | update = stats.update(len(label), ret)
171 | if utils.is_due(step, opt.print_every):
172 | utils.info('epoch {}/{}, step {}/{}: {}'.format(
173 | epoch, opt.n_epochs, step, len(dataloader), update))
174 |
175 | logger.log('[Summary] epoch {}/{}: {}'.format(epoch, opt.n_epochs,
176 | stats.summarize()))
177 |
178 | if utils.is_due(epoch, opt.n_epochs, opt.save_every):
179 | model.save(epoch)
180 | stats.save()
181 | logger.log('***** saved *****')
182 |
183 | if utils.is_due(epoch, opt.lr_decay_at):
184 | lrs = model.lr_decay()
185 | logger.log('***** lr decay *****: {}'.format(lrs))
186 |
187 |
188 | def test(opt, model, dataloader):
189 | # Logging
190 | logger = logging.Logger(opt.ckpt_path, opt.split)
191 | stats = logging.Statistics(opt.ckpt_path, opt.split)
192 | logger.log(opt)
193 |
194 | model.load(opt.load_ckpt_paths, opt.load_opts, opt.load_epoch)
195 | all_scores = []
196 | video_names = []
197 | for step, data in enumerate(dataloader, 1):
198 | inputs, label, vid_name = data
199 | info_acc, logits, scores = model.test(inputs, label, opt.timestep)
200 |
201 | all_scores.append(scores)
202 | video_names.append(vid_name[0])
203 | update = stats.update(logits.shape[0], info_acc)
204 | if utils.is_due(step, opt.print_every):
205 | utils.info('step {}/{}: {}'.format(step, len(dataloader), update))
206 |
207 | logger.log('[Summary] {}'.format(stats.summarize()))
208 |
209 | # Evaluate
210 | iou_thresholds = [0.1, 0.3, 0.5]
211 | groundtruth_dir = os.path.join(opt.dset_path, opt.dset, 'groundtruth',
212 | 'validation/cross-subject')
213 | assert os.path.exists(groundtruth_dir), '{} does not exist'.format(groundtruth_dir)
214 | mean_aps = calc_map(opt, all_scores, video_names, groundtruth_dir, iou_thresholds)
215 |
216 | for i in range(len(iou_thresholds)):
217 | logger.log('IoU: {}, mAP: {}'.format(iou_thresholds[i], mean_aps[i]))
218 |
219 |
220 | if __name__ == '__main__':
221 | opt = parser.parse_args()
222 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
223 |
224 | if opt.split == 'test':
225 | assert len(opt.modalities) == 1, 'specify only 1 modality for testing'
226 | assert len(opt.load_ckpt_path) > 0, 'specify load_ckpt_path for testing'
227 |
228 | if len(opt.modalities) == 1:
229 | single_stream(opt)
230 | else:
231 | multi_stream(opt)
232 |
--------------------------------------------------------------------------------
/images/pull.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/graph_distillation/1a7ce5125098e7df869f08e15e2d6d8bb3189382/images/pull.png
--------------------------------------------------------------------------------
/nets/get_distillation_kernel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Graph distillation kernel."""
17 |
18 | import torch
19 | from torch.autograd import Variable
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import utils
23 |
24 |
25 | class DistillationKernel(nn.Module):
26 | """Graph Distillation kernel.
27 |
28 | Calculate the edge weights e_{j->k} for each j. Modality k is specified by
29 | to_idx, and the other modalities are specified by from_idx.
30 | """
31 |
32 | def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
33 | gd_prior, gd_reg, w_losses, metric, alpha):
34 | super(DistillationKernel, self).__init__()
35 | self.W_logit = nn.Linear(n_classes, gd_size)
36 | self.W_repr = nn.Linear(hidden_size, gd_size)
37 | self.W_edge = nn.Linear(gd_size * 4, 1)
38 |
39 | self.gd_size = gd_size
40 | self.to_idx = to_idx
41 | self.from_idx = from_idx
42 | self.alpha = alpha
43 | # For calculating distillation loss
44 | self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
45 | self.gd_reg = gd_reg
46 | self.w_losses = w_losses # [logit weight, repr weight]
47 | self.metric = metric
48 |
49 |
50 | def forward(self, logits, reprs):
51 | """
52 | Args:
53 | logits: (n_modalities, batch_size, n_classes)
54 | reprs: (n_modalities, batch_siz`, hidden_size)
55 | Return:
56 | edges: weights e_{j->k} (n_modalities_from, batch_size)
57 | """
58 | n_modalities, batch_size = logits.size()[:2]
59 | z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
60 | z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
61 | z = torch.cat(
62 | (z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
63 | self.gd_size * 2)
64 |
65 | edges = []
66 | for i in self.from_idx:
67 | # To calculate e_{j->k}, concatenate z^j, z^k
68 | e = self.W_edge(torch.cat((z[self.to_idx], z[i]), dim=1))
69 | edges.append(e)
70 | edges = torch.cat(edges, dim=1)
71 | edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1)
72 | return edges
73 |
74 |
75 | def distillation_loss(self, logits, reprs, edges):
76 | """Calculate graph distillation losses, which include:
77 | regularization loss, loss for logits, and loss for representation.
78 | """
79 | # Regularization for graph distillation (average across batch)
80 | loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
81 |
82 | loss_logit, loss_repr = 0, 0
83 | for i, idx in enumerate(self.from_idx):
84 | w_distill = edges[i] + self.gd_prior[i] # add graph prior
85 | loss_logit += self.w_losses[0] * utils.distance_metric(
86 | logits[self.to_idx], logits[idx], self.metric, w_distill)
87 | loss_repr += self.w_losses[1] * utils.distance_metric(
88 | reprs[self.to_idx], reprs[idx], self.metric, w_distill)
89 | return loss_reg, loss_logit, loss_repr
90 |
91 |
92 | def get_distillation_kernel(n_classes,
93 | hidden_size,
94 | gd_size,
95 | to_idx,
96 | from_idx,
97 | gd_prior,
98 | gd_reg,
99 | w_losses,
100 | metric,
101 | alpha=1 / 8):
102 | return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
103 | gd_prior, gd_reg, w_losses, metric, alpha)
104 |
--------------------------------------------------------------------------------
/nets/get_gru.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """GRU module."""
17 |
18 | import torch
19 | from torch.autograd import Variable
20 | import torch.nn as nn
21 |
22 |
23 | class GRU(nn.Module):
24 | """ GRU Class"""
25 |
26 | def __init__(self, input_size, hidden_size, n_layers, dropout, n_classes):
27 | super(GRU, self).__init__()
28 | self.gru = nn.GRU(
29 | input_size=input_size,
30 | hidden_size=hidden_size,
31 | num_layers=n_layers,
32 | batch_first=True,
33 | dropout=dropout)
34 | self.fc = nn.Linear(hidden_size, n_classes)
35 |
36 | self.hidden_size = hidden_size
37 | self.n_layers = n_layers
38 |
39 | def _get_states(self, batch_size):
40 | h0 = Variable(
41 | torch.zeros(self.n_layers, batch_size, self.hidden_size).cuda(),
42 | requires_grad=False)
43 | return h0
44 |
45 | def forward(self, x):
46 | """:param x: input of size batch_size' x n_frames x input_size (batch_size' = batch_size*n_samples) :return:
47 | """
48 | batch_size = x.size(0)
49 | h0 = self._get_states(batch_size)
50 | x, _ = self.gru(x, h0)
51 | representation = x.mean(1)
52 | logit = self.fc(representation)
53 | return logit, representation
54 |
55 |
56 | def get_gru(input_size, hidden_size, n_layers, dropout, n_classes):
57 | return GRU(input_size, hidden_size, n_layers, dropout, n_classes)
58 |
--------------------------------------------------------------------------------
/nets/get_tad.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Model for temporal action detection."""
17 |
18 | import torch.nn as nn
19 | from third_party.pytorch.get_cnn import *
20 | from .get_gru import *
21 |
22 |
23 | class TAD(nn.Module):
24 | """Temporal action detection model.
25 |
26 | Consists of visual encoder (specified by encoder_type) and sequence
27 | encoder.
28 | """
29 |
30 | def __init__(self, n_classes, n_frames, n_channels, input_size, hidden_size,
31 | n_layers, dropout, hidden_size_seq, n_layers_seq, dropout_seq,
32 | encoder_type):
33 | super(TAD, self).__init__()
34 |
35 | # Visual encoder
36 | if encoder_type == 'cnn':
37 | self.embed = get_resnet(n_frames * n_channels, n_classes)
38 | elif encoder_type == 'rnn':
39 | self.embed = get_gru(input_size, hidden_size, n_layers, dropout,
40 | n_classes)
41 | else:
42 | raise NotImplementedError
43 |
44 | # Sequence encoder
45 | self.rnn = nn.GRU(
46 | input_size=hidden_size,
47 | hidden_size=hidden_size_seq,
48 | num_layers=n_layers_seq,
49 | batch_first=True,
50 | dropout=dropout_seq,
51 | bidirectional=True)
52 | # Classification layer
53 | self.fc = nn.Linear(hidden_size_seq,
54 | n_classes + 1) # plus 1 class for background
55 |
56 | self.n_classes = n_classes
57 | self.hidden_size = hidden_size
58 | self.hidden_size_seq = hidden_size_seq
59 | self.n_layers_seq = n_layers_seq
60 | self.encoder_type = encoder_type
61 |
62 | def forward(self, x):
63 | """:param x: if encoder_type == 'cnn', batch_size x timestep x n_frames x n_channels x h x w
64 |
65 | if encoder_type == 'lstm', batch_size x timestep x n_frames x
66 | input_size
67 | :return: representation and logits
68 | """
69 | if self.encoder_type == 'cnn':
70 | batch_size, timestep, n_frames, n_channels, h, w = x.size()
71 | x = x.view(batch_size * timestep, n_frames * n_channels, h, w)
72 | _, x = self.embed(x)
73 | x = x.view(batch_size, timestep, self.hidden_size)
74 | elif self.encoder_type == 'rnn':
75 | batch_size, timestep, n_frames, input_size = x.size()
76 | x = x.view(batch_size * timestep, n_frames, input_size)
77 | _, x = self.embed(x)
78 | x = x.view(batch_size, timestep, self.hidden_size)
79 | else:
80 | raise NotImplementedError
81 |
82 | batch_size = x.size(0)
83 | x, _ = self.rnn(x)
84 | x = x.contiguous().view(x.size(0), x.size(1), 2, -1).sum(2)
85 | representation = x
86 | logit = self.fc(representation)
87 |
88 | return logit, representation
89 |
90 |
91 | def get_tad(n_classes, n_frames, n_channels, input_size, hidden_size, n_layers,
92 | dropout, hidden_size_seq, n_layers_seq, dropout_seq, encoder_type):
93 | return TAD(n_classes, n_frames, n_channels, input_size, hidden_size, n_layers,
94 | dropout, hidden_size_seq, n_layers_seq, dropout_seq, encoder_type)
95 |
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/depth.zip
2 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jjd.zip
3 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jjv.zip
4 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/jld.zip
5 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/oflow.zip
6 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/rgb.zip
7 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/xfer_depth.zip
8 | wget https://storage.googleapis.com/graph_distillation/ckpt/ntu-rgbd/xfer_jjd.zip
9 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/depth.zip
10 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jjd.zip
11 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jjv.zip
12 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/jld.zip
13 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/oflow.zip
14 | wget https://storage.googleapis.com/graph_distillation/ckpt/pku-mmd/rgb.zip
15 |
16 |
--------------------------------------------------------------------------------
/scripts/test_ntu_rgbd.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m classification.run \
18 | --gpus 0 \
19 | --split test \
20 | --dset ntu-rgbd \
21 | --load_ckpt_path ckpt/ntu-rgbd/rgb/sub \
22 | --modalities rgb
23 |
--------------------------------------------------------------------------------
/scripts/test_pku_mmd.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m detection.run \
18 | --gpus 0,1,2,3 \
19 | --split test \
20 | --dset pku-mmd \
21 | --load_ckpt_path ckpt/pku-mmd/depth/ckpt \
22 | --modalities depth
23 |
--------------------------------------------------------------------------------
/scripts/train_ntu_rgbd.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m classification.run \
18 | --gpus 0 \
19 | --split train \
20 | --dset ntu-rgbd \
21 | --subsample 33 \
22 | --modalities rgb
23 |
--------------------------------------------------------------------------------
/scripts/train_ntu_rgbd_distillation.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m classification.run \
18 | --gpus 0,1 \
19 | --split train \
20 | --dset ntu-rgbd \
21 | --subsample 33 \
22 | --pretrained_ckpt_name sub \
23 | --xfer_to jjd \
24 | --modalities rgb oflow depth jjd jjv jld
25 |
--------------------------------------------------------------------------------
/scripts/train_pku_mmd.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m detection.run \
18 | --gpus 3 \
19 | --split train \
20 | --dset pku-mmd \
21 | --visual_encoder_ckpt_path ./ckpt/ntu-rgbd/jld/ckpt/embed_200.pth \
22 | --subsample_rate 20 \
23 | --ckpt_name sub \
24 | --modalities jld
25 |
--------------------------------------------------------------------------------
/scripts/train_pku_mmd_distillation.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | #!/bin/bash
17 | python -m detection.run \
18 | --gpus 0 \
19 | --split train \
20 | --dset pku-mmd \
21 | --visual_encoder_ckpt_path ./ckpt/ntu-rgbd/depth/ckpt/embed_200.pth \
22 | --xfer_to depth \
23 | --modalities depth jjd
24 |
--------------------------------------------------------------------------------
/third_party/pku_mmd/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2017] [Institute of Computer Science and Technology, Peking University]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/third_party/pku_mmd/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright [2017] [Institute of Computer Science and Technology, Peking University]
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Code adopted and modified from
17 | https://github.com/ECHO960/PKU-MMD/blob/master/evaluate.py
18 | """
19 |
20 | import argparse
21 | import os
22 | import numpy as np
23 |
24 | number_label = 52
25 |
26 |
27 | # calc_pr: calculate precision and recall
28 | # @positive: number of positive proposal
29 | # @proposal: number of all proposal
30 | # @ground: number of ground truth
31 | def calc_pr(positive, proposal, ground):
32 | if (proposal == 0):
33 | return 0, 0
34 | if (ground == 0):
35 | return 0, 0
36 | return (1.0 * positive) / proposal, (1.0 * positive) / ground
37 |
38 |
39 | # match: match proposal and ground truth
40 | # @lst: list of proposals(label, start, end, confidence, video_name)
41 | # @ratio: overlap ratio
42 | # @ground: list of ground truth(label, start, end, confidence, video_name)
43 | #
44 | # correspond_map: record matching ground truth for each proposal
45 | # count_map: record how many proposals is each ground truth matched by
46 | # index_map: index_list of each video for ground truth
47 | def match(lst, ratio, ground):
48 |
49 | def overlap(prop, ground):
50 | l_p, s_p, e_p, c_p, v_p = prop
51 | l_g, s_g, e_g, c_g, v_g = ground
52 | if (int(l_p) != int(l_g)):
53 | return 0
54 | if (v_p != v_g):
55 | return 0
56 | return (min(e_p, e_g) - max(s_p, s_g)) / (max(e_p, e_g) - min(s_p, s_g))
57 |
58 | cos_map = [-1 for x in range(len(lst))]
59 | count_map = [0 for x in range(len(ground))]
60 | #generate index_map to speed up
61 | index_map = [[] for x in range(number_label)]
62 | for x in range(len(ground)):
63 | index_map[int(ground[x][0])].append(x)
64 |
65 | for x in range(len(lst)):
66 | for y in index_map[int(lst[x][0])]:
67 | if (overlap(lst[x], ground[y]) < ratio):
68 | continue
69 | if (overlap(lst[x], ground[y]) < overlap(lst[x], ground[cos_map[x]])):
70 | continue
71 | cos_map[x] = y
72 | if (cos_map[x] != -1):
73 | count_map[cos_map[x]] += 1
74 | positive = sum([(x > 0) for x in count_map])
75 | return cos_map, count_map, positive
76 |
77 |
78 | # Interpolated Average Precision:
79 | # @lst: list of proposals(label, start, end, confidence, video_name)
80 | # @ratio: overlap ratio
81 | # @ground: list of ground truth(label, start, end, confidence, video_name)
82 | #
83 | # score = sigma(precision(recall) * delta(recall))
84 | # Note that when overlap ratio < 0.5,
85 | # one ground truth will correspond to many proposals
86 | # In that case, only one positive proposal is counted
87 | def ap(lst, ratio, ground):
88 | lst.sort(key=lambda x: x[3]) # sorted by confidence
89 | cos_map, count_map, positive = match(lst, ratio, ground)
90 | score = 0
91 | number_proposal = len(lst)
92 | number_ground = len(ground)
93 | old_precision, old_recall = calc_pr(positive, number_proposal, number_ground)
94 |
95 | for x in range(len(lst)):
96 | number_proposal -= 1
97 | if (cos_map[x] == -1):
98 | continue
99 | count_map[cos_map[x]] -= 1
100 | if (count_map[cos_map[x]] == 0):
101 | positive -= 1
102 |
103 | precision, recall = calc_pr(positive, number_proposal, number_ground)
104 | if precision > old_precision:
105 | old_precision = precision
106 | score += old_precision * (old_recall - recall)
107 | old_recall = recall
108 | return score
109 |
110 |
111 | def process(v_props, video_files, groundtruth_dir, theta):
112 | v_grounds = [] # ground-truth list separated by video
113 |
114 | #========== find all proposals separated by video========
115 | for video in video_files:
116 | ground = open(os.path.join(groundtruth_dir, video), "r").readlines()
117 | ground = [ground[x].replace(",", " ") for x in range(len(ground))]
118 | ground = [[float(y) for y in ground[x].split()] for x in range(len(ground))]
119 | #append video name
120 | for x in ground:
121 | x.append(video)
122 | v_grounds.append(ground)
123 |
124 | assert len(v_props) == len(v_grounds), "{} != {}".format(
125 | len(v_props), len(v_grounds))
126 |
127 | #========== find all proposals separated by action categories========
128 | # proposal list separated by class
129 | a_props = [[] for x in range(number_label)]
130 | # ground-truth list separated by class
131 | a_grounds = [[] for x in range(number_label)]
132 |
133 | for x in range(len(v_props)):
134 | for y in range(len(v_props[x])):
135 | a_props[int(v_props[x][y][0])].append(v_props[x][y])
136 |
137 | for x in range(len(v_grounds)):
138 | for y in range(len(v_grounds[x])):
139 | a_grounds[int(v_grounds[x][y][0])].append(v_grounds[x][y])
140 |
141 | #========== find all proposals========
142 | all_props = sum(a_props, [])
143 | all_grounds = sum(a_grounds, [])
144 |
145 | #========== calculate protocols========
146 | # mAP_action
147 | aps_action = np.array([
148 | ap(a_props[x + 1], theta, a_grounds[x + 1])
149 | for x in range(number_label - 1)
150 | ])
151 | # mAP_video
152 | aps_video = np.array(
153 | [ap(v_props[x], theta, v_grounds[x]) for x in range(len(v_props))])
154 | # Return mAP action.
155 | return np.mean(aps_action)
156 |
--------------------------------------------------------------------------------
/third_party/pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) Soumith Chintala 2016,
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/third_party/pytorch/get_cnn.py:
--------------------------------------------------------------------------------
1 | # BSD 3-Clause License
2 | #
3 | # Copyright (c) Soumith Chintala 2016,
4 | # All rights reserved.
5 | #
6 | # Redistribution and use in source and binary forms, with or without
7 | # modification, are permitted provided that the following conditions are met:
8 | #
9 | # * Redistributions of source code must retain the above copyright notice, this
10 | # list of conditions and the following disclaimer.
11 | #
12 | # * Redistributions in binary form must reproduce the above copyright notice,
13 | # this list of conditions and the following disclaimer in the documentation
14 | # and/or other materials provided with the distribution.
15 | #
16 | # * Neither the name of the copyright holder nor the names of its
17 | # contributors may be used to endorse or promote products derived from
18 | # this software without specific prior written permission.
19 |
20 | """Code for ResNet.
21 | Adapted from
22 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
23 | """
24 |
25 | import math
26 |
27 | import torch.nn as nn
28 | import torch.utils.model_zoo as model_zoo
29 | from torchvision.models.resnet import BasicBlock
30 | from torchvision.models.vgg import cfg, model_urls, VGG
31 |
32 |
33 | class ResNet(nn.Module):
34 |
35 | def __init__(self, block, layers, in_channels, n_classes):
36 | self.inplanes = 64
37 | super(ResNet, self).__init__()
38 | self.conv1 = nn.Conv2d(
39 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
40 | self.bn1 = nn.BatchNorm2d(64)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
43 | self.layer1 = self._make_layer(block, 64, layers[0])
44 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
45 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
46 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
47 | self.avgpool = nn.AvgPool2d(7)
48 | self.fc = nn.Linear(512 * block.expansion, n_classes)
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
53 | m.weight.data.normal_(0, math.sqrt(2. / n))
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 | def _make_layer(self, block, planes, blocks, stride=1):
59 | downsample = None
60 | if stride != 1 or self.inplanes != planes * block.expansion:
61 | downsample = nn.Sequential(
62 | nn.Conv2d(
63 | self.inplanes,
64 | planes * block.expansion,
65 | kernel_size=1,
66 | stride=stride,
67 | bias=False),
68 | nn.BatchNorm2d(planes * block.expansion),
69 | )
70 |
71 | layers = []
72 | layers.append(block(self.inplanes, planes, stride, downsample))
73 | self.inplanes = planes * block.expansion
74 | for i in range(1, blocks):
75 | layers.append(block(self.inplanes, planes))
76 |
77 | return nn.Sequential(*layers)
78 |
79 | def forward(self, x):
80 | x = self.conv1(x)
81 | x = self.bn1(x)
82 | x = self.relu(x)
83 | x = self.maxpool(x)
84 |
85 | x = self.layer1(x)
86 | x = self.layer2(x)
87 | x = self.layer3(x)
88 | conv4 = self.layer4(x)
89 |
90 | repr = self.avgpool(conv4).view(conv4.size(0), -1)
91 | logit = self.fc(repr)
92 |
93 | return logit, repr
94 |
95 |
96 | def resnet18(**kwargs):
97 | """Constructs a ResNet-18 model.
98 | Args:
99 | pretrained (bool): If True, returns a model pre-trained on ImageNet
100 | """
101 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
102 | return model
103 |
104 |
105 | def get_resnet(in_channels, n_classes):
106 | return resnet18(in_channels=in_channels, n_classes=n_classes)
107 |
108 |
109 | def make_layers(cfg, in_channels=3, batch_norm=False):
110 | layers = []
111 | for v in cfg:
112 | if v == 'M':
113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
114 | else:
115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
116 | if batch_norm:
117 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
118 | else:
119 | layers += [conv2d, nn.ReLU(inplace=True)]
120 | in_channels = v
121 | return nn.Sequential(*layers)
122 |
123 |
124 | def get_vgg(in_channels=3, **kwargs):
125 | model = VGG(make_layers(cfg['D'], in_channels), **kwargs)
126 | return model
127 |
--------------------------------------------------------------------------------
/third_party/two_stream_pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Yi Zhu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/two_stream_pytorch/video_transforms.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2017 Yi Zhu
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | """Transformations for videos.
16 |
17 | Code adapted from
18 | https://github.com/bryanyzhu/two-stream-pytorch/blob/master/video_transforms.py
19 | """
20 |
21 | import collections
22 | import math
23 | import numbers
24 | import random
25 |
26 | import numpy as np
27 | import torch
28 | import torch.nn.functional as F
29 |
30 | import utils
31 | from utils import imgproc
32 |
33 |
34 | class ToTensor(object):
35 | """Converts a numpy.ndarray (...
36 |
37 | x H x W x C) in the range
38 | [0, 255] to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0,
39 | 1.0].
40 | """
41 |
42 | def __init__(self, scale=True, to_float=True):
43 | self.scale = scale
44 | self.to_float = to_float
45 |
46 | def __call__(self, arr):
47 | if isinstance(arr, np.ndarray):
48 | video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3))
49 |
50 | if self.to_float:
51 | video = video.float()
52 |
53 | if self.scale:
54 | return video.div(255)
55 | else:
56 | return video
57 | else:
58 | raise NotImplementedError
59 |
60 |
61 | class Normalize(object):
62 | """Given mean and std,
63 | will normalize each channel of the torch.*Tensor, i.e.
64 | channel = (channel - mean) / std
65 | """
66 |
67 | def __init__(self, mean, std):
68 | if not isinstance(mean, list):
69 | mean = [mean]
70 | if not isinstance(std, list):
71 | std = [std]
72 |
73 | self.mean = torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2)
74 | self.std = torch.FloatTensor(std).unsqueeze(1).unsqueeze(2)
75 |
76 | def __call__(self, tensor):
77 | return tensor.sub_(self.mean).div_(self.std)
78 |
79 |
80 | class Scale(object):
81 | """Rescale the input numpy.ndarray to the given size.
82 | Args:
83 | size (sequence or int): Desired output size. If size is a sequence like
84 | (w, h), output size will be matched to this. If size is an int,
85 | smaller edge of the image will be matched to this number.
86 | i.e, if height > width, then image will be rescaled to
87 | (size * height / width, size)
88 | interpolation (int, optional): Desired interpolation. Default is
89 | ``bilinear``
90 | """
91 | def __init__(self, size, transform_pixel=False, interpolation='bilinear'):
92 | """:param size: output size :param transform_pixel: transform pixel values for flow :param interpolation: 'bilinear', 'nearest'
93 | """
94 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and
95 | len(size) == 2)
96 | self.size = size
97 | self.transform_pixel = transform_pixel
98 | self.interpolation = interpolation
99 |
100 | def __call__(self, video):
101 | """Args: video (numpy.ndarray): Video to be scaled.
102 |
103 | Returns:
104 | numpy.ndarray: Rescaled video.
105 | """
106 | w, h = video.shape[-2], video.shape[-3]
107 |
108 | if isinstance(self.size, int):
109 | if (w <= h and w == self.size) or (h <= w and h == self.size):
110 | return video
111 |
112 | if w < h:
113 | ow = self.size
114 | oh = int(self.size * h / w)
115 | video = imgproc.resize(video, (oh, ow), self.interpolation)
116 | else:
117 | oh = self.size
118 | ow = int(self.size * w / h)
119 | video = imgproc.resize(video, (oh, ow), self.interpolation)
120 |
121 | if self.transform_pixel:
122 | video[..., 0] = (video[..., 0] - 128) * (ow / w) + 128
123 | video[..., 1] = (video[..., 1] - 128) * (oh / h) + 128
124 | else:
125 | video = imgproc.resize(video, self.size, self.interpolation)
126 |
127 | if self.transform_pixel:
128 | video[..., 0] = (video[..., 0] - 128) * (self.size / w) + 128
129 | video[..., 1] = (video[..., 1] - 128) * (self.size / h) + 128
130 |
131 | return video
132 |
133 |
134 | class CenterCrop(object):
135 | """Crops the given numpy.ndarray at the center to have a region of
136 | the given size. size can be a tuple (target_height, target_width)
137 | or an integer, in which case the target will be of a square shape (size, size)
138 | """
139 |
140 | def __init__(self, size):
141 | if isinstance(size, numbers.Number):
142 | self.size = (int(size), int(size))
143 | else:
144 | self.size = size
145 |
146 | def __call__(self, video):
147 | h, w = video.shape[-3:-1]
148 | th, tw = self.size
149 | x1 = int(round((w - tw) / 2.))
150 | y1 = int(round((h - th) / 2.))
151 |
152 | return video[..., y1:y1 + th, x1:x1 + tw, :]
153 |
154 |
155 | class Pad(object):
156 | """Pad the given np.ndarray on all sides with the given "pad" value.
157 |
158 | Args: padding (int or sequence): Padding on each border. If a sequence of
159 | length 4, it is used to pad left, top, right and bottom borders respectively.
160 | fill: Pixel fill value. Default is 0.
161 | """
162 |
163 | def __init__(self, padding, fill=0):
164 | assert isinstance(padding, numbers.Number)
165 | assert isinstance(fill, numbers.Number) or isinstance(
166 | fill, str) or isinstance(fill, tuple)
167 | self.padding = padding
168 | self.fill = fill
169 |
170 | def __call__(self, video):
171 | """Args: video (np.ndarray): Video to be padded.
172 |
173 | Returns:
174 | np.ndarray: Padded video.
175 | """
176 | pad_width = ((0, 0), (self.padding, self.padding), (self.padding,
177 | self.padding), (0, 0))
178 | return np.pad(
179 | video, pad_width=pad_width, mode='constant', constant_values=self.fill)
180 |
181 |
182 | class RandomCrop(object):
183 | """Crop the given numpy.ndarray at a random location.
184 | Args:
185 | size (sequence or int): Desired output size of the crop. If size is an
186 | int instead of sequence like (h, w), a square crop (size, size) is
187 | made.
188 | padding (int or sequence, optional): Optional padding on each border
189 | of the image. Default is 0, i.e no padding. If a sequence of length
190 | 4 is provided, it is used to pad left, top, right, bottom borders
191 | respectively.
192 | """
193 |
194 | def __init__(self, size, padding=0):
195 | if isinstance(size, numbers.Number):
196 | self.size = (int(size), int(size))
197 | else:
198 | self.size = size
199 | self.padding = padding
200 |
201 | def __call__(self, video):
202 | """Args: video (np.ndarray): Video to be cropped.
203 |
204 | Returns:
205 | np.ndarray: Cropped video.
206 | """
207 | if self.padding > 0:
208 | pad = Pad(self.padding, 0)
209 | video = pad(video)
210 |
211 | w, h = video.shape[-2], video.shape[-3]
212 | th, tw = self.size
213 | if w == tw and h == th:
214 | return video
215 |
216 | x1 = random.randint(0, w - tw)
217 | y1 = random.randint(0, h - th)
218 | return video[..., y1:y1 + th, x1:x1 + tw, :]
219 |
220 |
221 | class RandomHorizontalFlip(object):
222 | """Randomly horizontally flips the given numpy.ndarray with a probability of 0.5
223 |
224 | """
225 |
226 | def __init__(self, transform_pixel=False):
227 | """:param transform_pixel: transform pixel values for flow
228 | """
229 | self.transform_pixel = transform_pixel if isinstance(
230 | transform_pixel, list) else [transform_pixel]
231 |
232 | def __call__(self, videos):
233 | """Support joint transform
234 | :param videos: np.ndarray or a list of np.ndarray
235 | :return:
236 | """
237 | if random.random() < 0.5:
238 | videos = utils.unsqueeze(videos)
239 | ret = []
240 | for tp, video in zip(self.transform_pixel, videos):
241 | video = video[..., ::-1, :]
242 | if tp:
243 | video[..., 0] = 255 - video[..., 0]
244 | ret.append(video.copy())
245 | return utils.squeeze(ret)
246 | else:
247 | return videos
248 |
249 |
250 | class RandomSizedCrop(object):
251 | """Crop the given np.ndarray to random size and aspect ratio.
252 | A crop of random size of (0.4 to 1.0) of the original size and a random
253 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
254 | is finally resized to given size.
255 | This is popularly used to train the Inception networks.
256 | """
257 |
258 | def __init__(self, size, transform_pixel=False):
259 | """:param size: size of the smaller edge :param transform_pixel: transform pixel values for flow
260 | """
261 | self.size = size
262 | self.transform_pixel = transform_pixel if isinstance(
263 | transform_pixel, list) else [transform_pixel]
264 |
265 | def __call__(self, videos):
266 | """Support joint transform
267 | :param videos: np.ndarray or a list of np.ndarray
268 | :return:
269 | """
270 | videos = utils.unsqueeze(videos)
271 | h_orig, w_orig = videos[0].shape[-3:-1]
272 |
273 | for attempt in range(10):
274 | ret = []
275 |
276 | area = h_orig * w_orig
277 | target_area = random.uniform(0.4, 1.0) * area
278 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
279 |
280 | w = int(round(math.sqrt(target_area * aspect_ratio)))
281 | h = int(round(math.sqrt(target_area / aspect_ratio)))
282 |
283 | if random.random() < 0.5:
284 | w, h = h, w
285 |
286 | if w <= w_orig and h <= h_orig:
287 | x1 = random.randint(0, w_orig - w)
288 | y1 = random.randint(0, h_orig - h)
289 |
290 | for tp, video in zip(self.transform_pixel, videos):
291 | video = video[..., y1:y1 + h, x1:x1 + w, :]
292 | video = imgproc.resize(video, (self.size, self.size), 'bilinear')
293 | if tp:
294 | video[..., 0] = (video[..., 0] - 128) * (self.size / w) + 128
295 | video[..., 1] = (video[..., 1] - 128) * (self.size / h) + 128
296 |
297 | ret.append(video)
298 |
299 | return utils.squeeze(ret)
300 |
301 | # Fallback
302 | ret = []
303 | scales = [Scale(self.size, tp, 'bilinear') for tp in self.transform_pixel]
304 | crop = CenterCrop(self.size)
305 | for scale, video in zip(scales, videos):
306 | video = crop(scale(video))
307 | ret.append(video)
308 |
309 | return utils.squeeze(ret)
310 |
311 |
312 | class SkelNormalize(object):
313 | """Given mean and std, will normalize the numpy array
314 | """
315 |
316 | def __init__(self, mean, std):
317 | self.mean = mean
318 | self.std = std
319 |
320 | def __call__(self, skel):
321 | return (skel - self.mean) / self.std
322 |
323 |
324 | class AvgPool(object):
325 | """Rescale the input by performing average pooling The height and width are scaled down by a factor of kernel_size
326 | """
327 |
328 | def __init__(self, kernel_size):
329 | self.kernel_size = kernel_size
330 |
331 | def __call__(self, x):
332 | batch_size, n_samples, n_channels, h, w = x.size()
333 | x = x.view(-1, n_channels, h, w)
334 | x = F.avg_pool2d(x, self.kernel_size, stride=self.kernel_size).data
335 | return x.view(batch_size, n_samples, *x.size()[-3:])
336 |
337 |
338 | class Clip(object):
339 | """Clip values of the numpy array
340 | """
341 |
342 | def __init__(self, lower, upper):
343 | self.lower = lower
344 | self.upper = upper
345 |
346 | def __call__(self, x):
347 | return np.clip(x, self.lower, self.upper)
348 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 |
--------------------------------------------------------------------------------
/utils/imgproc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for reading and processing data."""
17 |
18 | import cv2
19 | import numpy as np
20 | from scipy import interpolate
21 | from scipy.misc import imresize
22 |
23 |
24 | def imread_rgb(dset, path):
25 | if dset == 'ucf-101':
26 | rgb = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
27 | return rgb[:, :-1] # oflow is 1px smaller than rgb in ucf-101
28 | elif dset == 'ntu-rgbd' or dset == 'pku-mmd' or dset == 'cad-60':
29 | rgb = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
30 | return rgb
31 | else:
32 | assert False
33 |
34 |
35 | def imread_oflow(dset, *paths):
36 | if dset == 'ucf-101':
37 | path_u, path_v = paths
38 | oflow_u = cv2.imread(path_u, cv2.IMREAD_GRAYSCALE)
39 | oflow_v = cv2.imread(path_v, cv2.IMREAD_GRAYSCALE)
40 | oflow = np.stack((oflow_u, oflow_v), axis=2)
41 | return oflow
42 | elif dset == 'ntu-rgbd' or dset == 'pku-mmd' or dset == 'cad-60':
43 | path = paths[0]
44 | oflow = cv2.imread(path)[..., ::-1][..., :2]
45 | return oflow
46 | else:
47 | assert False
48 |
49 |
50 | def imread_depth(dset, path):
51 | # dset == 'ntu-rgbd' or dset == 'pku-mmd'
52 | depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, np.newaxis]
53 | depth = np.clip(depth/256, 0, 255).astype(np.uint8)
54 | return depth
55 |
56 |
57 | def inpaint(img, threshold=1):
58 | h, w = img.shape[:2]
59 |
60 | if len(img.shape) == 3: # RGB
61 | mask = np.all(img == 0, axis=2).astype(np.uint8)
62 | img = cv2.inpaint(img, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
63 |
64 | else: # depth
65 | mask = np.where(img > threshold)
66 | xx, yy = np.meshgrid(np.arange(w), np.arange(h))
67 | xym = np.vstack((np.ravel(xx[mask]), np.ravel(yy[mask]))).T
68 | img = np.ravel(img[mask])
69 | interp = interpolate.NearestNDInterpolator(xym, img)
70 | img = interp(np.ravel(xx), np.ravel(yy)).reshape(xx.shape)
71 |
72 | return img
73 |
74 |
75 | def resize(video, size, interpolation):
76 | """
77 | :param video: ... x h x w x num_channels
78 | :param size: (h, w)
79 | :param interpolation: 'bilinear', 'nearest'
80 | :return:
81 | """
82 | shape = video.shape[:-3]
83 | num_channels = video.shape[-1]
84 | video = video.reshape((-1, *video.shape[-3:]))
85 | resized_video = np.zeros((video.shape[0], *size, video.shape[-1]))
86 |
87 | for i in range(video.shape[0]):
88 | if num_channels == 3:
89 | resized_video[i] = imresize(video[i], size, interpolation)
90 | elif num_channels == 2:
91 | resized_video[i, ..., 0] = imresize(video[i, ..., 0], size, interpolation)
92 | resized_video[i, ..., 1] = imresize(video[i, ..., 1], size, interpolation)
93 | elif num_channels == 1:
94 | resized_video[i, ..., 0] = imresize(video[i, ..., 0], size, interpolation)
95 | else:
96 | raise NotImplementedError
97 |
98 | return resized_video.reshape((*shape, *size, video.shape[-1]))
99 |
100 |
101 | def proc_oflow(images):
102 | h, w = images.shape[-3:-1]
103 |
104 | processed_images = []
105 | for image in images:
106 | hsv = np.zeros((h, w, 3), dtype=np.uint8)
107 | hsv[:, :, 0] = 255
108 | hsv[:, :, 1] = 255
109 |
110 | mag, ang = cv2.cartToPolar(image[..., 0], image[..., 1])
111 | hsv[..., 0] = ang*180/np.pi/2
112 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
113 |
114 | processed_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
115 | processed_images.append(processed_image)
116 |
117 | return np.stack(processed_images)
118 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Calculate result statistics, logging."""
17 |
18 | from collections import OrderedDict
19 | import logging
20 | import os
21 | import sys
22 | import numpy as np
23 |
24 |
25 | class _AverageMeter(object):
26 | """ Average Meter Class."""
27 |
28 | def __init__(self):
29 | self.val = 0
30 | self.avg = 0
31 | self.sum = 0
32 | self.count = 0
33 |
34 | def update(self, val, n=1):
35 | self.val = val
36 | self.sum += val*n
37 | self.count += n
38 | self.avg = self.sum/self.count
39 |
40 |
41 | class Statistics(object):
42 | """ Statistics Class."""
43 |
44 | def __init__(self, ckpt_path=None, name='history'):
45 | self.meters = OrderedDict()
46 | self.history = OrderedDict()
47 | self.ckpt_path = ckpt_path
48 | self.name = name
49 |
50 | def update(self, n, ordered_dict):
51 | info = ''
52 | for key in ordered_dict:
53 | if key not in self.meters:
54 | self.meters.update({key: _AverageMeter()})
55 | self.meters[key].update(ordered_dict[key], n)
56 | info += '{key}={var.val:.4f}, avg {key}={var.avg:.4f}, '.format(
57 | key=key, var=self.meters[key])
58 |
59 | return info[:-2]
60 |
61 | def summarize(self, reset=True):
62 | info = ''
63 | for key in self.meters:
64 | info += '{key}={var:.4f}, '.format(key=key, var=self.meters[key].avg)
65 |
66 | if reset:
67 | self.reset()
68 |
69 | return info[:-2]
70 |
71 | def reset(self):
72 | for key in self.meters:
73 | if key in self.history:
74 | self.history[key].append(self.meters[key].avg)
75 | else:
76 | self.history.update({key: [self.meters[key].avg]})
77 |
78 | self.meters = OrderedDict()
79 |
80 | def load(self):
81 | self.history = np.load(
82 | os.path.join(self.ckpt_path, '{}.npy'.format(self.name))).item()
83 |
84 | def save(self):
85 | np.save(
86 | os.path.join(self.ckpt_path, '{}.npy'.format(self.name)), self.history)
87 |
88 |
89 | class Logger(object):
90 | """ Logger Class."""
91 |
92 | def __init__(self, path, name='debug'):
93 | self.logger = logging.getLogger()
94 | self.logger.setLevel(logging.INFO)
95 | formatter = logging.Formatter(
96 | '%(asctime)s %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
97 |
98 | fh = logging.FileHandler(os.path.join(path, '{}.log'.format(name)), 'w')
99 | fh.setLevel(logging.INFO)
100 | fh.setFormatter(formatter)
101 | self.logger.addHandler(fh)
102 |
103 | ch = logging.StreamHandler(sys.stdout)
104 | ch.setLevel(logging.INFO)
105 | ch.setFormatter(formatter)
106 | self.logger.addHandler(ch)
107 |
108 | def log(self, info):
109 | self.logger.info(info)
110 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Miscellaneous utility functions."""
17 |
18 | import numpy as np
19 | from sklearn.metrics import average_precision_score
20 | import torch
21 | import torch.nn.functional as F
22 | import torch.utils.data
23 |
24 |
25 | def to_numpy(array):
26 | if isinstance(array, np.ndarray):
27 | return array
28 | if isinstance(array, torch.autograd.Variable):
29 | array = array.data
30 | if array.is_cuda:
31 | array = array.cpu()
32 |
33 | return array.numpy()
34 |
35 |
36 | def squeeze(array):
37 | if not isinstance(array, list) or len(array) > 1:
38 | return array
39 | else: # len(array) == 1:
40 | return array[0]
41 |
42 |
43 | def unsqueeze(array):
44 | if isinstance(array, list):
45 | return array
46 | else:
47 | return [array]
48 |
49 |
50 | def is_due(*args):
51 | """Determines whether to perform an action or not, depending on the epoch.
52 | Used for logging, saving, learning rate decay, etc.
53 |
54 | Args:
55 | *args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
56 | due_every (due every due_every epochs)
57 | step, due_every (due every due_every steps)
58 | Returns:
59 | due: boolean: perform action or not
60 | """
61 | if len(args) == 2 and isinstance(args[1], list):
62 | epoch, due_at = args
63 | due = epoch in due_at
64 | elif len(args) == 3:
65 | epoch, num_epochs, due_every = args
66 | due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
67 | else:
68 | step, due_every = args
69 | due = (due_every > 0) and (step % due_every == 0)
70 |
71 | return due
72 |
73 |
74 | def softmax(w, t=1.0, axis=None):
75 | w = np.array(w) / t
76 | e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
77 | dist = e / np.sum(e, axis=axis, keepdims=True)
78 | return dist
79 |
80 |
81 | def distance_metric(student, teacher, option, weights=None):
82 | """Distance metric to calculate the imitation loss.
83 |
84 | Args:
85 | student: batch_size x n_classes
86 | teacher: batch_size x n_classes
87 | option: one of [cosine, l2, l2, kl]
88 | weights: batch_size or float
89 |
90 | Returns:
91 | The computed distance metric.
92 | """
93 | if option == 'cosine':
94 | dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
95 | elif option == 'l2':
96 | dists = (student-teacher.detach()).pow(2).sum(1)
97 | elif option == 'l1':
98 | dists = torch.abs(student-teacher.detach()).sum(1)
99 | elif option == 'kl':
100 | assert weights is None
101 | T = 8
102 | # averaged for each minibatch
103 | dist = F.kl_div(
104 | F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
105 | T * T)
106 | return dist
107 | else:
108 | raise NotImplementedError
109 |
110 | if weights is None:
111 | dist = dists.mean()
112 | else:
113 | dist = (dists * weights).mean()
114 |
115 | return dist
116 |
117 |
118 | def get_segments(input, timestep):
119 | """Split entire input into segments of length timestep.
120 |
121 | Args:
122 | input: 1 x total_length x n_frames x ...
123 | timestep: the timestamp.
124 |
125 | Returns:
126 | input: concatenated video segments
127 | start_indices: indices of the segments
128 | """
129 | assert input.size(0) == 1, 'Test time, batch_size must be 1'
130 |
131 | input.squeeze_(dim=0)
132 | # Find overlapping segments
133 | length = input.size()[0]
134 | step = timestep // 2
135 | num_segments = (length - timestep) // step + 1
136 | start_indices = (np.arange(num_segments) * step).tolist()
137 | if length % step > 0:
138 | start_indices.append(length - timestep)
139 |
140 | # Get the segments
141 | segments = []
142 | for s in start_indices:
143 | segment = input[s: (s + timestep)].unsqueeze(0)
144 | segments.append(segment)
145 | input = torch.cat(segments, dim=0)
146 | return input, start_indices
147 |
148 | def get_stats(logit, label):
149 | '''
150 | Calculate the accuracy.
151 | '''
152 | logit = to_numpy(logit)
153 | label = to_numpy(label)
154 |
155 | pred = np.argmax(logit, 1)
156 | acc = np.sum(pred == label)/label.shape[0]
157 |
158 | return acc, pred, label
159 |
160 |
161 | def get_stats_detection(logit, label, n_classes=52):
162 | '''
163 | Calculate the accuracy and average precisions.
164 | '''
165 | logit = to_numpy(logit)
166 | label = to_numpy(label)
167 | scores = softmax(logit, axis=1)
168 |
169 | pred = np.argmax(logit, 1)
170 | length = label.shape[0]
171 | acc = np.sum(pred == label)/length
172 |
173 | keep_bg = label == 0
174 | acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
175 | ratio_bg = np.sum(keep_bg)/length
176 |
177 | keep_action = label != 0
178 | acc_action = np.sum(
179 | pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
180 |
181 | # Average precision
182 | y_true = np.zeros((len(label), n_classes))
183 | y_true[np.arange(len(label)), label] = 1
184 | acc = np.sum(pred == label)/label.shape[0]
185 | aps = average_precision_score(y_true, scores, average=None)
186 | aps = list(filter(lambda x: not np.isnan(x), aps))
187 | ap = np.mean(aps)
188 |
189 | return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
190 |
191 |
192 | def info(text):
193 | print('\033[94m' + text + '\033[0m')
194 |
195 |
196 | def warn(text):
197 | print('\033[93m' + text + '\033[0m')
198 |
199 |
200 | def err(text):
201 | print('\033[91m' + text + '\033[0m')
202 |
--------------------------------------------------------------------------------
/utils/skelproc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Process Skeleton Feature."""
17 |
18 | import numpy as np
19 | import utils
20 |
21 |
22 | def read_skel(dset, path):
23 | """
24 | :param dset: name of dataset, either 'ntu-rgbd' or 'pku-mmd'
25 | :param path: path to the skeleton file
26 | :return:
27 | """
28 | if dset == 'ntu-rgbd':
29 | file = open(path, 'r')
30 | lines = file.readlines()
31 | num_lines = len(lines)
32 | num_frames = int(lines[0])
33 | # print(num_lines, num_frames)
34 |
35 | line_id = 1
36 | data = []
37 | for i in range(num_frames):
38 | num_skels = int(lines[line_id])
39 | # print(num_skels)
40 |
41 | joints = []
42 | for _ in range(num_skels):
43 | num_joints = int(lines[line_id+2])
44 | # print(num_joints)
45 |
46 | joint = []
47 | for k in range(num_joints):
48 | tmp = lines[line_id+3+k].rstrip().split(' ')
49 | x_3d, y_3d, z_3d, x_depth, y_depth, x_rgb, y_rgb, orientation_w,\
50 | orientation_x, orientation_y, orientation_z = list(
51 | map(float, tmp[:-1]))
52 | joint.append([x_3d, y_3d, z_3d])
53 | joints.append(joint)
54 | line_id += 2+num_joints
55 | joints = np.array(joints)
56 | data.append(joints)
57 | line_id += 1
58 |
59 | assert line_id == num_lines
60 |
61 | elif dset == 'pku-mmd':
62 | file = open(path, 'r')
63 | lines = file.readlines()
64 | # num_lines = len(lines)
65 |
66 | data = []
67 | for line in lines:
68 | joints = list(map(float, line.rstrip().split(' ')))
69 | joints = np.array(joints).reshape(2, -1, 3)
70 |
71 | if not np.any(joints[1]):
72 | joints = joints[0][np.newaxis, :, :]
73 |
74 | data.append(joints)
75 |
76 | elif dset == 'cad-60':
77 | f = open(path, 'r')
78 | lines = f.readlines()
79 | data = []
80 |
81 | # Last line is "END"
82 | for line in lines[:-1]:
83 | # fist item is frame number, last item is empty
84 | row = line.split(',')[1:-1]
85 | row = list(map(float, row))
86 | joints = []
87 | for i in range(15):
88 | if i < 11:
89 | # First 11 joints
90 | index = 14 * i + 10
91 | else:
92 | # Joint 12 ~ 15
93 | index = 11 * 14 + (i - 11) * 4
94 | joint = row[index: index+3]
95 | joints.append(joint)
96 | joints = np.array(joints) / 1000.0 # millimeter to meter
97 | joints = joints[np.newaxis, :, :] # To match ntu-rgb format
98 | data.append(joints)
99 |
100 | else:
101 | raise NotImplementedError
102 |
103 | return data
104 |
105 |
106 | def flip_skel(skel, dset):
107 | """processed skel (normalized and center shifted to the origin)."""
108 | # Shape: (N x NUM_JOINTS x 3)
109 | if dset == 'cad-60':
110 | num_joints = 15
111 | assert skel.ndim == 3 and skel.shape[1] == num_joints
112 | assert np.sum(np.mean(skel, axis=(0, 1))) < 1e-8, 'Skeleton not centered.'
113 | new_skel = skel.copy()
114 | # Head, neck, torso
115 | new_skel[:, 0, 0] = -skel[:, 0, 0]
116 | new_skel[:, 1, 0] = -skel[:, 1, 0]
117 | new_skel[:, 2, 0] = -skel[:, 2, 0]
118 | # Shoulder
119 | new_skel[:, 3, 0] = -skel[:, 5, 0]
120 | new_skel[:, 5, 0] = -skel[:, 3, 0]
121 | new_skel[:, 3, 1:] = skel[:, 5, 1:]
122 | new_skel[:, 5, 1:] = skel[:, 3, 1:]
123 | # elbow
124 | new_skel[:, 4, 0] = -skel[:, 6, 0]
125 | new_skel[:, 6, 0] = -skel[:, 4, 0]
126 | new_skel[:, 4, 1:] = skel[:, 6, 1:]
127 | new_skel[:, 6, 1:] = skel[:, 4, 1:]
128 | # hip
129 | new_skel[:, 7, 0] = -skel[:, 9, 0]
130 | new_skel[:, 9, 0] = -skel[:, 7, 0]
131 | new_skel[:, 7, 1:] = skel[:, 9, 1:]
132 | new_skel[:, 9, 1:] = skel[:, 7, 1:]
133 | # knee
134 | new_skel[:, 8, 0] = -skel[:, 10, 0]
135 | new_skel[:, 10, 0] = -skel[:, 8, 0]
136 | new_skel[:, 8, 1:] = skel[:, 10, 1:]
137 | new_skel[:, 10, 1:] = skel[:, 8, 1:]
138 | # hand
139 | new_skel[:, 11, 0] = -skel[:, 12, 0]
140 | new_skel[:, 12, 0] = -skel[:, 11, 0]
141 | new_skel[:, 11, 1:] = skel[:, 12, 1:]
142 | new_skel[:, 12, 1:] = skel[:, 11, 1:]
143 | # foot
144 | new_skel[:, 13, 0] = -skel[:, 14, 0]
145 | new_skel[:, 14, 0] = -skel[:, 13, 0]
146 | new_skel[:, 13, 1:] = skel[:, 14, 1:]
147 | new_skel[:, 14, 1:] = skel[:, 13, 1:]
148 | return new_skel
149 |
150 |
151 | def pad_skel(skel, axis=1):
152 | if skel.shape[axis] == 1:
153 | skel = np.repeat(skel, 2, axis=axis)
154 | return skel
155 |
156 |
157 | def extract(skel):
158 | """Extract. timestep x 2 x num_joints (25) x 3"""
159 | timestep = skel.shape[0]
160 | keep_joints = [1, 4, 6, 8, 10, 12, 14, 16, 18, 20, 21]
161 | skel = skel[:, :, [keep_joint-1 for keep_joint in keep_joints]]
162 | skel = skel.reshape(timestep, -1, 3) # timestep x 22 x 3
163 | num_joints = len(keep_joints)
164 |
165 | jjd, jjv = [], []
166 | for t in range(timestep):
167 | jjd_t, jjv_t = [], []
168 | for i in range(num_joints):
169 | for j in range(i, num_joints, 1):
170 | # joint-joint distance
171 | jjd_t.append(utils.l2_norm(skel[t, i], skel[t, j]))
172 |
173 | # joint-joint vector
174 | jjv_t.append(skel[t, i]-skel[t, j])
175 | jjd.append(jjd_t)
176 | jjv.append(jjv_t)
177 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """TODO: One-sentence doc string."""
17 |
18 | import os
19 | import cv2
20 | import numpy as np
21 | import utils
22 | from utils import imgproc
23 |
24 |
25 | def visualize_rgb(images):
26 | """Visualize RGB modality."""
27 | images = utils.to_numpy(images)
28 |
29 | mean = np.array([0.485, 0.456, 0.406])
30 | std = np.array([0.229, 0.224, 0.225])
31 | images = np.moveaxis(images, -3, -1)
32 | images = images*std+mean
33 | images = np.clip(images*255, 0, 255)
34 | images = images[..., ::-1].astype(np.uint8)
35 | images = images[0, 0] # subsample
36 |
37 | imgproc.save_avi('/home/luoa/research/rgb.avi', images)
38 |
39 |
40 | def visualize_oflow(images):
41 | """Visualize optical flow modality."""
42 | images = utils.to_numpy(images)
43 |
44 | images = np.moveaxis(images, -3, -1)
45 | images = images[0, 0] # subsample
46 |
47 | images = imgproc.proc_oflow(images)
48 | imgproc.save_avi('/home/luoa/research/oflow.avi', images)
49 |
50 |
51 | def visualize_warp(rgb, oflow):
52 | """TODO: add info."""
53 | rgb = utils.to_numpy(rgb)
54 | oflow = utils.to_numpy(oflow)
55 |
56 | mean = np.array([0.485, 0.456, 0.406])
57 | std = np.array([0.229, 0.224, 0.225])
58 | rgb = np.moveaxis(rgb, -3, -1)
59 | rgb = rgb*std+mean
60 | rgb = np.clip(rgb*255, 0, 255)
61 | bgr = rgb[..., ::-1].astype(np.uint8)
62 | bgr = bgr[0, 0] # subsample
63 | print(bgr.shape, np.amin(bgr), np.amax(bgr), np.mean(bgr),
64 | np.mean(np.absolute(bgr)))
65 |
66 | oflow = np.moveaxis(oflow, -3, -1)
67 | oflow = oflow[0, 0] # subsample
68 | print(oflow.shape, np.amin(oflow), np.amax(oflow), np.mean(oflow),
69 | np.mean(np.absolute(oflow)))
70 |
71 | warp = imgproc.warp(bgr[4], bgr[5], oflow[4])
72 |
73 | root = '/home/luoa/research'
74 | cv2.imwrite(os.path.join(root, 'bgr1.jpg'), bgr[4])
75 | cv2.imwrite(os.path.join(root, 'bgr2.jpg'), bgr[5])
76 | cv2.imwrite(os.path.join(root, 'warp.jpg'), warp)
77 |
--------------------------------------------------------------------------------