├── .gitignore ├── LICENSE ├── asset ├── arch.png └── finetune.png ├── backbone ├── __init__.py ├── convrnn.py ├── resnet_2d3d.py └── select_backbone.py ├── eval ├── model_lc.py └── test.py ├── memdpc ├── dataset.py ├── main.py └── model.py ├── process_data ├── readme.md └── src │ ├── extract_ff.py │ ├── resize_video.py │ └── write_csv.py ├── readme.md └── utils ├── __init__.py ├── augmentation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | */*/__pycache__/ 3 | *.pyc 4 | *.pth.tar 5 | *.tar 6 | *.pth 7 | *.cluster.local 8 | *tfevent* 9 | *.png 10 | *.pkl 11 | */tmp/ 12 | process_data/data/ 13 | !asset/*.png 14 | #!asset/*/*.png 15 | #!asset/*/*/*.png 16 | *.nfs* 17 | */model/* 18 | *.tsv 19 | *.pbtxt 20 | *.svg 21 | */test_log.md 22 | *test_log.md 23 | *notes.md 24 | *.pdf 25 | *.csv 26 | *.txt 27 | *.log 28 | *.json 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [Tengda Han, Weidi Xie, Andrew Zisserman] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /asset/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/MemDPC/d7dbbf0dc6ec4aa8ff9a5dc8c189d78f4e5e34a7/asset/arch.png -------------------------------------------------------------------------------- /asset/finetune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/MemDPC/d7dbbf0dc6ec4aa8ff9a5dc8c189d78f4e5e34a7/asset/finetune.png -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/MemDPC/d7dbbf0dc6ec4aa8ff9a5dc8c189d78f4e5e34a7/backbone/__init__.py -------------------------------------------------------------------------------- /backbone/convrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvGRUCell(nn.Module): 5 | ''' Initialize ConvGRU cell ''' 6 | def __init__(self, input_size, hidden_size, kernel_size): 7 | super(ConvGRUCell, self).__init__() 8 | self.input_size = input_size 9 | self.hidden_size = hidden_size 10 | self.kernel_size = kernel_size 11 | padding = kernel_size // 2 12 | 13 | self.reset_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 14 | self.update_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 15 | self.out_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 16 | 17 | nn.init.orthogonal_(self.reset_gate.weight) 18 | nn.init.orthogonal_(self.update_gate.weight) 19 | nn.init.orthogonal_(self.out_gate.weight) 20 | nn.init.constant_(self.reset_gate.bias, 0.) 21 | nn.init.constant_(self.update_gate.bias, 0.) 22 | nn.init.constant_(self.out_gate.bias, 0.) 23 | 24 | def forward(self, input_tensor, hidden_state): 25 | if hidden_state is None: 26 | B, C, *spatial_dim = input_tensor.size() 27 | hidden_state = torch.zeros([B,self.hidden_size,*spatial_dim]).cuda() 28 | # [B, C, H, W] 29 | combined = torch.cat([input_tensor, hidden_state], dim=1) #concat in C 30 | update = torch.sigmoid(self.update_gate(combined)) 31 | reset = torch.sigmoid(self.reset_gate(combined)) 32 | out = torch.tanh(self.out_gate(torch.cat([input_tensor, hidden_state * reset], dim=1))) 33 | new_state = hidden_state * (1 - update) + out * update 34 | return new_state 35 | 36 | 37 | class ConvGRU(nn.Module): 38 | ''' Initialize a multi-layer Conv GRU ''' 39 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, dropout=0.1): 40 | super(ConvGRU, self).__init__() 41 | self.input_size = input_size 42 | self.hidden_size = hidden_size 43 | self.kernel_size = kernel_size 44 | self.num_layers = num_layers 45 | 46 | cell_list = [] 47 | for i in range(self.num_layers): 48 | if i == 0: 49 | input_dim = self.input_size 50 | else: 51 | input_dim = self.hidden_size 52 | cell = ConvGRUCell(input_dim, self.hidden_size, self.kernel_size) 53 | name = 'ConvGRUCell_' + str(i).zfill(2) 54 | 55 | setattr(self, name, cell) 56 | cell_list.append(getattr(self, name)) 57 | 58 | self.cell_list = nn.ModuleList(cell_list) 59 | self.dropout_layer = nn.Dropout(p=dropout) 60 | 61 | 62 | def forward(self, x, hidden_state=None): 63 | [B, seq_len, *_] = x.size() 64 | 65 | if hidden_state is None: 66 | hidden_state = [None] * self.num_layers 67 | # input: image sequences [B, T, C, H, W] 68 | current_layer_input = x 69 | del x 70 | 71 | last_state_list = [] 72 | 73 | for idx in range(self.num_layers): 74 | cell_hidden = hidden_state[idx] 75 | output_inner = [] 76 | for t in range(seq_len): 77 | cell_hidden = self.cell_list[idx](current_layer_input[:,t,:], cell_hidden) 78 | cell_hidden = self.dropout_layer(cell_hidden) # dropout in each time step 79 | output_inner.append(cell_hidden) 80 | 81 | layer_output = torch.stack(output_inner, dim=1) 82 | current_layer_input = layer_output 83 | 84 | last_state_list.append(cell_hidden) 85 | 86 | last_state_list = torch.stack(last_state_list, dim=1) 87 | 88 | return layer_output, last_state_list 89 | 90 | -------------------------------------------------------------------------------- /backbone/resnet_2d3d.py: -------------------------------------------------------------------------------- 1 | ## modified from https://github.com/kenshohara/3D-ResNets-PyTorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | 8 | __all__ = [ 9 | 'ResNet2d3d_full', 'resnet18_2d3d_full', 'resnet34_2d3d_full', 'resnet50_2d3d_full' 10 | ] 11 | 12 | def conv3x3x3(in_planes, out_planes, stride=1, bias=False): 13 | # 3x3x3 convolution with padding 14 | return nn.Conv3d( 15 | in_planes, 16 | out_planes, 17 | kernel_size=3, 18 | stride=stride, 19 | padding=1, 20 | bias=bias) 21 | 22 | def conv1x3x3(in_planes, out_planes, stride=1, bias=False): 23 | # 1x3x3 convolution with padding 24 | return nn.Conv3d( 25 | in_planes, 26 | out_planes, 27 | kernel_size=(1,3,3), 28 | stride=(1,stride,stride), 29 | padding=(0,1,1), 30 | bias=bias) 31 | 32 | 33 | def downsample_basic_block(x, planes, stride): 34 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 35 | zero_pads = torch.Tensor( 36 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 37 | out.size(4)).zero_() 38 | if isinstance(out.data, torch.cuda.FloatTensor): 39 | zero_pads = zero_pads.cuda() 40 | 41 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 42 | 43 | return out 44 | 45 | 46 | class BasicBlock3d(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 50 | super(BasicBlock3d, self).__init__() 51 | bias = False 52 | self.use_final_relu = use_final_relu 53 | self.conv1 = conv3x3x3(inplanes, planes, stride, bias=bias) 54 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3x3(planes, planes, bias=bias) 58 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | if self.use_final_relu: out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class BasicBlock2d(nn.Module): 83 | expansion = 1 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 86 | super(BasicBlock2d, self).__init__() 87 | bias = False 88 | self.use_final_relu = use_final_relu 89 | self.conv1 = conv1x3x3(inplanes, planes, stride, bias=bias) 90 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 91 | 92 | self.relu = nn.ReLU(inplace=True) 93 | self.conv2 = conv1x3x3(planes, planes, bias=bias) 94 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 95 | 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | residual = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | if self.use_final_relu: out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class Bottleneck3d(nn.Module): 119 | expansion = 4 120 | 121 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 122 | super(Bottleneck3d, self).__init__() 123 | bias = False 124 | self.use_final_relu = use_final_relu 125 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 126 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 127 | 128 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias) 129 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 130 | 131 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 132 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 133 | 134 | self.relu = nn.ReLU(inplace=True) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | out = self.relu(out) 148 | 149 | out = self.conv3(out) 150 | out = self.bn3(out) 151 | 152 | if self.downsample is not None: 153 | residual = self.downsample(x) 154 | 155 | out += residual 156 | if self.use_final_relu: out = self.relu(out) 157 | 158 | return out 159 | 160 | 161 | class Bottleneck2d(nn.Module): 162 | expansion = 4 163 | 164 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 165 | super(Bottleneck2d, self).__init__() 166 | bias = False 167 | self.use_final_relu = use_final_relu 168 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 169 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 170 | 171 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 172 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 173 | 174 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 175 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 176 | 177 | self.relu = nn.ReLU(inplace=True) 178 | self.downsample = downsample 179 | self.stride = stride 180 | 181 | def forward(self, x): 182 | residual = x 183 | 184 | out = self.conv1(x) 185 | if self.batchnorm: out = self.bn1(out) 186 | out = self.relu(out) 187 | 188 | out = self.conv2(out) 189 | if self.batchnorm: out = self.bn2(out) 190 | out = self.relu(out) 191 | 192 | out = self.conv3(out) 193 | if self.batchnorm: out = self.bn3(out) 194 | 195 | if self.downsample is not None: 196 | residual = self.downsample(x) 197 | 198 | out += residual 199 | if self.use_final_relu: out = self.relu(out) 200 | 201 | return out 202 | 203 | 204 | class ResNet2d3d_full(nn.Module): 205 | def __init__(self, block, layers, track_running_stats=True): 206 | super(ResNet2d3d_full, self).__init__() 207 | self.inplanes = 64 208 | self.track_running_stats = track_running_stats 209 | bias = False 210 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1,7,7), stride=(1, 2, 2), padding=(0, 3, 3), bias=bias) 211 | self.bn1 = nn.BatchNorm3d(64, track_running_stats=track_running_stats) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 214 | 215 | if not isinstance(block, list): 216 | block = [block] * 4 217 | 218 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 219 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2) 220 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2) 221 | self.layer4 = self._make_layer(block[3], 256, layers[3], stride=2, is_final=True) 222 | # modify layer4 from exp=512 to exp=256 223 | for m in self.modules(): 224 | if isinstance(m, nn.Conv3d): 225 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 226 | if m.bias is not None: m.bias.data.zero_() 227 | elif isinstance(m, nn.BatchNorm3d): 228 | m.weight.data.fill_(1) 229 | m.bias.data.zero_() 230 | 231 | def _make_layer(self, block, planes, blocks, stride=1, is_final=False): 232 | downsample = None 233 | if stride != 1 or self.inplanes != planes * block.expansion: 234 | # customized_stride to deal with 2d or 3d residual blocks 235 | if (block == Bottleneck2d) or (block == BasicBlock2d): 236 | customized_stride = (1, stride, stride) 237 | else: 238 | customized_stride = stride 239 | 240 | downsample = nn.Sequential( 241 | nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=customized_stride, bias=False), 242 | nn.BatchNorm3d(planes * block.expansion, track_running_stats=self.track_running_stats) 243 | ) 244 | 245 | layers = [] 246 | layers.append(block(self.inplanes, planes, stride, downsample, track_running_stats=self.track_running_stats)) 247 | self.inplanes = planes * block.expansion 248 | if is_final: # if is final block, no ReLU in the final output 249 | for i in range(1, blocks-1): 250 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 251 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats, use_final_relu=False)) 252 | else: 253 | for i in range(1, blocks): 254 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 255 | 256 | return nn.Sequential(*layers) 257 | 258 | def forward(self, x): 259 | x = self.conv1(x) 260 | x = self.bn1(x) 261 | x = self.relu(x) 262 | x = self.maxpool(x) 263 | 264 | x = self.layer1(x) 265 | x = self.layer2(x) 266 | x = self.layer3(x) 267 | x = self.layer4(x) 268 | 269 | return x 270 | 271 | 272 | def resnet18_2d3d_full(**kwargs): 273 | '''Constructs a ResNet-18 model. ''' 274 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 275 | [2, 2, 2, 2], **kwargs) 276 | return model 277 | 278 | def resnet34_2d3d_full(**kwargs): 279 | '''Constructs a ResNet-34 model. ''' 280 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 281 | [3, 4, 6, 3], **kwargs) 282 | return model 283 | 284 | def resnet50_2d3d_full(**kwargs): 285 | '''Constructs a ResNet-50 model. ''' 286 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 287 | [3, 4, 6, 3], **kwargs) 288 | return model 289 | 290 | -------------------------------------------------------------------------------- /backbone/select_backbone.py: -------------------------------------------------------------------------------- 1 | from .resnet_2d3d import * 2 | 3 | def select_resnet(network,): 4 | param = {'feature_size': 1024} 5 | if network == 'resnet18': 6 | model = resnet18_2d3d_full(track_running_stats=True) 7 | param['feature_size'] = 256 8 | elif network == 'resnet34': 9 | model = resnet34_2d3d_full(track_running_stats=True) 10 | param['feature_size'] = 256 11 | elif network == 'resnet50': 12 | model = resnet50_2d3d_full(track_running_stats=True) 13 | else: 14 | raise NotImplementedError 15 | 16 | return model, param -------------------------------------------------------------------------------- /eval/model_lc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | sys.path.append('../') 5 | from backbone.select_backbone import select_resnet 6 | from backbone.convrnn import ConvGRU 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class LC(nn.Module): 13 | '''Linear Classifier''' 14 | def __init__(self, sample_size, num_seq, seq_len, 15 | network='resnet18', dropout=0.5, num_class=101, train_what='all'): 16 | super(LC, self).__init__() 17 | torch.cuda.manual_seed(666) 18 | self.sample_size = sample_size 19 | self.num_seq = num_seq 20 | self.seq_len = seq_len 21 | self.num_class = num_class 22 | self.train_what = train_what 23 | 24 | print('=> Use 2D-3D %s backbone' % network) 25 | self.last_duration = int(math.ceil(seq_len / 4)) 26 | self.last_size = int(math.ceil(sample_size / 32)) 27 | 28 | self.backbone, self.param = select_resnet(network) 29 | self.param['num_layers'] = 1 30 | self.param['hidden_size'] = self.param['feature_size'] 31 | 32 | self.agg_f = ConvGRU(input_size=self.param['feature_size'], 33 | hidden_size=self.param['hidden_size'], 34 | kernel_size=1, 35 | num_layers=self.param['num_layers']) 36 | self.agg_b = ConvGRU(input_size=self.param['feature_size'], 37 | hidden_size=self.param['hidden_size'], 38 | kernel_size=1, 39 | num_layers=self.param['num_layers']) 40 | self._initialize_weights(self.agg_f) 41 | self._initialize_weights(self.agg_b) 42 | 43 | self.final_bn = nn.BatchNorm1d(self.param['feature_size']*2) 44 | self.final_bn.weight.data.fill_(1) 45 | self.final_bn.bias.data.zero_() 46 | 47 | self.final_fc = nn.Sequential(nn.Dropout(dropout), 48 | nn.Linear(self.param['feature_size']*2, self.num_class)) 49 | self._initialize_weights(self.final_fc) 50 | 51 | def forward(self, block): 52 | # seq1: [B, N, C, SL, W, H] 53 | (B, N, C, SL, H, W) = block.shape 54 | block = block.view(B*N, C, SL, H, W) 55 | enable_grad = self.train_what!='last' 56 | with torch.set_grad_enabled(enable_grad): 57 | feature = self.backbone(block) 58 | feature = F.relu(feature) 59 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=1) 60 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B*N,D,last_size,last_size] 61 | 62 | context_forward, _ = self.agg_f(feature) 63 | context_forward = context_forward[:,-1,:].unsqueeze(1) 64 | context_forward = F.avg_pool3d(context_forward, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 65 | 66 | feature_back = torch.flip(feature, dims=(1,)) 67 | context_back, _ = self.agg_b(feature_back) 68 | context_back = context_back[:,-1,:].unsqueeze(1) 69 | context_back = F.avg_pool3d(context_back, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 70 | 71 | context = torch.cat([context_forward, context_back], dim=-1) # B,N,C=2C 72 | 73 | context = self.final_bn(context.transpose(-1,-2)).transpose(-1,-2) # [B,N,C] -> [B,C,N] -> BN() -> [B,N,C], because BN operates on id=1 channel. 74 | output = self.final_fc(context).view(B, -1, self.num_class) 75 | 76 | return output, context 77 | 78 | def _initialize_weights(self, module): 79 | for name, param in module.named_parameters(): 80 | if 'bias' in name: 81 | nn.init.constant_(param, 0.0) 82 | elif 'weight' in name: 83 | nn.init.orthogonal_(param, 1) 84 | 85 | 86 | -------------------------------------------------------------------------------- /eval/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import re 6 | import numpy as np 7 | import random 8 | import json 9 | from tqdm import tqdm 10 | from tensorboardX import SummaryWriter 11 | 12 | sys.path.append('../') 13 | sys.path.append('../memdpc/') 14 | from dataset import UCF101Dataset, HMDB51Dataset 15 | from model_lc import LC 16 | import utils.augmentation as A 17 | from utils.utils import AverageMeter, ConfusionMeter, save_checkpoint, \ 18 | calc_topk_accuracy, denorm, calc_accuracy, neq_load_customized, Logger 19 | 20 | import torch 21 | import torch.optim as optim 22 | from torch.utils import data 23 | import torch.nn as nn 24 | from torchvision import datasets, models, transforms 25 | import torchvision.utils as vutils 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--net', default='resnet18', type=str) 30 | parser.add_argument('--model', default='lc', type=str) 31 | parser.add_argument('--dataset', default='ucf101', type=str) 32 | parser.add_argument('--split', default=1, type=int) 33 | parser.add_argument('--seq_len', default=5, type=int) 34 | parser.add_argument('--num_seq', default=8, type=int) 35 | parser.add_argument('--num_class', default=101, type=int) 36 | parser.add_argument('--dropout', default=0.9, type=float) 37 | parser.add_argument('--ds', default=3, type=int) 38 | parser.add_argument('--batch_size', default=4, type=int) 39 | parser.add_argument('--lr', default=1e-3, type=float) 40 | parser.add_argument('--schedule', default=[], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x)') 41 | parser.add_argument('--wd', default=1e-3, type=float, help='weight decay') 42 | parser.add_argument('--resume', default='', type=str) 43 | parser.add_argument('--pretrain', default='random', type=str) 44 | parser.add_argument('--test', default='', type=str) 45 | parser.add_argument('--center_crop', action='store_true') 46 | parser.add_argument('--five_crop', action='store_true') 47 | parser.add_argument('--ten_crop', action='store_true') 48 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 49 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 50 | parser.add_argument('--gpu', default='0,1', type=str) 51 | parser.add_argument('--print_freq', default=5, type=int) 52 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 53 | parser.add_argument('--train_what', default='last', type=str, help='Train what parameters?') 54 | parser.add_argument('--prefix', default='tmp', type=str) 55 | parser.add_argument('--img_dim', default=128, type=int) 56 | parser.add_argument('--seed', default=0, type=int) 57 | parser.add_argument('-j', '--workers', default=16, type=int) 58 | args = parser.parse_args() 59 | return args 60 | 61 | 62 | def main(args): 63 | torch.manual_seed(args.seed) 64 | np.random.seed(args.seed) 65 | random.seed(args.seed) 66 | 67 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 68 | device = torch.device('cuda') 69 | num_gpu = len(str(args.gpu).split(',')) 70 | args.batch_size = num_gpu * args.batch_size 71 | 72 | if args.dataset == 'ucf101': args.num_class = 101 73 | elif args.dataset == 'hmdb51': args.num_class = 51 74 | 75 | ### classifier model ### 76 | if args.model == 'lc': 77 | model = LC(sample_size=args.img_dim, 78 | num_seq=args.num_seq, 79 | seq_len=args.seq_len, 80 | network=args.net, 81 | num_class=args.num_class, 82 | dropout=args.dropout, 83 | train_what=args.train_what) 84 | else: 85 | raise ValueError('wrong model!') 86 | 87 | model.to(device) 88 | model = nn.DataParallel(model) 89 | model_without_dp = model.module 90 | criterion = nn.CrossEntropyLoss() 91 | 92 | ### optimizer ### 93 | params = None 94 | if args.train_what == 'ft': 95 | print('=> finetune backbone with smaller lr') 96 | params = [] 97 | for name, param in model.module.named_parameters(): 98 | if ('resnet' in name) or ('rnn' in name): 99 | params.append({'params': param, 'lr': args.lr/10}) 100 | else: 101 | params.append({'params': param}) 102 | elif args.train_what == 'last': 103 | print('=> train only last layer') 104 | params = [] 105 | for name, param in model.named_parameters(): 106 | if ('bone' in name) or ('agg' in name) or ('mb' in name) or ('network_pred' in name): 107 | param.requires_grad = False 108 | else: params.append({'params': param}) 109 | else: 110 | pass # train all layers 111 | 112 | print('\n===========Check Grad============') 113 | for name, param in model.named_parameters(): 114 | print(name, param.requires_grad) 115 | print('=================================\n') 116 | 117 | if params is None: params = model.parameters() 118 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 119 | 120 | ### scheduler ### 121 | if args.dataset == 'hmdb51': 122 | step = args.schedule 123 | if step == []: step = [150,250] 124 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=step, repeat=1) 125 | elif args.dataset == 'ucf101': 126 | step = args.schedule 127 | if step == []: step = [300, 400] 128 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=step, repeat=1) 129 | lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 130 | print('=> Using scheduler at {} epochs'.format(step)) 131 | 132 | args.old_lr = None 133 | best_acc = 0 134 | args.iteration = 1 135 | 136 | ### if in test mode ### 137 | if args.test: 138 | if os.path.isfile(args.test): 139 | print("=> loading test checkpoint '{}'".format(args.test)) 140 | checkpoint = torch.load(args.test, map_location=torch.device('cpu')) 141 | try: 142 | model_without_dp.load_state_dict(checkpoint['state_dict']) 143 | except: 144 | print('=> [Warning]: weight structure is not equal to test model; Load anyway ==') 145 | model_without_dp = neq_load_customized(model_without_dp, checkpoint['state_dict']) 146 | epoch = checkpoint['epoch'] 147 | print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch'])) 148 | elif args.test == 'random': 149 | epoch = 0 150 | print("=> loaded random weights") 151 | else: 152 | print("=> no checkpoint found at '{}'".format(args.test)) 153 | sys.exit(0) 154 | 155 | args.logger = Logger(path=os.path.dirname(args.test)) 156 | _, test_dataset = get_data(None, 'test') 157 | test_loss, test_acc = test(test_dataset, model, criterion, device, epoch, args) 158 | sys.exit() 159 | 160 | ### restart training ### 161 | if args.resume: 162 | if os.path.isfile(args.resume): 163 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 164 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 165 | args.start_epoch = checkpoint['epoch'] 166 | args.iteration = checkpoint['iteration'] 167 | best_acc = checkpoint['best_acc'] 168 | model_without_dp.load_state_dict(checkpoint['state_dict']) 169 | try: 170 | optimizer.load_state_dict(checkpoint['optimizer']) 171 | except: 172 | print('[WARNING] Not loading optimizer states') 173 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 174 | else: 175 | print("=> no checkpoint found at '{}'".format(args.resume)) 176 | sys.exit(0) 177 | 178 | if (not args.resume) and args.pretrain: 179 | if args.pretrain == 'random': 180 | print('=> using random weights') 181 | elif os.path.isfile(args.pretrain): 182 | print("=> loading pretrained checkpoint '{}'".format(args.pretrain)) 183 | checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) 184 | model_without_dp = neq_load_customized(model_without_dp, checkpoint['state_dict']) 185 | print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch'])) 186 | else: 187 | print("=> no checkpoint found at '{}'".format(args.pretrain)) 188 | sys.exit(0) 189 | 190 | ### data ### 191 | transform = transforms.Compose([ 192 | A.RandomSizedCrop(consistent=True, size=224, p=1.0), 193 | A.Scale(size=(args.img_dim,args.img_dim)), 194 | A.RandomHorizontalFlip(consistent=True), 195 | A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 196 | A.ToTensor(), 197 | A.Normalize() 198 | ]) 199 | val_transform = transforms.Compose([ 200 | A.RandomSizedCrop(consistent=True, size=224, p=0.3), 201 | A.Scale(size=(args.img_dim,args.img_dim)), 202 | A.RandomHorizontalFlip(consistent=True), 203 | A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 204 | A.ToTensor(), 205 | A.Normalize() 206 | ]) 207 | 208 | train_loader, _ = get_data(transform, 'train') 209 | val_loader, _ = get_data(val_transform, 'val') 210 | 211 | # setup tools 212 | args.img_path, args.model_path = set_path(args) 213 | args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val')) 214 | args.writer_train = SummaryWriter(logdir=os.path.join(args.img_path, 'train')) 215 | torch.backends.cudnn.benchmark = True 216 | 217 | ### main loop ### 218 | for epoch in range(args.start_epoch, args.epochs): 219 | train_loss, train_acc = train_one_epoch(train_loader, 220 | model, 221 | criterion, 222 | optimizer, 223 | device, 224 | epoch, 225 | args) 226 | val_loss, val_acc = validate(val_loader, 227 | model, 228 | criterion, 229 | device, 230 | epoch, 231 | args) 232 | lr_scheduler.step(epoch) 233 | 234 | # save check_point 235 | is_best = val_acc > best_acc 236 | best_acc = max(val_acc, best_acc) 237 | save_dict = { 238 | 'epoch': epoch, 239 | 'backbone': args.net, 240 | 'state_dict': model_without_dp.state_dict(), 241 | 'best_acc': best_acc, 242 | 'optimizer': optimizer.state_dict(), 243 | 'iteration': args.iteration} 244 | save_checkpoint(save_dict, is_best, 245 | filename=os.path.join(args.model_path, 'epoch%s.pth.tar' % str(epoch)), 246 | keep_all=False) 247 | 248 | print('Training from ep %d to ep %d finished' 249 | % (args.start_epoch, args.epochs)) 250 | sys.exit(0) 251 | 252 | 253 | def train_one_epoch(data_loader, model, criterion, optimizer, device, epoch, args): 254 | batch_time = AverageMeter() 255 | data_time = AverageMeter() 256 | losses = AverageMeter() 257 | accuracy = AverageMeter() 258 | 259 | if args.train_what == 'last': 260 | model.eval() 261 | model.module.final_bn.train() 262 | model.module.final_fc.train() 263 | print('[Warning] train model with eval mode, except final layer') 264 | else: 265 | model.train() 266 | 267 | end = time.time() 268 | tic = time.time() 269 | 270 | for idx, (input_seq, target) in enumerate(data_loader): 271 | data_time.update(time.time() - end) 272 | input_seq = input_seq.to(device) 273 | target = target.to(device) 274 | B = input_seq.size(0) 275 | output, _ = model(input_seq) 276 | 277 | [_, N, D] = output.size() 278 | output = output.view(B*N, D) 279 | target = target.repeat(1, N).view(-1) 280 | 281 | loss = criterion(output, target) 282 | acc = calc_accuracy(output, target) 283 | 284 | losses.update(loss.item(), B) 285 | accuracy.update(acc.item(), B) 286 | 287 | optimizer.zero_grad() 288 | loss.backward() 289 | optimizer.step() 290 | 291 | batch_time.update(time.time() - end) 292 | end = time.time() 293 | 294 | if idx % args.print_freq == 0: 295 | print('Epoch: [{0}][{1}/{2}]\t' 296 | 'Loss {loss.val:.4f} ({loss.local_avg:.4f})\t' 297 | 'Acc: {acc.val:.4f} ({acc.local_avg:.4f})\t' 298 | 'T-data:{dt.val:.2f} T-batch:{bt.val:.2f}\t'.format( 299 | epoch, idx, len(data_loader), 300 | loss=losses, acc=accuracy, dt=data_time, bt=batch_time)) 301 | 302 | args.writer_train.add_scalar('local/loss', losses.val, args.iteration) 303 | args.writer_train.add_scalar('local/accuracy', accuracy.val, args.iteration) 304 | 305 | args.iteration += 1 306 | print('Epoch: [{0}]\t' 307 | 'T-epoch:{t:.2f}\t'.format(epoch, t=time.time()-tic)) 308 | 309 | args.writer_train.add_scalar('global/loss', losses.avg, epoch) 310 | args.writer_train.add_scalar('global/accuracy', accuracy.avg, epoch) 311 | 312 | return losses.avg, accuracy.avg 313 | 314 | 315 | def validate(data_loader, model, criterion, device, epoch, args): 316 | losses = AverageMeter() 317 | accuracy = AverageMeter() 318 | model.eval() 319 | with torch.no_grad(): 320 | for idx, (input_seq, target) in tqdm(enumerate(data_loader), total=len(data_loader)): 321 | input_seq = input_seq.to(device) 322 | target = target.to(device) 323 | B = input_seq.size(0) 324 | output, _ = model(input_seq) 325 | 326 | [_, N, D] = output.size() 327 | output = output.view(B*N, D) 328 | target = target.repeat(1, N).view(-1) 329 | 330 | loss = criterion(output, target) 331 | acc = calc_accuracy(output, target) 332 | 333 | losses.update(loss.item(), B) 334 | accuracy.update(acc.item(), B) 335 | 336 | print('Loss {loss.avg:.4f}\t' 337 | 'Acc: {acc.avg:.4f} \t'.format(loss=losses, acc=accuracy)) 338 | args.writer_val.add_scalar('global/loss', losses.avg, epoch) 339 | args.writer_val.add_scalar('global/accuracy', accuracy.avg, epoch) 340 | 341 | return losses.avg, accuracy.avg 342 | 343 | 344 | def test(dataset, model, criterion, device, epoch, args): 345 | # 10-crop then average the probability 346 | prob_dict = {} 347 | model.eval() 348 | 349 | # aug_list: 1,2,3,4,5 = top-left, top-right, bottom-left, bottom-right, center 350 | # flip_list: 0,1 = original, horizontal-flip 351 | if args.center_crop: 352 | print('Test using center crop') 353 | args.logger.log('Test using center_crop\n') 354 | aug_list = [5]; flip_list = [0]; title = 'center' 355 | if args.five_crop: 356 | print('Test using 5 crop') 357 | args.logger.log('Test using 5_crop\n') 358 | aug_list = [5,1,2,3,4]; flip_list = [0]; title = 'five' 359 | if args.ten_crop: 360 | print('Test using 10 crop') 361 | args.logger.log('Test using 10_crop\n') 362 | aug_list = [5,1,2,3,4]; flip_list = [0,1]; title = 'ten' 363 | 364 | with torch.no_grad(): 365 | end = time.time() 366 | for flip_idx in flip_list: 367 | for aug_idx in aug_list: 368 | print('Aug type: %d; flip: %d' % (aug_idx, flip_idx)) 369 | if flip_idx == 0: 370 | transform = transforms.Compose([ 371 | A.RandomHorizontalFlip(command='left'), 372 | A.FiveCrop(size=(224,224), where=aug_idx), 373 | A.Scale(size=(args.img_dim,args.img_dim)), 374 | A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 375 | A.ToTensor(), 376 | ]) 377 | else: 378 | transform = transforms.Compose([ 379 | A.RandomHorizontalFlip(command='right'), 380 | A.FiveCrop(size=(224,224), where=aug_idx), 381 | A.Scale(size=(args.img_dim,args.img_dim)), 382 | A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 383 | A.ToTensor(), 384 | ]) 385 | 386 | dataset.transform = transform 387 | dataset.return_path = True 388 | dataset.return_label = True 389 | data_sampler = data.RandomSampler(dataset) 390 | data_loader = data.DataLoader(dataset, 391 | batch_size=1, 392 | sampler=data_sampler, 393 | shuffle=False, 394 | num_workers=16, 395 | pin_memory=True) 396 | 397 | 398 | for idx, (input_seq, target) in tqdm(enumerate(data_loader), total=len(data_loader)): 399 | B = 1 400 | input_seq = input_seq.to(device) 401 | target, vname = target 402 | target = target.to(device) 403 | input_seq = input_seq.squeeze(0) # squeeze the '1' batch dim 404 | output, _ = model(input_seq) 405 | 406 | prob_mean = nn.functional.softmax(output, 2).mean(1).mean(0, keepdim=True) 407 | 408 | vname = vname[0] 409 | if vname not in prob_dict.keys(): 410 | prob_dict[vname] = [] 411 | prob_dict[vname].append(prob_mean) 412 | 413 | # show intermediate result 414 | if (title == 'ten') and (flip_idx == 0) and (aug_idx == 5): 415 | print('center-crop result:') 416 | acc_1 = summarize_probability(prob_dict, 417 | data_loader.dataset.encode_action, 'center') 418 | args.logger.log('center-crop:') 419 | args.logger.log('test Epoch: [{0}]\t' 420 | 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' 421 | .format(epoch, acc=acc_1)) 422 | 423 | # show intermediate result 424 | if (title == 'ten') and (flip_idx == 0): 425 | print('five-crop result:') 426 | acc_5 = summarize_probability(prob_dict, 427 | data_loader.dataset.encode_action, 'five') 428 | args.logger.log('five-crop:') 429 | args.logger.log('test Epoch: [{0}]\t' 430 | 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' 431 | .format(epoch, acc=acc_5)) 432 | 433 | # show final result 434 | print('%s-crop result:' % title) 435 | acc_final = summarize_probability(prob_dict, 436 | data_loader.dataset.encode_action, 'ten') 437 | args.logger.log('%s-crop:' % title) 438 | args.logger.log('test Epoch: [{0}]\t' 439 | 'Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' 440 | .format(epoch, acc=acc_final)) 441 | sys.exit(0) 442 | 443 | 444 | def summarize_probability(prob_dict, action_to_idx, title): 445 | acc = [AverageMeter(),AverageMeter()] 446 | stat = {} 447 | for vname, item in tqdm(prob_dict.items(), total=len(prob_dict)): 448 | try: 449 | action_name = vname.split('/')[-3] 450 | except: 451 | action_name = vname.split('/')[-2] 452 | target = action_to_idx(action_name) 453 | mean_prob = torch.stack(item, 0).mean(0) 454 | mean_top1, mean_top5 = calc_topk_accuracy(mean_prob, torch.LongTensor([target]).cuda(), (1,5)) 455 | stat[vname] = {'mean_prob': mean_prob.tolist()} 456 | acc[0].update(mean_top1.item(), 1) 457 | acc[1].update(mean_top5.item(), 1) 458 | 459 | print('Mean: Acc@1: {acc[0].avg:.4f} Acc@5: {acc[1].avg:.4f}' 460 | .format(acc=acc)) 461 | 462 | with open(os.path.join(os.path.dirname(args.test), 463 | '%s-prob-%s.json' % (os.path.basename(args.test), title)), 'w') as fp: 464 | json.dump(stat, fp) 465 | return acc 466 | 467 | 468 | def get_data(transform, mode='train'): 469 | print('Loading data for "%s" ...' % mode) 470 | global dataset 471 | if args.dataset == 'ucf101': 472 | dataset = UCF101Dataset(mode=mode, 473 | transform=transform, 474 | seq_len=args.seq_len, 475 | num_seq=args.num_seq, 476 | downsample=args.ds, 477 | which_split=args.split, 478 | return_label=True) 479 | elif args.dataset == 'hmdb51': 480 | dataset = HMDB51Dataset(mode=mode, 481 | transform=transform, 482 | seq_len=args.seq_len, 483 | num_seq=args.num_seq, 484 | downsample=args.ds, 485 | which_split=args.split, 486 | return_label=True) 487 | else: 488 | raise ValueError('dataset not supported') 489 | my_sampler = data.RandomSampler(dataset) 490 | if mode == 'train': 491 | data_loader = data.DataLoader(dataset, 492 | batch_size=args.batch_size, 493 | sampler=my_sampler, 494 | shuffle=False, 495 | num_workers=args.workers, 496 | pin_memory=True, 497 | drop_last=True) 498 | elif mode == 'val': 499 | data_loader = data.DataLoader(dataset, 500 | batch_size=args.batch_size, 501 | sampler=my_sampler, 502 | shuffle=False, 503 | num_workers=args.workers, 504 | pin_memory=True, 505 | drop_last=True) 506 | elif mode == 'test': 507 | data_loader = data.DataLoader(dataset, 508 | batch_size=1, 509 | sampler=my_sampler, 510 | shuffle=False, 511 | num_workers=args.workers, 512 | pin_memory=True) 513 | print('"%s" dataset size: %d' % (mode, len(dataset))) 514 | return data_loader, dataset 515 | 516 | 517 | def set_path(args): 518 | if args.resume: exp_path = os.path.dirname(os.path.dirname(args.resume)) 519 | else: 520 | exp_path = 'log_{args.prefix}/{args.dataset}-{args.img_dim}-\ 521 | sp{args.split}_{args.net}_{args.model}_bs{args.batch_size}_\ 522 | lr{0}_wd{args.wd}_ds{args.ds}_seq{args.num_seq}_len{args.seq_len}_\ 523 | dp{args.dropout}_train-{args.train_what}{1}'.format( 524 | args.old_lr if args.old_lr is not None else args.lr, \ 525 | '_pt='+args.pretrain.replace('/','-') if args.pretrain else '', \ 526 | args=args) 527 | img_path = os.path.join(exp_path, 'img') 528 | model_path = os.path.join(exp_path, 'model') 529 | if not os.path.exists(img_path): os.makedirs(img_path) 530 | if not os.path.exists(model_path): os.makedirs(model_path) 531 | return img_path, model_path 532 | 533 | 534 | def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): 535 | '''return the multipier for LambdaLR, 536 | 0 <= ep < 10: gamma^0 537 | 10 <= ep < 15: gamma^1 538 | 15 <= ep < 20: gamma^2 539 | 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' 540 | max_step = max(step) 541 | effective_epoch = epoch % max_step 542 | if epoch // max_step >= repeat: 543 | exp = len(step) - 1 544 | else: 545 | exp = len([i for i in step if effective_epoch>=i]) 546 | return gamma ** exp 547 | 548 | 549 | if __name__ == '__main__': 550 | args = parse_args() 551 | main(args) 552 | -------------------------------------------------------------------------------- /memdpc/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import glob 9 | import csv 10 | import pandas as pd 11 | import numpy as np 12 | from tqdm import tqdm 13 | from PIL import Image 14 | 15 | def read_file(path): 16 | with open(path, 'r') as f: 17 | content = f.readlines() 18 | content = [i.strip() for i in content] 19 | return content 20 | 21 | def pil_loader(path): 22 | with open(path, 'rb') as f: 23 | with Image.open(f) as img: 24 | return img.convert('RGB') 25 | 26 | 27 | class K400Dataset(data.Dataset): 28 | def __init__(self, 29 | root='%s/../process_data/data/k400' % os.path.dirname(os.path.abspath(__file__)), 30 | mode='val', 31 | transform=None, 32 | seq_len=5, 33 | num_seq=8, 34 | downsample=3, 35 | return_label=False): 36 | self.mode = mode 37 | self.transform = transform 38 | self.seq_len = seq_len 39 | self.num_seq = num_seq 40 | self.downsample = downsample 41 | self.return_label = return_label 42 | 43 | classes = read_file(os.path.join(root, 'ClassInd.txt')) 44 | print('Frame Dataset from {} has #class {}'.format(root, len(classes))) 45 | self.num_class = len(classes) 46 | self.class_to_idx = {classes[i]:i for i in range(len(classes))} 47 | self.idx_to_class = {i:classes[i] for i in range(len(classes))} 48 | 49 | # splits 50 | if mode == 'train': 51 | split = '../process_data/data/kinetics400/train_split.csv' 52 | video_info = pd.read_csv(split, header=None) 53 | elif (mode == 'val') or (mode == 'test'): 54 | split = '../process_data/data/kinetics400/val_split.csv' 55 | video_info = pd.read_csv(split, header=None) 56 | else: raise ValueError('wrong mode') 57 | 58 | drop_idx = [] 59 | print('filter out too short videos ...') 60 | for idx, row in tqdm(video_info.iterrows(), total=len(video_info)): 61 | vpath, vlen = row 62 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 63 | drop_idx.append(idx) 64 | self.video_info = video_info.drop(drop_idx, axis=0) 65 | 66 | if mode == 'val': 67 | self.video_info = self.video_info.sample(frac=0.3, random_state=666) 68 | 69 | def idx_sampler(self, vlen, vpath): 70 | '''sample index from a video''' 71 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), 1) 72 | seq_idx = np.arange(self.num_seq*self.seq_len)*self.downsample + start_idx 73 | return seq_idx 74 | 75 | def __getitem__(self, index): 76 | vpath, vlen = self.video_info.iloc[index] 77 | frame_index = self.idx_sampler(vlen, vpath) 78 | 79 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in frame_index] 80 | t_seq = self.transform(seq) 81 | 82 | (C, H, W) = t_seq[0].size() 83 | t_seq = torch.stack(t_seq, 0) 84 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 85 | 86 | if self.return_label: 87 | try: 88 | vname = vpath.split('/')[-3] 89 | vid = self.encode_action(vname) 90 | except: 91 | vname = vpath.split('/')[-2] 92 | vid = self.encode_action(vname) 93 | 94 | label = torch.LongTensor([vid]) 95 | return t_seq, label 96 | 97 | return t_seq 98 | 99 | def __len__(self): 100 | return len(self.video_info) 101 | 102 | def encode_action(self, action_name): 103 | return self.class_to_idx[action_name] 104 | 105 | def decode_action(self, action_code): 106 | return self.idx_to_class[action_code] 107 | 108 | 109 | class UCF101Dataset(data.Dataset): 110 | def __init__(self, 111 | root='%s/../process_data/data/ucf101' % os.path.dirname(os.path.abspath(__file__)), 112 | mode='val', 113 | transform=None, 114 | seq_len=5, 115 | num_seq=8, 116 | downsample=3, 117 | which_split=1, 118 | return_label=False, 119 | return_path=False): 120 | self.mode = mode 121 | self.transform = transform 122 | self.seq_len = seq_len 123 | self.num_seq = num_seq 124 | self.downsample = downsample 125 | self.which_split = which_split 126 | self.return_label = return_label 127 | self.return_path = return_path 128 | 129 | # splits 130 | if mode == 'train': 131 | split = '../process_data/data/ucf101/train_split%02d.csv' % self.which_split 132 | video_info = pd.read_csv(split, header=None) 133 | elif (mode == 'val') or (mode == 'test'): # use val for test 134 | split = '../process_data/data/ucf101/test_split%02d.csv' % self.which_split 135 | video_info = pd.read_csv(split, header=None) 136 | else: raise ValueError('wrong mode') 137 | 138 | # get action list 139 | classes = read_file(os.path.join(root, 'ClassInd.txt')) 140 | print('Frame Dataset from {} has #class {}'.format(root, len(classes))) 141 | self.num_class = len(classes) 142 | self.class_to_idx = {classes[i]:i for i in range(len(classes))} 143 | self.idx_to_class = {i:classes[i] for i in range(len(classes))} 144 | 145 | # filter out too short videos: 146 | drop_idx = [] 147 | for idx, row in video_info.iterrows(): 148 | vpath, vlen = row 149 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 150 | drop_idx.append(idx) 151 | self.video_info = video_info.drop(drop_idx, axis=0) 152 | 153 | if mode == 'val': 154 | self.video_info = self.video_info.sample(frac=0.3) 155 | 156 | def idx_sampler(self, vlen, vpath): 157 | '''sample index from a video''' 158 | if self.mode == 'test': 159 | available = vlen-self.num_seq*self.seq_len*self.downsample 160 | start_idx = np.expand_dims(np.arange(0, available+1, self.num_seq*self.seq_len*self.downsample//2-1), 1) 161 | seq_idx = np.expand_dims(np.arange(self.num_seq*self.seq_len)*self.downsample, 0) + start_idx # [test_sample, num_frames] 162 | seq_idx = seq_idx.flatten(0) 163 | else: 164 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), 1) 165 | seq_idx = np.arange(self.num_seq*self.seq_len)*self.downsample + start_idx 166 | return seq_idx 167 | 168 | 169 | def __getitem__(self, index): 170 | vpath, vlen = self.video_info.iloc[index] 171 | frame_index = self.idx_sampler(vlen, vpath) 172 | 173 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in frame_index] 174 | t_seq = self.transform(seq) 175 | 176 | (C, H, W) = t_seq[0].size() 177 | t_seq = torch.stack(t_seq, 0) 178 | if self.mode == 'test': 179 | t_seq = t_seq.view(-1, self.num_seq, self.seq_len, C, H, W).transpose(2,3) 180 | else: 181 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 182 | 183 | if self.return_label: 184 | try: 185 | vname = vpath.split('/')[-3] 186 | vid = self.encode_action(vname) 187 | except: 188 | vname = vpath.split('/')[-2] 189 | vid = self.encode_action(vname) 190 | label = torch.LongTensor([vid]) 191 | if self.return_path: 192 | return t_seq, (label, vpath) 193 | else: 194 | return t_seq, label 195 | 196 | return t_seq 197 | 198 | def __len__(self): 199 | return len(self.video_info) 200 | 201 | def encode_action(self, action_name): 202 | return self.class_to_idx[action_name] 203 | 204 | def decode_action(self, action_code): 205 | return self.idx_to_class[action_code] 206 | 207 | 208 | class HMDB51Dataset(data.Dataset): 209 | def __init__(self, 210 | root='%s/../process_data/data/hmdb51' % os.path.dirname(os.path.abspath(__file__)), 211 | mode='val', 212 | transform=None, 213 | seq_len=5, 214 | num_seq=8, 215 | downsample=3, 216 | which_split=1, 217 | return_label=False, 218 | return_path=False): 219 | self.mode = mode 220 | self.transform = transform 221 | self.seq_len = seq_len 222 | self.num_seq = num_seq 223 | self.downsample = downsample 224 | self.which_split = which_split 225 | self.return_label = return_label 226 | self.return_path = return_path 227 | 228 | # splits 229 | if mode == 'train': 230 | split = '../process_data/data/hmdb51/train_split%02d.csv' % self.which_split 231 | video_info = pd.read_csv(split, header=None) 232 | elif (mode == 'val') or (mode == 'test'): # use val for test 233 | split = '../process_data/data/hmdb51/test_split%02d.csv' % self.which_split 234 | video_info = pd.read_csv(split, header=None) 235 | else: raise ValueError('wrong mode') 236 | 237 | # get action list 238 | classes = read_file(os.path.join(root, 'ClassInd.txt')) 239 | print('Frame Dataset from {} has #class {}'.format(root, len(classes))) 240 | self.num_class = len(classes) 241 | self.class_to_idx = {classes[i]:i for i in range(len(classes))} 242 | self.idx_to_class = {i:classes[i] for i in range(len(classes))} 243 | 244 | # filter out too short videos: 245 | drop_idx = [] 246 | for idx, row in video_info.iterrows(): 247 | vpath, vlen = row 248 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 249 | drop_idx.append(idx) 250 | self.video_info = video_info.drop(drop_idx, axis=0) 251 | 252 | if mode == 'val': 253 | self.video_info = self.video_info.sample(frac=0.3) 254 | 255 | def idx_sampler(self, vlen, vpath): 256 | '''sample index from a video''' 257 | if self.mode == 'test': 258 | available = vlen-self.num_seq*self.seq_len*self.downsample 259 | start_idx = np.expand_dims(np.arange(0, available+1, self.num_seq*self.seq_len*self.downsample//2-1), 1) 260 | seq_idx = np.expand_dims(np.arange(self.num_seq*self.seq_len)*self.downsample, 0) + start_idx # [test_sample, num_frames] 261 | seq_idx = seq_idx.flatten(0) 262 | else: 263 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), 1) 264 | seq_idx = np.arange(self.num_seq*self.seq_len)*self.downsample + start_idx 265 | return seq_idx 266 | 267 | 268 | def __getitem__(self, index): 269 | vpath, vlen = self.video_info.iloc[index] 270 | frame_index = self.idx_sampler(vlen, vpath) 271 | 272 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in frame_index] 273 | t_seq = self.transform(seq) 274 | 275 | (C, H, W) = t_seq[0].size() 276 | t_seq = torch.stack(t_seq, 0) 277 | if self.mode == 'test': 278 | t_seq = t_seq.view(-1, self.num_seq, self.seq_len, C, H, W).transpose(2,3) 279 | else: 280 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 281 | 282 | if self.return_label: 283 | try: 284 | vname = vpath.split('/')[-3] 285 | vid = self.encode_action(vname) 286 | except: 287 | vname = vpath.split('/')[-2] 288 | vid = self.encode_action(vname) 289 | label = torch.LongTensor([vid]) 290 | if self.return_path: 291 | return t_seq, (label, vpath) 292 | else: 293 | return t_seq, label 294 | 295 | return t_seq 296 | 297 | def __len__(self): 298 | return len(self.video_info) 299 | 300 | def encode_action(self, action_name): 301 | return self.class_to_idx[action_name] 302 | 303 | def decode_action(self, action_code): 304 | return self.idx_to_class[action_code] 305 | -------------------------------------------------------------------------------- /memdpc/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import re 5 | import argparse 6 | import numpy as np 7 | import random 8 | from tqdm import tqdm 9 | from tensorboardX import SummaryWriter 10 | import matplotlib.pyplot as plt 11 | plt.switch_backend('agg') 12 | 13 | sys.path.append('../') 14 | from dataset import K400Dataset, UCF101Dataset 15 | from model import MemDPC_BD 16 | 17 | import utils.augmentation as A 18 | from utils.utils import AverageMeter, save_checkpoint, Logger,\ 19 | calc_topk_accuracy, neq_load_customized, MultiStepLR_Restart_Multiplier 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | from torch.utils import data 25 | from torchvision import datasets, models, transforms 26 | import torchvision.utils as vutils 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--net', default='resnet18', type=str) 32 | parser.add_argument('--model', default='memdpc', type=str) 33 | parser.add_argument('--dataset', default='ucf101', type=str) 34 | parser.add_argument('--seq_len', default=5, type=int, help='number of frames in each video block') 35 | parser.add_argument('--num_seq', default=8, type=int, help='number of video blocks') 36 | parser.add_argument('--pred_step', default=3, type=int) 37 | parser.add_argument('--ds', default=3, type=int, help='frame downsampling rate') 38 | parser.add_argument('--mem_size', default=1024, type=int, help='memory size') 39 | parser.add_argument('--batch_size', default=4, type=int) 40 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 41 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 42 | parser.add_argument('--resume', default='', type=str, help='path of model to resume') 43 | parser.add_argument('--pretrain', default='', type=str, help='path of pretrained model') 44 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 46 | parser.add_argument('--gpu', default='0,1', type=str) 47 | parser.add_argument('--print_freq', default=5, type=int, help='frequency of printing output during training') 48 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 49 | parser.add_argument('--prefix', default='tmp', type=str, help='prefix of checkpoint filename') 50 | parser.add_argument('--img_dim', default=128, type=int) 51 | parser.add_argument('--seed', default=0, type=int) 52 | parser.add_argument('-j', '--workers', default=16, type=int) 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | def main(args): 58 | torch.manual_seed(args.seed) 59 | np.random.seed(args.seed) 60 | random.seed(args.seed) 61 | 62 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 63 | device = torch.device('cuda') 64 | num_gpu = len(str(args.gpu).split(',')) 65 | args.batch_size = num_gpu * args.batch_size 66 | 67 | ### model ### 68 | if args.model == 'memdpc': 69 | model = MemDPC_BD(sample_size=args.img_dim, 70 | num_seq=args.num_seq, 71 | seq_len=args.seq_len, 72 | network=args.net, 73 | pred_step=args.pred_step, 74 | mem_size=args.mem_size) 75 | else: 76 | raise NotImplementedError('wrong model!') 77 | 78 | model.to(device) 79 | model = nn.DataParallel(model) 80 | model_without_dp = model.module 81 | 82 | ### optimizer ### 83 | params = model.parameters() 84 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 85 | criterion = nn.CrossEntropyLoss() 86 | 87 | ### data ### 88 | transform = transforms.Compose([ 89 | A.RandomSizedCrop(size=224, consistent=True, p=1.0), # crop from 256 to 224 90 | A.Scale(size=(args.img_dim,args.img_dim)), 91 | A.RandomHorizontalFlip(consistent=True), 92 | A.RandomGray(consistent=False, p=0.25), 93 | A.ColorJitter(0.5, 0.5, 0.5, 0.25, consistent=False, p=1.0), 94 | A.ToTensor(), 95 | A.Normalize() 96 | ]) 97 | 98 | train_loader = get_data(transform, 'train') 99 | val_loader = get_data(transform, 'val') 100 | 101 | if 'ucf' in args.dataset: 102 | lr_milestones_eps = [300,400] 103 | elif 'k400' in args.dataset: 104 | lr_milestones_eps = [120,160] 105 | else: 106 | lr_milestones_eps = [1000] # NEVER 107 | lr_milestones = [len(train_loader) * m for m in lr_milestones_eps] 108 | print('=> Use lr_scheduler: %s eps == %s iters' % (str(lr_milestones_eps), str(lr_milestones))) 109 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=lr_milestones, repeat=1) 110 | lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 111 | 112 | best_acc = 0 113 | args.iteration = 1 114 | 115 | ### restart training ### 116 | if args.resume: 117 | if os.path.isfile(args.resume): 118 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 119 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 120 | args.start_epoch = checkpoint['epoch'] 121 | args.iteration = checkpoint['iteration'] 122 | best_acc = checkpoint['best_acc'] 123 | model_without_dp.load_state_dict(checkpoint['state_dict']) 124 | try: 125 | optimizer.load_state_dict(checkpoint['optimizer']) 126 | except: 127 | print('[WARNING] Not loading optimizer states') 128 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 129 | else: 130 | print("[Warning] no checkpoint found at '{}'".format(args.resume)) 131 | sys.exit(0) 132 | 133 | # logging tools 134 | args.img_path, args.model_path = set_path(args) 135 | args.logger = Logger(path=args.img_path) 136 | args.logger.log('args=\n\t\t'+'\n\t\t'.join(['%s:%s'%(str(k),str(v)) for k,v in vars(args).items()])) 137 | 138 | args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val')) 139 | args.writer_train = SummaryWriter(logdir=os.path.join(args.img_path, 'train')) 140 | 141 | torch.backends.cudnn.benchmark = True 142 | 143 | ### main loop ### 144 | for epoch in range(args.start_epoch, args.epochs): 145 | np.random.seed(epoch) 146 | random.seed(epoch) 147 | 148 | train_loss, train_acc = train_one_epoch(train_loader, 149 | model, 150 | criterion, 151 | optimizer, 152 | lr_scheduler, 153 | device, 154 | epoch, 155 | args) 156 | val_loss, val_acc = validate(val_loader, 157 | model, 158 | criterion, 159 | device, 160 | epoch, 161 | args) 162 | 163 | # save check_point 164 | is_best = val_acc > best_acc 165 | best_acc = max(val_acc, best_acc) 166 | save_dict = {'epoch': epoch, 167 | 'state_dict': model_without_dp.state_dict(), 168 | 'best_acc': best_acc, 169 | 'optimizer': optimizer.state_dict(), 170 | 'iteration': args.iteration} 171 | save_checkpoint(save_dict, is_best, 172 | filename=os.path.join(args.model_path, 'epoch%s.pth.tar' % str(epoch)), 173 | keep_all=False) 174 | 175 | print('Training from ep %d to ep %d finished' 176 | % (args.start_epoch, args.epochs)) 177 | sys.exit(0) 178 | 179 | 180 | def train_one_epoch(data_loader, model, criterion, optimizer, lr_scheduler, device, epoch, args): 181 | batch_time = AverageMeter() 182 | data_time = AverageMeter() 183 | losses = AverageMeter() 184 | accuracy = [[AverageMeter(), AverageMeter()], # forward top1, top5 185 | [AverageMeter(), AverageMeter()]] # backward top1, top5 186 | 187 | model.train() 188 | end = time.time() 189 | tic = time.time() 190 | 191 | for idx, input_seq in enumerate(data_loader): 192 | data_time.update(time.time() - end) 193 | 194 | input_seq = input_seq.to(device) 195 | B = input_seq.size(0) 196 | loss, loss_step, acc, extra = model(input_seq) 197 | 198 | for i in range(2): 199 | top1, top5 = acc[i].mean(0) # average acc across multi-gpus 200 | accuracy[i][0].update(top1.item(), B) 201 | accuracy[i][1].update(top5.item(), B) 202 | 203 | loss = loss.mean() # average loss across multi-gpus 204 | losses.update(loss.item(), B) 205 | 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | 210 | batch_time.update(time.time() - end) 211 | end = time.time() 212 | 213 | if idx % args.print_freq == 0: 214 | print('Epoch: [{0}][{1}/{2}]\t' 215 | 'Loss {loss.val:.6f}\t' 216 | 'Acc: {acc[0][0].val:.4f}\t' 217 | 'T-data:{dt.val:.2f} T-batch:{bt.val:.2f}\t'.format( 218 | epoch, idx, len(data_loader), 219 | loss=losses, acc=accuracy, dt=data_time, bt=batch_time)) 220 | 221 | args.writer_train.add_scalar('local/loss', losses.val, args.iteration) 222 | args.writer_train.add_scalar('local/F-top1', accuracy[0][0].val, args.iteration) 223 | args.writer_train.add_scalar('local/F-top5', accuracy[0][1].val, args.iteration) 224 | args.writer_train.add_scalar('local/B-top1', accuracy[1][0].val, args.iteration) 225 | args.writer_train.add_scalar('local/B-top5', accuracy[1][1].val, args.iteration) 226 | 227 | args.iteration += 1 228 | if lr_scheduler is not None: lr_scheduler.step() 229 | 230 | print('Epoch: [{0}]\t' 231 | 'T-epoch:{t:.2f}\t'.format(epoch, t=time.time()-tic)) 232 | 233 | args.writer_train.add_scalar('global/loss', losses.avg, epoch) 234 | args.writer_train.add_scalar('global/F-top1', accuracy[0][0].avg, epoch) 235 | args.writer_train.add_scalar('global/F-top5', accuracy[0][1].avg, epoch) 236 | args.writer_train.add_scalar('global/B-top1', accuracy[1][0].avg, epoch) 237 | args.writer_train.add_scalar('global/B-top5', accuracy[1][1].avg, epoch) 238 | 239 | return losses.avg, np.mean([accuracy[0][0].avg, accuracy[1][0].avg]) 240 | 241 | 242 | def validate(data_loader, model, criterion, device, epoch, args): 243 | losses = AverageMeter() 244 | accuracy = [[AverageMeter(), AverageMeter()], # forward top1, top5 245 | [AverageMeter(), AverageMeter()]] # backward top1, top5 246 | 247 | model.eval() 248 | 249 | with torch.no_grad(): 250 | for idx, input_seq in enumerate(data_loader): 251 | input_seq = input_seq.to(device) 252 | B = input_seq.size(0) 253 | loss, loss_step, acc, extra = model(input_seq) 254 | 255 | for i in range(2): 256 | top1, top5 = acc[i].mean(0) # average acc across multi-gpus 257 | accuracy[i][0].update(top1.item(), B) 258 | accuracy[i][1].update(top5.item(), B) 259 | 260 | loss = loss.mean() # average loss across multi-gpus 261 | losses.update(loss.item(), B) 262 | 263 | print('Epoch: [{0}/{1}]\t' 264 | 'Loss {loss.val:.6f}\t' 265 | 'Acc: {acc[0][0].val:.4f}\t'.format( 266 | epoch, args.epochs, 267 | loss=losses, acc=accuracy)) 268 | 269 | args.writer_val.add_scalar('global/loss', losses.avg, epoch) 270 | args.writer_val.add_scalar('global/F-top1', accuracy[0][0].avg, epoch) 271 | args.writer_val.add_scalar('global/F-top5', accuracy[0][1].avg, epoch) 272 | args.writer_val.add_scalar('global/B-top1', accuracy[1][0].avg, epoch) 273 | args.writer_val.add_scalar('global/B-top5', accuracy[1][1].avg, epoch) 274 | 275 | return losses.avg, np.mean([accuracy[0][0].avg, accuracy[1][0].avg]) 276 | 277 | 278 | def get_data(transform, mode='train'): 279 | print('Loading {} dataset for {}'.format(args.dataset, mode)) 280 | if args.dataset == 'k400': 281 | dataset = K400Dataset(mode=mode, 282 | transform=transform, 283 | seq_len=args.seq_len, 284 | num_seq=args.num_seq, 285 | downsample=args.ds) 286 | elif args.dataset == 'ucf101': 287 | dataset = UCF101Dataset(mode=mode, 288 | transform=transform, 289 | seq_len=args.seq_len, 290 | num_seq=args.num_seq, 291 | downsample=args.ds) 292 | else: 293 | raise NotImplementedError('dataset not supported') 294 | 295 | sampler = data.RandomSampler(dataset) 296 | data_loader = data.DataLoader(dataset, 297 | batch_size=args.batch_size, 298 | sampler=sampler, 299 | shuffle=False, 300 | num_workers=args.workers, 301 | pin_memory=True, 302 | drop_last=True) 303 | print('"%s" dataset size: %d' % (mode, len(dataset))) 304 | return data_loader 305 | 306 | def set_path(args): 307 | if args.resume: 308 | exp_path = os.path.dirname(os.path.dirname(args.resume)) 309 | else: 310 | exp_path = 'log_{args.prefix}/{args.model}_{args.dataset}-{args.img_dim}_{args.net}_\ 311 | mem{args.mem_size}_bs{args.batch_size}_lr{args.lr}_seq{args.num_seq}_pred{args.pred_step}_\ 312 | len{args.seq_len}_ds{args.ds}'.format(args=args) 313 | 314 | img_path = os.path.join(exp_path, 'img') 315 | model_path = os.path.join(exp_path, 'model') 316 | if not os.path.exists(img_path): 317 | os.makedirs(img_path) 318 | if not os.path.exists(model_path): 319 | os.makedirs(model_path) 320 | 321 | return img_path, model_path 322 | 323 | if __name__ == '__main__': 324 | args = parse_args() 325 | main(args) 326 | -------------------------------------------------------------------------------- /memdpc/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import math 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | sys.path.append('../') 10 | from backbone.select_backbone import select_resnet 11 | from backbone.convrnn import ConvGRU 12 | from utils.utils import calc_topk_accuracy 13 | 14 | 15 | class MemDPC_BD(nn.Module): 16 | '''MemDPC with bi-directional RNN''' 17 | def __init__(self, 18 | sample_size, 19 | num_seq=8, 20 | seq_len=5, 21 | pred_step=3, 22 | network='resnet18', 23 | mem_size=1024): 24 | super(MemDPC_BD, self).__init__() 25 | print('Using MemDPC-BiDirectional model with {} and mem_size {}'\ 26 | .format(network, mem_size)) 27 | self.sample_size = sample_size 28 | self.num_seq = num_seq 29 | self.seq_len = seq_len 30 | self.pred_step = pred_step 31 | self.last_duration = int(math.ceil(seq_len / 4)) 32 | self.last_size = int(math.ceil(sample_size / 32)) 33 | self.mem_size = mem_size 34 | self.tgt_dict = {} 35 | print('final feature map has size %dx%d' % (self.last_size, self.last_size)) 36 | 37 | self.backbone, self.param = select_resnet(network) 38 | self.param['num_layers'] = 1 # param for GRU 39 | self.param['hidden_size'] = self.param['feature_size'] # param for GRU 40 | self.param['membanks_size'] = mem_size 41 | self.mb = torch.nn.Parameter(torch.randn(self.param['membanks_size'], self.param['feature_size'])) 42 | print('MEM Bank has size %dx%d' % (self.param['membanks_size'], self.param['feature_size'])) 43 | 44 | # bi-directional RNN 45 | self.agg_f = ConvGRU(input_size=self.param['feature_size'], 46 | hidden_size=self.param['hidden_size'], 47 | kernel_size=1, 48 | num_layers=self.param['num_layers']) 49 | self.agg_b = ConvGRU(input_size=self.param['feature_size'], 50 | hidden_size=self.param['hidden_size'], 51 | kernel_size=1, 52 | num_layers=self.param['num_layers']) 53 | 54 | self.network_pred = nn.Sequential( 55 | nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(self.param['feature_size'], self.param['membanks_size'], kernel_size=1, padding=0) 58 | ) 59 | self.mask = None 60 | self.relu = nn.ReLU(inplace=False) 61 | self.ce_loss = nn.CrossEntropyLoss(reduction='none') 62 | 63 | self._initialize_weights(self.agg_f) 64 | self._initialize_weights(self.agg_b) 65 | self._initialize_weights(self.network_pred) 66 | 67 | 68 | def get_loss(self, pred, gt, B, SL, last_size, feature_size, kernel=1): 69 | # pred: B,C,N,H,W 70 | # GT: C,B,N,H*H 71 | score = torch.matmul(pred.permute(0,2,3,4,1).contiguous().view(B*SL*last_size**2,feature_size), 72 | gt.contiguous().view(feature_size, B*SL*last_size**2)) 73 | if SL not in self.tgt_dict: 74 | self.tgt_dict[SL] = torch.arange(B*SL*last_size**2) 75 | tgt = self.tgt_dict[SL].to(score.device) 76 | loss = self.ce_loss(score, tgt) 77 | top1, top5 = calc_topk_accuracy(score, tgt, (1,5)) 78 | return loss, top1, top5 79 | 80 | 81 | def forward(self, block): 82 | # extract feature 83 | (B, N, C, SL, H, W) = block.shape 84 | block = block.view(B*N, C, SL, H, W) 85 | feat3d = self.backbone(block) 86 | 87 | feat3d = F.avg_pool3d(feat3d, (self.last_duration, 1, 1), stride=(1, 1, 1)) 88 | feat3d = feat3d.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # before ReLU, (-inf, +inf) 89 | 90 | losses = [] # all loss 91 | acc = [] # all acc 92 | loss = 0 93 | gt = feat3d.permute(2,0,1,3,4).contiguous().view(self.param['feature_size'], B, N, self.last_size**2) 94 | 95 | feat3d_b = torch.flip(feat3d, dims=(1,)) 96 | gt_b = torch.flip(gt, dims=(2,)) 97 | 98 | # forward MemDPC 99 | pd_tmp_pool = [] 100 | for j in range(self.pred_step): 101 | if j == 0: 102 | feat_tmp = feat3d[:,0:(N-self.pred_step),:,:,:] 103 | _, hidden = self.agg_f(F.relu(feat_tmp)) 104 | context_feature = hidden.clone() 105 | else: 106 | _, hidden = self.agg_f(F.relu(pd_tmp).unsqueeze(1), hidden.unsqueeze(0)) 107 | hidden = hidden[:,-1,:] # after tanh, (-1,1). get the hidden state of last layer, last time step 108 | pd_tmp = self.network_pred(hidden) 109 | pd_tmp = F.softmax(pd_tmp, dim=1) # B,MEM,H,W 110 | pd_tmp = torch.einsum('bmhw,mc->bchw', pd_tmp, self.mb) 111 | pd_tmp_pool.append(pd_tmp) 112 | 113 | pd_tmp_pool = torch.stack(pd_tmp_pool, dim=2); SL_tmp = pd_tmp_pool.size(2) 114 | gt_tmp = gt[:,:,-self.pred_step::,:] 115 | loss_tmp, top1, top5 = self.get_loss(pd_tmp_pool, gt_tmp, B, SL_tmp, self.last_size, self.param['feature_size']) 116 | loss_tmp = loss_tmp.mean() 117 | loss = loss_tmp 118 | losses.append(loss_tmp.data.unsqueeze(0)) 119 | acc.append(torch.stack([top1, top5], 0).unsqueeze(0)) 120 | 121 | 122 | # backward MemDPC 123 | pd_tmp_pool_b = [] 124 | for j in range(self.pred_step): 125 | if j == 0: 126 | feat_tmp = feat3d_b[:,0:(N-self.pred_step),:,:,:] 127 | _, hidden = self.agg_b(F.relu(feat_tmp)) 128 | else: 129 | _, hidden = self.agg_b(F.relu(pd_tmp_b).unsqueeze(1), hidden.unsqueeze(0)) 130 | hidden = hidden[:,-1,:] # after tanh, (-1,1). get the hidden state of last layer, last time step 131 | pd_tmp_b = self.network_pred(hidden) 132 | pd_tmp_b = F.softmax(pd_tmp_b, dim=1) # B,MEM,H,W 133 | pd_tmp_b = torch.einsum('bmhw,mc->bchw', pd_tmp_b, self.mb) 134 | pd_tmp_pool_b.append(pd_tmp_b) 135 | 136 | pd_tmp_pool_b = torch.stack(pd_tmp_pool_b, dim=2); SL_tmp = pd_tmp_pool_b.size(2) 137 | gt_tmp_b = gt_b[:,:,-self.pred_step::,:] 138 | loss_tmp_b, top1, top5 = self.get_loss(pd_tmp_pool_b, gt_tmp_b, B, SL_tmp, self.last_size, self.param['feature_size']) 139 | loss_tmp_b = loss_tmp_b.mean() 140 | losses.append(loss_tmp_b.data.unsqueeze(0)) 141 | acc.append(torch.stack([top1, top5], 0).unsqueeze(0)) 142 | 143 | loss = loss + loss_tmp_b 144 | 145 | return loss, losses, acc, context_feature 146 | 147 | 148 | def _initialize_weights(self, module): 149 | for name, param in module.named_parameters(): 150 | if 'bias' in name: 151 | nn.init.constant_(param, 0.0) 152 | elif 'weight' in name: 153 | nn.init.orthogonal_(param, 0.1) 154 | 155 | -------------------------------------------------------------------------------- /process_data/readme.md: -------------------------------------------------------------------------------- 1 | ## Process data 2 | 3 | This folder has some tools to process UCF101, HMDB51 and Kinetics400 datasets. 4 | 5 | ### 1. Download 6 | 7 | Download the videos from source: 8 | [UCF101 source](https://www.crcv.ucf.edu/data/UCF101.php), 9 | [HMDB51 source](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads), 10 | [Kinetics400 source](https://deepmind.com/research/publications/kinetics-human-action-video-dataset). 11 | 12 | Make sure datasets are stored as follows: 13 | 14 | * UCF101 15 | ``` 16 | {your_path}/UCF101/videos/{action class}/{video name}.avi 17 | {your_path}/UCF101/splits_classification/trainlist{01/02/03}.txt 18 | {your_path}/UCF101/splits_classification/testlist{01/02/03}}.txt 19 | ``` 20 | 21 | * HMDB51 22 | ``` 23 | {your_path}/HMDB51/videos/{action class}/{video name}.avi 24 | {your_path}/HMDB51/split/testTrainMulti_7030_splits/{action class}_test_split{1/2/3}.txt 25 | ``` 26 | 27 | * Kinetics400 28 | ``` 29 | {your_path}/Kinetics400/videos/train_split/{action class}/{video name}.mp4 30 | {your_path}/Kinetics400/videos/val_split/{action class}/{video name}.mp4 31 | ``` 32 | Also keep the downloaded csv files, make sure you have: 33 | ``` 34 | {your_path}/Kinetics/kinetics_train/kinetics_train.csv 35 | {your_path}/Kinetics/kinetics_val/kinetics_val.csv 36 | {your_path}/Kinetics/kinetics_test/kinetics_test.csv 37 | ``` 38 | 39 | ### 2. Extract frame and flow 40 | 41 | [Optional for Kinetics400] Reduce the video dimension to short size = 256 pixel with `ffmpeg`, to save space and time. An example is provided in [src/resize_video.py](src/resize_video.py). 42 | 43 | Edit path arguments in `main_*()` functions, and `python extract_ff.py`. Video frame and TV-L1 optical flow will be extracted. 44 | 45 | ### 3. Collect paths into csv 46 | 47 | Edit path arguments in `main_*()` functions, and `python write_csv.py`. csv files will be stored in `data/` directory. 48 | 49 | then prepare `data/ClassInd.txt` that stores sorted action names for corresponding dataset, like 50 | ``` 51 | ApplyEyeMakeup 52 | ApplyLipstick 53 | Archery 54 | ... 55 | ``` 56 | is for UCF101. 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /process_data/src/extract_ff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import cv2 5 | import numpy as np 6 | from joblib import delayed, Parallel 7 | from tqdm import tqdm 8 | import platform 9 | import argparse 10 | import glob 11 | 12 | 13 | def compute_TVL1(prev, curr, bound=20): 14 | """Compute the TV-L1 optical flow.""" 15 | TVL1 = cv2.DualTVL1OpticalFlow_create() 16 | flow = TVL1.calc(prev, curr, None) 17 | flow = np.clip(flow, -bound, bound) 18 | 19 | flow = (flow + bound) * (255.0 / (2*bound)) 20 | flow = np.round(flow).astype('uint8') 21 | return flow 22 | 23 | 24 | def extract_ff_opencv(v_path, frame_root, flow_root): 25 | '''opencv version: 26 | v_path: single video path xxx/action/vname.mp4 27 | frame_root: root to store flow 28 | flow_root: root to store flow ''' 29 | v_class = v_path.split('/')[-2] 30 | v_name = os.path.basename(v_path)[0:-4] 31 | frame_out_dir = os.path.join(frame_root, v_class, v_name) 32 | flow_out_dir = os.path.join(flow_root, v_class, v_name) 33 | for i in [frame_out_dir, flow_out_dir]: 34 | if not os.path.exists(i): 35 | os.makedirs(i) 36 | else: 37 | print('[WARNING]', i, 'exists, continue...') 38 | 39 | 40 | vidcap = cv2.VideoCapture(v_path) 41 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 42 | 43 | if len(glob.glob(os.path.join(frame_out_dir, '*.jpg'))) >= nb_frames - 3: # tolerance = 3 frame difference 44 | print('[WARNING]', frame_out_dir, 'has finished, dropped!') 45 | vidcap.release() 46 | return 47 | 48 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 49 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 50 | if (width == 0) or (height==0): 51 | print(width, height, v_path) 52 | 53 | empty_img = 128 * np.ones((int(height),int(width),3)).astype(np.uint8) 54 | success, image = vidcap.read() 55 | count = 1 56 | 57 | pbar = tqdm(total=nb_frames) 58 | while success: 59 | cv2.imwrite(os.path.join(frame_out_dir, 'image_%05d.jpg' % count), 60 | image, 61 | [cv2.IMWRITE_JPEG_QUALITY, 100]) # quality from 0-100, 95 is default, high is good 62 | image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 63 | if count != 1: 64 | flow = compute_TVL1(prev_gray, image_gray) 65 | flow_img = empty_img.copy() 66 | flow_img[:,:,0:2] = flow 67 | cv2.imwrite(os.path.join(flow_out_dir, 'flow_%05d.jpg' % (count-1)), 68 | flow_img, 69 | [cv2.IMWRITE_JPEG_QUALITY, 100]) 70 | 71 | prev_gray = image_gray 72 | success, image = vidcap.read() 73 | count += 1 74 | pbar.update(1) 75 | 76 | if nb_frames > count: 77 | print(frame_out_dir, 'is NOT extracted successfully', nb_frames, count) 78 | vidcap.release() 79 | 80 | return 81 | 82 | 83 | 84 | def main_UCF101(v_root, frame_root, flow_root): 85 | print('extracting UCF101 ... ') 86 | print('extracting videos from %s' % v_root) 87 | print('frame save to %s' % frame_root) 88 | print('flow save to %s' % flow_root) 89 | 90 | if not os.path.exists(frame_root): os.makedirs(frame_root) 91 | if not os.path.exists(flow_root): os.makedirs(flow_root) 92 | 93 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 94 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 95 | v_paths = glob.glob(os.path.join(j, '*.avi')) 96 | v_paths = sorted(v_paths) 97 | Parallel(n_jobs=32)(delayed(extract_ff_opencv)\ 98 | (p, frame_root, flow_root) for p in tqdm(v_paths, total=len(v_paths))) 99 | 100 | 101 | def main_HMDB51(v_root, frame_root, flow_root): 102 | print('extracting HMDB51 ... ') 103 | print('extracting videos from %s' % v_root) 104 | print('frame save to %s' % frame_root) 105 | print('flow save to %s' % flow_root) 106 | 107 | if not os.path.exists(frame_root): os.makedirs(frame_root) 108 | if not os.path.exists(flow_root): os.makedirs(flow_root) 109 | 110 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 111 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 112 | v_paths = glob.glob(os.path.join(j, '*.avi')) 113 | v_paths = sorted(v_paths) 114 | Parallel(n_jobs=32)(delayed(extract_ff_opencv)\ 115 | (p, frame_root, flow_root) for p in tqdm(v_paths, total=len(v_paths))) 116 | 117 | 118 | def main_kinetics400(v_root, frame_root, flow_root): 119 | print('extracting Kinetics400 ... ') 120 | for basename in ['train_split', 'val_split']: 121 | v_root_real = v_root + '/' + basename 122 | if not os.path.exists(v_root_real): 123 | print('Wrong v_root'); sys.exit() 124 | 125 | frame_root_real = os.path.join(frame_root, basename) 126 | flow_root_real = os.path.join(flow_root, basename) 127 | print('frame save to %s' % frame_root_real) 128 | print('flow save to %s' % flow_root_real) 129 | 130 | if not os.path.exists(frame_root_real): os.makedirs(frame_root_real) 131 | if not os.path.exists(flow_root_real): os.makedirs(flow_root_real) 132 | 133 | v_act_root = glob.glob(os.path.join(v_root_real, '*/')) 134 | v_act_root = sorted(v_act_root) 135 | 136 | # if resume, remember to delete the last video folder 137 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 138 | v_paths = glob.glob(os.path.join(j, '*.mp4')) 139 | v_paths = sorted(v_paths) 140 | 141 | # for resume: 142 | v_class = j.split('/')[-2] 143 | out_dir = os.path.join(frame_root_real, v_class) 144 | if os.path.exists(out_dir): 145 | print(out_dir, 'exists!') 146 | continue 147 | 148 | print('extracting: %s' % v_class) 149 | Parallel(n_jobs=32)(delayed(extract_ff_opencv)\ 150 | (p, frame_root_real, flow_root_real) for p in tqdm(v_paths, total=len(v_paths))) 151 | 152 | 153 | if __name__ == '__main__': 154 | # edit 'your_path' here: 155 | main_UCF101(v_root='your_path/UCF101/videos', 156 | frame_root='your_path/UCF101/frame', 157 | flow_root='your_path/UCF101/flow') 158 | 159 | main_HMDB51(v_root='your_path/HMDB51/videos', 160 | frame_root='your_path/HMDB51/frame', 161 | flow_root='your_path/HMDB51/flow') 162 | 163 | main_kinetics400(v_root='your_path/Kinetics400/videos', 164 | frame_root='your_path/Kinetics400/frame', 165 | flow_root='your_path/Kinetics400/flow') -------------------------------------------------------------------------------- /process_data/src/resize_video.py: -------------------------------------------------------------------------------- 1 | # by htd@robots.ox.ac.uk 2 | from joblib import delayed, Parallel 3 | import os 4 | import sys 5 | import glob 6 | import subprocess 7 | from tqdm import tqdm 8 | import cv2 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | plt.switch_backend('agg') 12 | 13 | def resize_video_ffmpeg(v_path, out_path, dim=256): 14 | '''v_path: single video path; 15 | out_path: root to store output videos''' 16 | v_class = v_path.split('/')[-2] 17 | v_name = os.path.basename(v_path)[0:-4] 18 | out_dir = os.path.join(out_path, v_class) 19 | if not os.path.exists(out_dir): 20 | raise ValueError("directory not exist, it shouldn't happen") 21 | 22 | vidcap = cv2.VideoCapture(v_path) 23 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 24 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 25 | if (width == 0) or (height==0): 26 | print(v_path, 'not successfully loaded, drop ..'); return 27 | new_dim = resize_dim(width, height, dim) 28 | if new_dim[0] == dim: 29 | dim_cmd = '%d:-2' % dim 30 | elif new_dim[1] == dim: 31 | dim_cmd = '-2:%d' % dim 32 | 33 | cmd = ['ffmpeg', '-loglevel', 'quiet', '-y', 34 | '-i', '%s'%v_path, 35 | '-vf', 36 | 'scale=%s'%dim_cmd, 37 | '%s' % os.path.join(out_dir, os.path.basename(v_path))] 38 | ffmpeg = subprocess.call(cmd) 39 | 40 | def resize_dim(w, h, target): 41 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 42 | if w >= h: 43 | return [int(target * w / h), int(target)] 44 | else: 45 | return [int(target), int(target * h / w)] 46 | 47 | def main_kinetics400(output_path='your_path/kinetics400'): 48 | print('save to %s ... ' % output_path) 49 | for splitname in ['val_split', 'train_split']: 50 | v_root = '/datasets/KineticsVideo' + '/' + splitname 51 | if not os.path.exists(v_root): 52 | print('Wrong v_root') 53 | import ipdb; ipdb.set_trace() # for debug 54 | out_path = os.path.join(output_path, splitname) 55 | if not os.path.exists(out_path): 56 | os.makedirs(out_path) 57 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 58 | v_act_root = sorted(v_act_root) 59 | 60 | # if resume, remember to delete the last video folder 61 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 62 | v_paths = glob.glob(os.path.join(j, '*.mp4')) 63 | v_paths = sorted(v_paths) 64 | v_class = j.split('/')[-2] 65 | out_dir = os.path.join(out_path, v_class) 66 | if os.path.exists(out_dir): 67 | print(out_dir, 'exists!'); continue 68 | else: 69 | os.makedirs(out_dir) 70 | 71 | print('extracting: %s' % v_class) 72 | Parallel(n_jobs=8)(delayed(resize_video_ffmpeg)(p, out_path, dim=256) for p in tqdm(v_paths, total=len(v_paths))) 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | main_kinetics400(output_path='your_path/kinetics400') 78 | # users need to change output_path and v_root 79 | -------------------------------------------------------------------------------- /process_data/src/write_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | 5 | def write_list(data_list, path, ): 6 | with open(path, 'w') as f: 7 | writer = csv.writer(f, delimiter=',') 8 | for row in data_list: 9 | if row: writer.writerow(row) 10 | print('split saved to %s' % path) 11 | 12 | def main_UCF101(f_root, splits_root, csv_root='../data/ucf101/'): 13 | '''generate training/testing split, count number of available frames, save in csv''' 14 | if not os.path.exists(csv_root): os.makedirs(csv_root) 15 | for which_split in [1,2,3]: 16 | train_set = [] 17 | test_set = [] 18 | train_split_file = os.path.join(splits_root, 'trainlist%02d.txt' % which_split) 19 | with open(train_split_file, 'r') as f: 20 | for line in f: 21 | vpath = os.path.join(f_root, line.split(' ')[0][0:-4]) + '/' 22 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 23 | 24 | test_split_file = os.path.join(splits_root, 'testlist%02d.txt' % which_split) 25 | with open(test_split_file, 'r') as f: 26 | for line in f: 27 | vpath = os.path.join(f_root, line.rstrip()[0:-4]) + '/' 28 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 29 | 30 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 31 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 32 | 33 | 34 | def main_HMDB51(f_root, splits_root, csv_root='../data/hmdb51/'): 35 | '''generate training/testing split, count number of available frames, save in csv''' 36 | if not os.path.exists(csv_root): os.makedirs(csv_root) 37 | for which_split in [1,2,3]: 38 | train_set = [] 39 | test_set = [] 40 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 41 | assert len(split_files) == 51 42 | for split_file in split_files: 43 | action_name = os.path.basename(split_file)[0:-16] 44 | with open(split_file, 'r') as f: 45 | for line in f: 46 | video_name = line.split(' ')[0] 47 | _type = line.split(' ')[1] 48 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 49 | if _type == '1': 50 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 51 | elif _type == '2': 52 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 53 | 54 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 55 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 56 | 57 | ### For Kinetics ### 58 | def get_split(root, split_path, mode): 59 | print('processing %s split ...' % mode) 60 | print('checking %s' % root) 61 | split_list = [] 62 | split_content = pd.read_csv(split_path).iloc[:,0:4] 63 | split_list = Parallel(n_jobs=64)\ 64 | (delayed(check_exists)(row, root) \ 65 | for i, row in tqdm(split_content.iterrows(), total=len(split_content))) 66 | return split_list 67 | 68 | def check_exists(row, root): 69 | dirname = '_'.join([row['youtube_id'], '%06d' % row['time_start'], '%06d' % row['time_end']]) 70 | full_dirname = os.path.join(root, row['label'], dirname) 71 | if os.path.exists(full_dirname): 72 | n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) 73 | return [full_dirname, n_frames] 74 | else: 75 | return None 76 | 77 | def main_Kinetics400(mode, k400_path, f_root, csv_root='../data/kinetics400'): 78 | train_split_path = os.path.join(k400_path, 'kinetics_train/kinetics_train.csv') 79 | val_split_path = os.path.join(k400_path, 'kinetics_val/kinetics_val.csv') 80 | test_split_path = os.path.join(k400_path, 'kinetics_test/kinetics_test.csv') 81 | if not os.path.exists(csv_root): os.makedirs(csv_root) 82 | if mode == 'train': 83 | train_split = get_split(os.path.join(f_root, 'train_split'), train_split_path, 'train') 84 | write_list(train_split, os.path.join(csv_root, 'train_split.csv')) 85 | elif mode == 'val': 86 | val_split = get_split(os.path.join(f_root, 'val_split'), val_split_path, 'val') 87 | write_list(val_split, os.path.join(csv_root, 'val_split.csv')) 88 | elif mode == 'test': 89 | test_split = get_split(f_root, test_split_path, 'test') 90 | write_list(test_split, os.path.join(csv_root, 'test_split.csv')) 91 | else: 92 | raise IOError('wrong mode') 93 | 94 | if __name__ == '__main__': 95 | # f_root is the frame path 96 | # edit 'your_path' here: 97 | 98 | main_UCF101(f_root='your_path/UCF101/frame', 99 | splits_root='your_path/UCF101/splits_classification') 100 | 101 | # main_HMDB51(f_root='your_path/HMDB51/frame', 102 | # splits_root='your_path/HMDB51/split/testTrainMulti_7030_splits') 103 | 104 | # main_Kinetics400(mode='train', # train or val or test 105 | # k400_path='your_path/Kinetics', 106 | # f_root='your_path/Kinetics400/frame') 107 | 108 | # main_Kinetics400(mode='train', # train or val or test 109 | # k400_path='your_path/Kinetics', 110 | # f_root='your_path/Kinetics400_256/frame', 111 | # csv_root='../data/kinetics400_256') -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Memory-augmented Dense Predictive Coding for Video Representation Learning 2 | 3 | This repository contains the implementation of Memory-augmented Dense Predictive Coding (MemDPC). 4 | 5 | Links: [[arXiv](https://arxiv.org/abs/2008.01065)] [[PDF](http://www.robots.ox.ac.uk/~vgg/publications/2020/Han20/han20.pdf)] [[Video](https://www.youtube.com/watch?v=XlR7QoM053k)] [[Project page](http://www.robots.ox.ac.uk/~vgg/research/DPC/)] 6 | 7 | ![arch](asset/arch.png) 8 | 9 | ### News 10 | 11 | * 2020/09/08: upload [evaluation code](https://github.com/TengdaHan/MemDPC#evaluation) for action classification and upload [pretrained weights on Kinetics400](https://github.com/TengdaHan/MemDPC#memdpc-pretrained-weights). 12 | 13 | * 2020/08/26: correct the DynamoNet statistics in the figure. DynamoNet uses 500K videos from Youtube8M but only use 10-second clip from each, totally the video length is about 58 days. 14 | 15 | ### Preparation 16 | 17 | This repository is implemented in PyTorch 1.2, but newer version should also work. 18 | Additionally, it needs cv2, joblib, tqdm, tensorboardX. 19 | 20 | For the dataset, please follow the instructions [here](process_data/). 21 | 22 | 23 | 24 | ### Self-supervised training (MemDPC) 25 | 26 | * Change directory `cd memdpc/` 27 | 28 | * Train MemDPC on UCF101 rgb stream 29 | ``` 30 | python main.py --gpu 0,1 --net resnet18 --dataset ucf101 --batch_size 16 --img_dim 128 --epochs 500 31 | ``` 32 | 33 | * Train MemDPC on Kinetics400 rgb stream 34 | ``` 35 | python main.py --gpu 0,1,2,3 --net resnet34 --dataset k400 --batch_size 16 --img_dim 224 --epochs 200 36 | ``` 37 | 38 | ### Evaluation 39 | 40 | Finetune entire network for action classification on UCF101: 41 | ![arch](asset/finetune.png) 42 | 43 | * Change directory `cd eval/` 44 | 45 | * Train action classifier by finetuning the pretrained weights 46 | ``` 47 | python test.py --gpu 0,1 --net resnet34 --dataset ucf101 --batch_size 16 \ 48 | --img_dim 224 --epochs 500 --train_what ft --schedule 300 400 49 | ``` 50 | 51 | * Train action classifier by freezing the pretrained weights and only a linear layer 52 | ``` 53 | python test.py --gpu 0,1 --net resnet34 --dataset ucf101 --batch_size 16 \ 54 | --img_dim 224 --epochs 100 --train_what last --schedule 60 80 --dropout 0.5 55 | ``` 56 | 57 | ### MemDPC pretrained weights 58 | 59 | * [MemDPC-ResNet34-K400-RGB-224](http://www.robots.ox.ac.uk/~htd/memdpc/k400-rgb-224_resnet34_memdpc.pth.tar) 60 | 61 | * [MemDPC-ResNet34-K400-Flow-224](http://www.robots.ox.ac.uk/~htd/memdpc/k400-flow-224_resnet34_memdpc.pth.tar) 62 | 63 | ### Citation 64 | 65 | If you find the repo useful for your research, please consider citing our paper: 66 | ``` 67 | @InProceedings{Han20, 68 | author = "Tengda Han and Weidi Xie and Andrew Zisserman", 69 | title = "Memory-augmented Dense Predictive Coding for Video Representation Learning", 70 | booktitle = "European Conference on Computer Vision", 71 | year = "2020", 72 | } 73 | ``` 74 | For any questions, welcome to create an issue or contact Tengda Han ([htd@robots.ox.ac.uk](mailto:htd@robots.ox.ac.uk)). 75 | 76 | 77 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/MemDPC/d7dbbf0dc6ec4aa8ff9a5dc8c189d78f4e5e34a7/utils/__init__.py -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | import numpy as np 6 | from PIL import ImageOps, Image 7 | from joblib import Parallel, delayed 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as F 12 | 13 | class Padding: 14 | def __init__(self, pad): 15 | self.pad = pad 16 | 17 | def __call__(self, img): 18 | return ImageOps.expand(img, border=self.pad, fill=0) 19 | 20 | class Scale: 21 | def __init__(self, size, interpolation=Image.NEAREST): 22 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 23 | self.size = size 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, imgmap): 27 | # assert len(imgmap) > 1 # list of images 28 | img1 = imgmap[0] 29 | if isinstance(self.size, int): 30 | w, h = img1.size 31 | if (w <= h and w == self.size) or (h <= w and h == self.size): 32 | return imgmap 33 | if w < h: 34 | ow = self.size 35 | oh = int(self.size * h / w) 36 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 37 | else: 38 | oh = self.size 39 | ow = int(self.size * w / h) 40 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 41 | else: 42 | return [i.resize(self.size, self.interpolation) for i in imgmap] 43 | 44 | 45 | class CenterCrop: 46 | def __init__(self, size, consistent=True): 47 | if isinstance(size, numbers.Number): 48 | self.size = (int(size), int(size)) 49 | else: 50 | self.size = size 51 | 52 | def __call__(self, imgmap): 53 | img1 = imgmap[0] 54 | w, h = img1.size 55 | th, tw = self.size 56 | x1 = int(round((w - tw) / 2.)) 57 | y1 = int(round((h - th) / 2.)) 58 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 59 | 60 | class FiveCrop: 61 | def __init__(self, size, where=1): 62 | # 1=topleft, 2=topright, 3=botleft, 4=botright, 5=center 63 | if isinstance(size, numbers.Number): 64 | self.size = (int(size), int(size)) 65 | else: 66 | self.size = size 67 | self.where = where 68 | 69 | def __call__(self, imgmap): 70 | img1 = imgmap[0] 71 | w, h = img1.size 72 | th, tw = self.size 73 | if (th > h) or (tw > w): 74 | raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size, (h,w))) 75 | if self.where == 1: 76 | return [i.crop((0, 0, tw, th)) for i in imgmap] 77 | elif self.where == 2: 78 | return [i.crop((w-tw, 0, w, th)) for i in imgmap] 79 | elif self.where == 3: 80 | return [i.crop((0, h-th, tw, h)) for i in imgmap] 81 | elif self.where == 4: 82 | return [i.crop((w-tw, h-tw, w, h)) for i in imgmap] 83 | elif self.where == 5: 84 | x1 = int(round((w - tw) / 2.)) 85 | y1 = int(round((h - th) / 2.)) 86 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 87 | 88 | 89 | class RandomCropWithProb: 90 | def __init__(self, size, p=0.8, consistent=True): 91 | if isinstance(size, numbers.Number): 92 | self.size = (int(size), int(size)) 93 | else: 94 | self.size = size 95 | self.consistent = consistent 96 | self.threshold = p 97 | 98 | def __call__(self, imgmap): 99 | img1 = imgmap[0] 100 | w, h = img1.size 101 | if self.size is not None: 102 | th, tw = self.size 103 | if w == tw and h == th: 104 | return imgmap 105 | if self.consistent: 106 | if random.random() < self.threshold: 107 | x1 = random.randint(0, w - tw) 108 | y1 = random.randint(0, h - th) 109 | else: 110 | x1 = int(round((w - tw) / 2.)) 111 | y1 = int(round((h - th) / 2.)) 112 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 113 | else: 114 | result = [] 115 | for i in imgmap: 116 | if random.random() < self.threshold: 117 | x1 = random.randint(0, w - tw) 118 | y1 = random.randint(0, h - th) 119 | else: 120 | x1 = int(round((w - tw) / 2.)) 121 | y1 = int(round((h - th) / 2.)) 122 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 123 | return result 124 | else: 125 | return imgmap 126 | 127 | class RandomCrop: 128 | def __init__(self, size, consistent=True): 129 | if isinstance(size, numbers.Number): 130 | self.size = (int(size), int(size)) 131 | else: 132 | self.size = size 133 | self.consistent = consistent 134 | 135 | def __call__(self, imgmap, flowmap=None): 136 | img1 = imgmap[0] 137 | w, h = img1.size 138 | if self.size is not None: 139 | th, tw = self.size 140 | if w == tw and h == th: 141 | return imgmap 142 | if not flowmap: 143 | if self.consistent: 144 | x1 = random.randint(0, w - tw) 145 | y1 = random.randint(0, h - th) 146 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 147 | else: 148 | result = [] 149 | for i in imgmap: 150 | x1 = random.randint(0, w - tw) 151 | y1 = random.randint(0, h - th) 152 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 153 | return result 154 | elif flowmap is not None: 155 | assert (not self.consistent) 156 | result = [] 157 | for idx, i in enumerate(imgmap): 158 | proposal = [] 159 | for j in range(3): # number of proposal: use the one with largest optical flow 160 | x = random.randint(0, w - tw) 161 | y = random.randint(0, h - th) 162 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 163 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 164 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 165 | return result 166 | else: 167 | raise ValueError('wrong case') 168 | else: 169 | return imgmap 170 | 171 | 172 | class RandomSizedCrop: 173 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 174 | self.size = size 175 | self.interpolation = interpolation 176 | self.consistent = consistent 177 | self.threshold = p 178 | 179 | def __call__(self, imgmap): 180 | img1 = imgmap[0] 181 | if random.random() < self.threshold: # do RandomSizedCrop 182 | for attempt in range(10): 183 | area = img1.size[0] * img1.size[1] 184 | target_area = random.uniform(0.5, 1) * area 185 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 186 | 187 | w = int(round(math.sqrt(target_area * aspect_ratio))) 188 | h = int(round(math.sqrt(target_area / aspect_ratio))) 189 | 190 | if self.consistent: 191 | if random.random() < 0.5: 192 | w, h = h, w 193 | if w <= img1.size[0] and h <= img1.size[1]: 194 | x1 = random.randint(0, img1.size[0] - w) 195 | y1 = random.randint(0, img1.size[1] - h) 196 | 197 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 198 | for i in imgmap: assert(i.size == (w, h)) 199 | 200 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 201 | else: 202 | result = [] 203 | for i in imgmap: 204 | if random.random() < 0.5: 205 | w, h = h, w 206 | if w <= img1.size[0] and h <= img1.size[1]: 207 | x1 = random.randint(0, img1.size[0] - w) 208 | y1 = random.randint(0, img1.size[1] - h) 209 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 210 | assert(result[-1].size == (w, h)) 211 | else: 212 | result.append(i) 213 | 214 | assert len(result) == len(imgmap) 215 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 216 | 217 | # Fallback 218 | scale = Scale(self.size, interpolation=self.interpolation) 219 | crop = CenterCrop(self.size) 220 | return crop(scale(imgmap)) 221 | else: # don't do RandomSizedCrop, do CenterCrop 222 | crop = CenterCrop(self.size) 223 | return crop(imgmap) 224 | 225 | 226 | class RandomHorizontalFlip: 227 | def __init__(self, consistent=True, command=None): 228 | self.consistent = consistent 229 | if command == 'left': 230 | self.threshold = 0 231 | elif command == 'right': 232 | self.threshold = 1 233 | else: 234 | self.threshold = 0.5 235 | def __call__(self, imgmap): 236 | if self.consistent: 237 | if random.random() < self.threshold: 238 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 239 | else: 240 | return imgmap 241 | else: 242 | result = [] 243 | for i in imgmap: 244 | if random.random() < self.threshold: 245 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 246 | else: 247 | result.append(i) 248 | assert len(result) == len(imgmap) 249 | return result 250 | 251 | 252 | class RandomGray: 253 | '''Actually it is a channel splitting, not strictly grayscale images''' 254 | def __init__(self, consistent=True, p=0.5): 255 | self.consistent = consistent 256 | self.p = p # probability to apply grayscale 257 | def __call__(self, imgmap): 258 | if self.consistent: 259 | if random.random() < self.p: 260 | return [self.grayscale(i) for i in imgmap] 261 | else: 262 | return imgmap 263 | else: 264 | result = [] 265 | for i in imgmap: 266 | if random.random() < self.p: 267 | result.append(self.grayscale(i)) 268 | else: 269 | result.append(i) 270 | assert len(result) == len(imgmap) 271 | return result 272 | 273 | def grayscale(self, img): 274 | channel = np.random.choice(3) 275 | np_img = np.array(img)[:,:,channel] 276 | np_img = np.dstack([np_img, np_img, np_img]) 277 | img = Image.fromarray(np_img, 'RGB') 278 | return img 279 | 280 | 281 | class ColorJitter(object): 282 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code 283 | Args: 284 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 285 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 286 | or the given [min, max]. Should be non negative numbers. 287 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 288 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 289 | or the given [min, max]. Should be non negative numbers. 290 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 291 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 292 | or the given [min, max]. Should be non negative numbers. 293 | hue (float or tuple of float (min, max)): How much to jitter hue. 294 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 295 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 296 | """ 297 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 298 | self.brightness = self._check_input(brightness, 'brightness') 299 | self.contrast = self._check_input(contrast, 'contrast') 300 | self.saturation = self._check_input(saturation, 'saturation') 301 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 302 | clip_first_on_zero=False) 303 | self.consistent = consistent 304 | self.threshold = p 305 | 306 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 307 | if isinstance(value, numbers.Number): 308 | if value < 0: 309 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 310 | value = [center - value, center + value] 311 | if clip_first_on_zero: 312 | value[0] = max(value[0], 0) 313 | elif isinstance(value, (tuple, list)) and len(value) == 2: 314 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 315 | raise ValueError("{} values should be between {}".format(name, bound)) 316 | else: 317 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 318 | 319 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 320 | # or (0., 0.) for hue, do nothing 321 | if value[0] == value[1] == center: 322 | value = None 323 | return value 324 | 325 | @staticmethod 326 | def get_params(brightness, contrast, saturation, hue): 327 | """Get a randomized transform to be applied on image. 328 | Arguments are same as that of __init__. 329 | Returns: 330 | Transform which randomly adjusts brightness, contrast and 331 | saturation in a random order. 332 | """ 333 | transforms = [] 334 | 335 | if brightness is not None: 336 | brightness_factor = random.uniform(brightness[0], brightness[1]) 337 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 338 | 339 | if contrast is not None: 340 | contrast_factor = random.uniform(contrast[0], contrast[1]) 341 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 342 | 343 | if saturation is not None: 344 | saturation_factor = random.uniform(saturation[0], saturation[1]) 345 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 346 | 347 | if hue is not None: 348 | hue_factor = random.uniform(hue[0], hue[1]) 349 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 350 | 351 | random.shuffle(transforms) 352 | transform = torchvision.transforms.Compose(transforms) 353 | 354 | return transform 355 | 356 | def __call__(self, imgmap): 357 | if random.random() < self.threshold: # do ColorJitter 358 | if self.consistent: 359 | transform = self.get_params(self.brightness, self.contrast, 360 | self.saturation, self.hue) 361 | return [transform(i) for i in imgmap] 362 | else: 363 | result = [] 364 | for img in imgmap: 365 | transform = self.get_params(self.brightness, self.contrast, 366 | self.saturation, self.hue) 367 | result.append(transform(img)) 368 | return result 369 | else: # don't do ColorJitter, do nothing 370 | return imgmap 371 | 372 | def __repr__(self): 373 | format_string = self.__class__.__name__ + '(' 374 | format_string += 'brightness={0}'.format(self.brightness) 375 | format_string += ', contrast={0}'.format(self.contrast) 376 | format_string += ', saturation={0}'.format(self.saturation) 377 | format_string += ', hue={0})'.format(self.hue) 378 | return format_string 379 | 380 | 381 | class RandomRotation: 382 | def __init__(self, consistent=True, degree=15, p=1.0): 383 | self.consistent = consistent 384 | self.degree = degree 385 | self.threshold = p 386 | def __call__(self, imgmap): 387 | if random.random() < self.threshold: # do RandomRotation 388 | if self.consistent: 389 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 390 | return [i.rotate(deg, expand=True) for i in imgmap] 391 | else: 392 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 393 | else: # don't do RandomRotation, do nothing 394 | return imgmap 395 | 396 | class ToTensor: 397 | def __call__(self, imgmap): 398 | totensor = transforms.ToTensor() 399 | return [totensor(i) for i in imgmap] 400 | 401 | class Normalize: 402 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 403 | self.mean = mean 404 | self.std = std 405 | def __call__(self, imgmap): 406 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 407 | return [normalize(i) for i in imgmap] 408 | 409 | 410 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | from datetime import datetime 6 | import glob 7 | import re 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | from collections import deque 11 | from tqdm import tqdm 12 | from torchvision import transforms 13 | 14 | def save_checkpoint(state, is_best=0, gap=1, filename='models/checkpoint.pth.tar', keep_all=False): 15 | torch.save(state, filename) 16 | last_epoch_path = os.path.join(os.path.dirname(filename), 17 | 'epoch%s.pth.tar' % str(state['epoch']-gap)) 18 | if not keep_all: 19 | try: os.remove(last_epoch_path) 20 | except: pass 21 | if is_best: 22 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'model_best_*.pth.tar')) 23 | for i in past_best: 24 | try: os.remove(i) 25 | except: pass 26 | torch.save(state, os.path.join(os.path.dirname(filename), 'model_best_epoch%s.pth.tar' % str(state['epoch']))) 27 | 28 | def write_log(content, epoch, filename): 29 | if not os.path.exists(filename): 30 | log_file = open(filename, 'w') 31 | else: 32 | log_file = open(filename, 'a') 33 | log_file.write('## Epoch %d:\n' % epoch) 34 | log_file.write('time: %s\n' % str(datetime.now())) 35 | log_file.write(content + '\n\n') 36 | log_file.close() 37 | 38 | class Logger(object): 39 | '''write something to txt file''' 40 | def __init__(self, path): 41 | self.birth_time = datetime.now() 42 | filepath = os.path.join(path, self.birth_time.strftime('%Y-%m-%d-%H:%M:%S')+'.log') 43 | self.filepath = filepath 44 | with open(filepath, 'a') as f: 45 | f.write(self.birth_time.strftime('%Y-%m-%d %H:%M:%S')+'\n') 46 | 47 | def log(self, string): 48 | with open(self.filepath, 'a') as f: 49 | time_stamp = datetime.now() - self.birth_time 50 | f.write(strfdelta(time_stamp,"{d}-{h:02d}:{m:02d}:{s:02d}")+'\t'+string+'\n') 51 | 52 | def calc_topk_accuracy(output, target, topk=(1,)): 53 | ''' 54 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 55 | Given predicted and ground truth labels, 56 | calculate top-k accuracies. 57 | ''' 58 | maxk = max(topk) 59 | batch_size = target.size(0) 60 | 61 | _, pred = output.topk(maxk, 1, True, True) 62 | pred = pred.t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | 65 | res = [] 66 | for k in topk: 67 | correct_k = correct[:k].view(-1).float().sum(0) 68 | res.append(correct_k.mul_(1 / batch_size)) 69 | return res 70 | 71 | def calc_accuracy(output, target): 72 | '''output: (B, N); target: (B)''' 73 | target = target.squeeze() 74 | _, pred = torch.max(output, 1) 75 | return torch.mean((pred == target).float()) 76 | 77 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 78 | assert len(mean)==len(std)==3 79 | inv_mean = [-mean[i]/std[i] for i in range(3)] 80 | inv_std = [1/i for i in std] 81 | return transforms.Normalize(mean=inv_mean, std=inv_std) 82 | 83 | 84 | def neq_load_customized(model, pretrained_dict): 85 | ''' load pre-trained model in a not-equal way, 86 | when new model has been partially modified ''' 87 | model_dict = model.state_dict() 88 | tmp = {} 89 | print('\n=======Check Weights Loading======') 90 | print('Weights not used from pretrained file:') 91 | for k, v in pretrained_dict.items(): 92 | if k in model_dict: 93 | tmp[k] = v 94 | else: 95 | print(k) 96 | print('---------------------------') 97 | print('Weights not loaded into new model:') 98 | for k, v in model_dict.items(): 99 | if k not in pretrained_dict: 100 | print(k) 101 | print('===================================\n') 102 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 103 | del pretrained_dict 104 | model_dict.update(tmp) 105 | del tmp 106 | model.load_state_dict(model_dict) 107 | return model 108 | 109 | 110 | class AverageMeter(object): 111 | """Computes and stores the average and current value""" 112 | def __init__(self): 113 | self.reset() 114 | 115 | def reset(self): 116 | self.val = 0 117 | self.avg = 0 118 | self.sum = 0 119 | self.count = 0 120 | self.local_history = deque([]) 121 | self.local_avg = 0 122 | self.history = [] 123 | self.dict = {} # save all data values here 124 | self.save_dict = {} # save mean and std here, for summary table 125 | 126 | def update(self, val, n=1, history=0, step=5): 127 | self.val = val 128 | self.sum += val * n 129 | self.count += n 130 | self.avg = self.sum / self.count 131 | if history: 132 | self.history.append(val) 133 | if step > 0: 134 | self.local_history.append(val) 135 | if len(self.local_history) > step: 136 | self.local_history.popleft() 137 | self.local_avg = np.average(self.local_history) 138 | 139 | def dict_update(self, val, key): 140 | if key in self.dict.keys(): 141 | self.dict[key].append(val) 142 | else: 143 | self.dict[key] = [val] 144 | 145 | def __len__(self): 146 | return self.count 147 | 148 | 149 | class AccuracyTable(object): 150 | '''compute accuracy for each class''' 151 | def __init__(self): 152 | self.dict = {} 153 | 154 | def update(self, pred, tar): 155 | pred = torch.squeeze(pred) 156 | tar = torch.squeeze(tar) 157 | for i, j in zip(pred, tar): 158 | i = int(i) 159 | j = int(j) 160 | if j not in self.dict.keys(): 161 | self.dict[j] = {'count':0,'correct':0} 162 | self.dict[j]['count'] += 1 163 | if i == j: 164 | self.dict[j]['correct'] += 1 165 | 166 | def print_table(self, label): 167 | for key in self.dict.keys(): 168 | acc = self.dict[key]['correct'] / self.dict[key]['count'] 169 | print('%s: %2d, accuracy: %3d/%3d = %0.6f' \ 170 | % (label, key, self.dict[key]['correct'], self.dict[key]['count'], acc)) 171 | 172 | 173 | class ConfusionMeter(object): 174 | '''compute and show confusion matrix''' 175 | def __init__(self, num_class): 176 | self.num_class = num_class 177 | self.mat = np.zeros((num_class, num_class)) 178 | self.precision = [] 179 | self.recall = [] 180 | 181 | def update(self, pred, tar): 182 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy() 183 | pred = np.squeeze(pred) 184 | tar = np.squeeze(tar) 185 | for p,t in zip(pred.flat, tar.flat): 186 | self.mat[p][t] += 1 187 | 188 | def print_mat(self): 189 | print('Confusion Matrix: (target in columns)') 190 | print(self.mat) 191 | 192 | def plot_mat(self, path, dictionary=None, annotate=False): 193 | plt.figure(dpi=600) 194 | plt.imshow(self.mat, 195 | cmap=plt.cm.jet, 196 | interpolation=None, 197 | extent=(0.5, np.shape(self.mat)[0]+0.5, np.shape(self.mat)[1]+0.5, 0.5)) 198 | width, height = self.mat.shape 199 | if annotate: 200 | for x in range(width): 201 | for y in range(height): 202 | plt.annotate(str(int(self.mat[x][y])), xy=(y+1, x+1), 203 | horizontalalignment='center', 204 | verticalalignment='center', 205 | fontsize=8) 206 | 207 | if dictionary is not None: 208 | plt.xticks([i+1 for i in range(width)], 209 | [dictionary[i] for i in range(width)], 210 | rotation='vertical') 211 | plt.yticks([i+1 for i in range(height)], 212 | [dictionary[i] for i in range(height)]) 213 | plt.xlabel('Ground Truth') 214 | plt.ylabel('Prediction') 215 | plt.colorbar() 216 | plt.tight_layout() 217 | plt.savefig(path, format='svg') 218 | plt.clf() 219 | 220 | for i in range(width): 221 | if np.sum(self.mat[i,:]) != 0: 222 | self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:])) 223 | if np.sum(self.mat[:,i]) != 0: 224 | self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i])) 225 | print('Average Precision: %0.4f' % np.mean(self.precision)) 226 | print('Average Recall: %0.4f' % np.mean(self.recall)) 227 | 228 | 229 | def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): 230 | '''return the multipier for LambdaLR, 231 | 0 <= ep < 10: gamma^0 232 | 10 <= ep < 15: gamma^1 233 | 15 <= ep < 20: gamma^2 234 | 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' 235 | max_step = max(step) 236 | effective_epoch = epoch % max_step 237 | if epoch // max_step >= repeat: 238 | exp = len(step) - 1 239 | else: 240 | exp = len([i for i in step if effective_epoch>=i]) 241 | return gamma ** exp 242 | 243 | 244 | def strfdelta(tdelta, fmt): 245 | d = {"d": tdelta.days} 246 | d["h"], rem = divmod(tdelta.seconds, 3600) 247 | d["m"], d["s"] = divmod(rem, 60) 248 | return fmt.format(**d) 249 | 250 | class Logger(object): 251 | '''write something to txt file''' 252 | def __init__(self, path): 253 | self.birth_time = datetime.now() 254 | filepath = os.path.join(path, self.birth_time.strftime('%Y-%m-%d-%H:%M:%S')+'.log') 255 | self.filepath = filepath 256 | with open(filepath, 'a') as f: 257 | f.write(self.birth_time.strftime('%Y-%m-%d %H:%M:%S')+'\n') 258 | 259 | def log(self, string): 260 | with open(self.filepath, 'a') as f: 261 | time_stamp = datetime.now() - self.birth_time 262 | f.write(strfdelta(time_stamp,"{d}-{h:02d}:{m:02d}:{s:02d}")+'\t'+string+'\n') 263 | --------------------------------------------------------------------------------