├── LICENSE ├── README.md ├── assets ├── intuition.png └── results.png ├── backbone ├── convrnn.py ├── resnet_2d3d.py └── select_backbone.py ├── process_data ├── readme.md └── src │ ├── build_rawframes_optimized.py │ ├── extract_features.py │ ├── extract_frame.py │ └── write_csv.py ├── requirements.txt ├── test ├── dataset_3d_lc.py ├── model_3d_lc.py ├── test.py └── transform_utils.py ├── train ├── data_utils.py ├── dataset_3d.py ├── finetune_utils.py ├── mask_utils.py ├── model_3d.py ├── model_trainer.py ├── model_utils.py └── sim_utils.py └── utils ├── augmentation.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nishant Rai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CoCon: Coooperative Contrastive Learning for Video Representation Learning 2 | 3 | This repository contains the implementation of [CoCon - Cooperative Contrastive Learning for video representation 4 | learning](https://arxiv.org/abs/2104.14764). We utilize multiple views of videos in order to learn better representations capturing semantics suitable for tasks related to video understanding. CoCon was presented at [BayLearn 2020](http://www.baylearn.org/overview) and will be part of [Holistic Video Understanding at CVPR '21](https://holistic-video-understanding.github.io/workshops/cvpr2021.html#awards). 5 | 6 | ![arch](assets/intuition.png) 7 | 8 | ### Authors 9 | 10 | * [Nishant Rai - Stanford University](https://www.linkedin.com/in/nishantrai18/) 11 | * Ehsan Adeli - Stanford University 12 | * Kuan-Hui Lee - Toyota Research Institute 13 | * Adrien Gaidon - Toyota Research Institute 14 | * Juan Carlos Niebles - Stanford University 15 | 16 | ### Installation 17 | 18 | Our implementation should work with python >= 3.6, pytorch >= 0.4, torchvision >= 0.2.2. The repo also requires cv2 19 | (`conda install -c menpo opencv`), tensorboardX >= 1.7 (`pip install tensorboardX`), tqdm. 20 | 21 | A requirements.txt has been provided which can be used to create the exact environment required. 22 | ``` 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Prepare data 27 | 28 | Follow the instructions [here](process_data/). Instructions to generate multi-view data for custom datasets will be 29 | added soon. 30 | 31 | ### Cooperative Contrastive Learning (CoCon) 32 | 33 | Training scripts are present in `cd CoCon/train/` 34 | 35 | Run `python model_trainer.py --help` to get details about the command lines args. The most useful ones are `--dataset 36 | ` and `--modalities`, which are used to change the dataset we're supposed to run our experiments along with the input 37 | modalities to use. 38 | 39 | Our implementation has been tested with RGB, Optical Flow, Segmentation Masks, Human Keypoints 40 | . However, it is easy to extend it to custom views; look at `dataset_3d.py` for details. 41 | 42 | * Single View Training: train CoCon using 2 GPUs, using RGB inputs, with a 3D-ResNet18 backbone, on UCF101 with 224x224 43 | resolution, for 100 epochs. Batch size is per-gpu. 44 | ``` 45 | CUDA_VISIBLE_DEVICES="0,1" python model_trainer.py --net resnet18 --dataset ucf101 --modalities imgs 46 | --batch_size 16 --img_dim 224 --epochs 100 47 | ``` 48 | 49 | * Multi-View Training: train CoCon using 4 GPUs, using RGB, Flow, Pose, Keypoints inputs, with a 3D-ResNet18 backbone 50 | , on HMDB51 with 128x128 resolution, for 100 epochs 51 | ``` 52 | CUDA_VISIBLE_DEVICES="0,1,2,3" python model_trainer.py --net resnet18 --dataset hmdb 53 | --modalities imgs_flow_seg_kphm --batch_size 16 --img_dim 128 --epochs 100 54 | ``` 55 | 56 | * Heavy Multi-View Training: train CoCon using 4 GPUs, using RGB, Flow inputs, with 3D-ResNet34 backbone, on Kinetics400 57 | dataset with 128x128 resolution, for 50 epochs 58 | ``` 59 | CUDA_VISIBLE_DEVICES="0,1,2,3" python model_trainer.py --net resnet18 --dataset kinetics 60 | --modalities imgs_flow --batch_size 8 --img_dim 128 --epochs 50 61 | ``` 62 | 63 | ### Evaluation: Video Action Recognition 64 | 65 | Testing scripts are present in `cd CoCon/test/` 66 | 67 | * Evaluate model: Fine-tune pre-trained weights (replace `model_path` with pretrained weights) 68 | ``` 69 | python test.py --net resnet18 --dataset ucf101 --modality imgs --batch_size 8 --img_dim 128 70 | --pretrain {model_path} --epochs 100 71 | ``` 72 | 73 | ### Results 74 | 75 | ![arch](assets/results.png) 76 | 77 | ### Qualitative Evaluation 78 | 79 | Scripts for qualitative evaluation will be added here. 80 | 81 | ### Acknowledgements 82 | 83 | Portions of code have been borrowed from [DPC](https://github.com/TengdaHan/DPC). Feel free to refer to their great 84 | work as well if you're interested in the field. 85 | 86 | ### Citing 87 | 88 | If our paper or the codebase was useful to you, please consider citing it using the below. 89 | 90 | ``` 91 | @InProceedings{Rai_2021_CVPR, 92 | author = {Rai, Nishant and Adeli, Ehsan and Lee, Kuan-Hui and Gaidon, Adrien and Niebles, Juan Carlos}, 93 | title = {CoCon: Cooperative-Contrastive Learning}, 94 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 95 | month = {June}, 96 | year = {2021}, 97 | pages = {3384-3393} 98 | } 99 | ``` 100 | 101 | ### Keywords 102 | * Multi-view Video Representation Learning 103 | * Video Contrastive Learning 104 | * Multi-view Self-supervised Learning 105 | -------------------------------------------------------------------------------- /assets/intuition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantrai18/cocon/f8ffd8a9988b0344cd7759ed02743d29da7a17b9/assets/intuition.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantrai18/cocon/f8ffd8a9988b0344cd7759ed02743d29da7a17b9/assets/results.png -------------------------------------------------------------------------------- /backbone/convrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvGRUCell(nn.Module): 5 | ''' Initialize ConvGRU cell ''' 6 | def __init__(self, input_size, hidden_size, kernel_size): 7 | super(ConvGRUCell, self).__init__() 8 | self.input_size = input_size 9 | self.hidden_size = hidden_size 10 | self.kernel_size = kernel_size 11 | padding = kernel_size // 2 12 | 13 | self.reset_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 14 | self.update_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 15 | self.out_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 16 | 17 | nn.init.orthogonal_(self.reset_gate.weight) 18 | nn.init.orthogonal_(self.update_gate.weight) 19 | nn.init.orthogonal_(self.out_gate.weight) 20 | nn.init.constant_(self.reset_gate.bias, 0.0) 21 | nn.init.constant_(self.update_gate.bias, 0.0) 22 | nn.init.constant_(self.out_gate.bias, 0.) 23 | 24 | def forward(self, input_tensor, hidden_state): 25 | if hidden_state is None: 26 | B, C, *spatial_dim = input_tensor.size() 27 | hidden_state = torch.zeros([B,self.hidden_size,*spatial_dim]).to(input_tensor.device) 28 | # [B, C, H, W] 29 | combined = torch.cat([input_tensor, hidden_state], dim=1) #concat in C 30 | update = torch.sigmoid(self.update_gate(combined)) 31 | reset = torch.sigmoid(self.reset_gate(combined)) 32 | out = torch.tanh(self.out_gate(torch.cat([input_tensor, hidden_state * reset], dim=1))) 33 | new_state = hidden_state * (1 - update) + out * update 34 | return new_state 35 | 36 | 37 | class ConvGRU(nn.Module): 38 | ''' Initialize a multi-layer Conv GRU ''' 39 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, dropout=0.1): 40 | super(ConvGRU, self).__init__() 41 | self.input_size = input_size 42 | self.hidden_size = hidden_size 43 | self.kernel_size = kernel_size 44 | self.num_layers = num_layers 45 | 46 | cell_list = [] 47 | for i in range(self.num_layers): 48 | if i == 0: 49 | input_dim = self.input_size 50 | else: 51 | input_dim = self.hidden_size 52 | cell = ConvGRUCell(input_dim, self.hidden_size, self.kernel_size) 53 | name = 'ConvGRUCell_' + str(i).zfill(2) 54 | 55 | setattr(self, name, cell) 56 | cell_list.append(getattr(self, name)) 57 | 58 | self.cell_list = nn.ModuleList(cell_list) 59 | self.dropout_layer = nn.Dropout(p=dropout) 60 | 61 | def forward(self, x, hidden_state=None): 62 | [B, seq_len, *_] = x.size() 63 | 64 | if hidden_state is None: 65 | hidden_state = [None] * self.num_layers 66 | # input: image sequences [B, T, C, H, W] 67 | current_layer_input = x 68 | del x 69 | 70 | last_state_list = [] 71 | 72 | for idx in range(self.num_layers): 73 | cell_hidden = hidden_state[idx] 74 | output_inner = [] 75 | for t in range(seq_len): 76 | cell_hidden = self.cell_list[idx](current_layer_input[:,t,:], cell_hidden) 77 | cell_hidden = self.dropout_layer(cell_hidden) # dropout in each time step 78 | output_inner.append(cell_hidden) 79 | 80 | layer_output = torch.stack(output_inner, dim=1) 81 | current_layer_input = layer_output 82 | 83 | last_state_list.append(cell_hidden) 84 | 85 | last_state_list = torch.stack(last_state_list, dim=1) 86 | 87 | return layer_output, last_state_list 88 | 89 | 90 | if __name__ == '__main__': 91 | crnn = ConvGRU(input_size=10, hidden_size=20, kernel_size=3, num_layers=2) 92 | data = torch.randn(4, 5, 10, 6, 6) # [B, seq_len, C, H, W], temporal axis=1 93 | output, hn = crnn(data) 94 | import ipdb; ipdb.set_trace() 95 | -------------------------------------------------------------------------------- /backbone/resnet_2d3d.py: -------------------------------------------------------------------------------- 1 | ## modified from https://github.com/kenshohara/3D-ResNets-PyTorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | 8 | __all__ = [ 9 | 'ResNet2d3d_full', 'resnet18_2d3d_full', 'resnet34_2d3d_full', 'resnet50_2d3d_full', 'resnet101_2d3d_full', 10 | 'resnet152_2d3d_full', 'resnet200_2d3d_full', 11 | ] 12 | 13 | def conv3x3x3(in_planes, out_planes, stride=1, bias=False): 14 | # 3x3x3 convolution with padding 15 | return nn.Conv3d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=bias) 22 | 23 | def conv1x3x3(in_planes, out_planes, stride=1, bias=False): 24 | # 1x3x3 convolution with padding 25 | return nn.Conv3d( 26 | in_planes, 27 | out_planes, 28 | kernel_size=(1,3,3), 29 | stride=(1,stride,stride), 30 | padding=(0,1,1), 31 | bias=bias) 32 | 33 | 34 | def downsample_basic_block(x, planes, stride): 35 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 36 | zero_pads = torch.Tensor( 37 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 38 | out.size(4)).zero_() 39 | if isinstance(out.data, torch.cuda.FloatTensor): 40 | zero_pads = zero_pads.cuda() 41 | 42 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 43 | 44 | return out 45 | 46 | 47 | class BasicBlock3d(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 51 | super(BasicBlock3d, self).__init__() 52 | bias = False 53 | self.use_final_relu = use_final_relu 54 | self.conv1 = conv3x3x3(inplanes, planes, stride, bias=bias) 55 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 56 | 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3x3(planes, planes, bias=bias) 59 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | if self.use_final_relu: out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class BasicBlock2d(nn.Module): 84 | expansion = 1 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 87 | super(BasicBlock2d, self).__init__() 88 | bias = False 89 | self.use_final_relu = use_final_relu 90 | self.conv1 = conv1x3x3(inplanes, planes, stride, bias=bias) 91 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 92 | 93 | self.relu = nn.ReLU(inplace=True) 94 | self.conv2 = conv1x3x3(planes, planes, bias=bias) 95 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 96 | 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | if self.use_final_relu: out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class Bottleneck3d(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 123 | super(Bottleneck3d, self).__init__() 124 | bias = False 125 | self.use_final_relu = use_final_relu 126 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 127 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 128 | 129 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias) 130 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 131 | 132 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 133 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 134 | 135 | self.relu = nn.ReLU(inplace=True) 136 | self.downsample = downsample 137 | self.stride = stride 138 | 139 | def forward(self, x): 140 | residual = x 141 | 142 | out = self.conv1(x) 143 | out = self.bn1(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv2(out) 147 | out = self.bn2(out) 148 | out = self.relu(out) 149 | 150 | out = self.conv3(out) 151 | out = self.bn3(out) 152 | 153 | if self.downsample is not None: 154 | residual = self.downsample(x) 155 | 156 | out += residual 157 | if self.use_final_relu: out = self.relu(out) 158 | 159 | return out 160 | 161 | 162 | class Bottleneck2d(nn.Module): 163 | expansion = 4 164 | 165 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 166 | super(Bottleneck2d, self).__init__() 167 | bias = False 168 | self.use_final_relu = use_final_relu 169 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 170 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 171 | 172 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 173 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 174 | 175 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 176 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 177 | 178 | self.relu = nn.ReLU(inplace=True) 179 | self.downsample = downsample 180 | self.stride = stride 181 | self.batchnorm = True 182 | 183 | def forward(self, x): 184 | residual = x 185 | 186 | out = self.conv1(x) 187 | if self.batchnorm: out = self.bn1(out) 188 | out = self.relu(out) 189 | 190 | out = self.conv2(out) 191 | if self.batchnorm: out = self.bn2(out) 192 | out = self.relu(out) 193 | 194 | out = self.conv3(out) 195 | if self.batchnorm: out = self.bn3(out) 196 | 197 | if self.downsample is not None: 198 | residual = self.downsample(x) 199 | 200 | out += residual 201 | if self.use_final_relu: out = self.relu(out) 202 | 203 | return out 204 | 205 | 206 | class ResNet2d3d_full(nn.Module): 207 | def __init__(self, block, layers, track_running_stats=True, in_channels=3): 208 | super(ResNet2d3d_full, self).__init__() 209 | self.inplanes = 64 210 | self.track_running_stats = track_running_stats 211 | bias = False 212 | self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(1,7,7), stride=(1, 2, 2), padding=(0, 3, 3), bias=bias) 213 | self.bn1 = nn.BatchNorm3d(64, track_running_stats=track_running_stats) 214 | self.relu = nn.ReLU(inplace=True) 215 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 216 | 217 | if not isinstance(block, list): 218 | block = [block] * 4 219 | 220 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 221 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2) 222 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2) 223 | self.layer4 = self._make_layer(block[3], 256, layers[3], stride=2, is_final=True) 224 | # modify layer4 from exp=512 to exp=256 225 | for m in self.modules(): 226 | if isinstance(m, nn.Conv3d): 227 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 228 | if m.bias is not None: m.bias.data.zero_() 229 | elif isinstance(m, nn.BatchNorm3d): 230 | m.weight.data.fill_(1) 231 | m.bias.data.zero_() 232 | 233 | def _make_layer(self, block, planes, blocks, stride=1, is_final=False): 234 | downsample = None 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | # customized_stride to deal with 2d or 3d residual blocks 237 | if (block == Bottleneck2d) or (block == BasicBlock2d): 238 | customized_stride = (1, stride, stride) 239 | else: 240 | customized_stride = stride 241 | 242 | downsample = nn.Sequential( 243 | nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=customized_stride, bias=False), 244 | nn.BatchNorm3d(planes * block.expansion, track_running_stats=self.track_running_stats) 245 | ) 246 | 247 | layers = [] 248 | layers.append(block(self.inplanes, planes, stride, downsample, track_running_stats=self.track_running_stats)) 249 | self.inplanes = planes * block.expansion 250 | if is_final: # if is final block, no ReLU in the final output 251 | for i in range(1, blocks-1): 252 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 253 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats, use_final_relu=False)) 254 | else: 255 | for i in range(1, blocks): 256 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 257 | 258 | return nn.Sequential(*layers) 259 | 260 | def forward(self, x): 261 | x = self.conv1(x) 262 | x = self.bn1(x) 263 | x = self.relu(x) 264 | x = self.maxpool(x) 265 | 266 | x = self.layer1(x) 267 | x = self.layer2(x) 268 | x = self.layer3(x) 269 | x = self.layer4(x) 270 | 271 | return x 272 | 273 | 274 | ## full resnet 275 | def resnet18_2d3d_full(**kwargs): 276 | '''Constructs a ResNet-18 model. ''' 277 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 278 | [2, 2, 2, 2], **kwargs) 279 | return model 280 | 281 | def resnet34_2d3d_full(**kwargs): 282 | '''Constructs a ResNet-34 model. ''' 283 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 284 | [3, 4, 6, 3], **kwargs) 285 | return model 286 | 287 | def resnet50_2d3d_full(**kwargs): 288 | '''Constructs a ResNet-50 model. ''' 289 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 290 | [3, 4, 6, 3], **kwargs) 291 | return model 292 | 293 | def resnet101_2d3d_full(**kwargs): 294 | '''Constructs a ResNet-101 model. ''' 295 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 296 | [3, 4, 23, 3], **kwargs) 297 | return model 298 | 299 | def resnet152_2d3d_full(**kwargs): 300 | '''Constructs a ResNet-101 model. ''' 301 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 302 | [3, 8, 36, 3], **kwargs) 303 | return model 304 | 305 | def resnet200_2d3d_full(**kwargs): 306 | '''Constructs a ResNet-101 model. ''' 307 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 308 | [3, 24, 36, 3], **kwargs) 309 | return model 310 | 311 | def neq_load_customized(model, pretrained_dict): 312 | ''' load pre-trained model in a not-equal way, 313 | when new model has been partially modified ''' 314 | model_dict = model.state_dict() 315 | tmp = {} 316 | print('\n=======Check Weights Loading======') 317 | print('Weights not used from pretrained file:') 318 | names = [] 319 | for k, v in pretrained_dict.items(): 320 | if k in model_dict: 321 | tmp[k] = v 322 | else: 323 | names.append(k) 324 | print(set([k.split('.')[-1] for k in names])) 325 | print('---------------------------') 326 | print('Weights not loaded into new model:') 327 | names = [] 328 | for k, v in model_dict.items(): 329 | if k not in pretrained_dict: 330 | names.append(k) 331 | print(set([k.split('.')[-1] for k in names])) 332 | print('===================================\n') 333 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 334 | del pretrained_dict 335 | model_dict.update(tmp) 336 | del tmp 337 | model.load_state_dict(model_dict) 338 | return model 339 | 340 | 341 | if __name__ == '__main__': 342 | mymodel = resnet18_2d3d_full() 343 | mydata = torch.FloatTensor(4, 3, 16, 128, 128) 344 | nn.init.normal_(mydata) 345 | import ipdb; ipdb.set_trace() 346 | mymodel(mydata) 347 | -------------------------------------------------------------------------------- /backbone/select_backbone.py: -------------------------------------------------------------------------------- 1 | from resnet_2d3d import * 2 | 3 | def select_resnet(network, track_running_stats=True, in_channels=3): 4 | param = {'feature_size': 1024} 5 | if network == 'resnet18': 6 | model = resnet18_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 7 | param['feature_size'] = 256 8 | elif network == 'resnet34': 9 | model = resnet34_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 10 | param['feature_size'] = 256 11 | elif network == 'resnet50': 12 | model = resnet50_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 13 | elif network == 'resnet101': 14 | model = resnet101_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 15 | elif network == 'resnet152': 16 | model = resnet152_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 17 | elif network == 'resnet200': 18 | model = resnet200_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 19 | else: raise IOError('model type is wrong') 20 | 21 | return model, param -------------------------------------------------------------------------------- /process_data/readme.md: -------------------------------------------------------------------------------- 1 | ## Process data 2 | 3 | This folder has some tools to process UCF101, HMDB51 and Kinetics400 datasets. 4 | 5 | ### 1. Download 6 | 7 | Download the videos from source: 8 | [UCF101 source](https://www.crcv.ucf.edu/data/UCF101.php), 9 | [HMDB51 source](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads), 10 | [Kinetics400 source](https://deepmind.com/research/publications/kinetics-human-action-video-dataset). 11 | 12 | Make sure datasets are stored as follows: 13 | 14 | * UCF101 15 | ``` 16 | {your_path}/UCF101/videos/{action class}/{video name}.avi 17 | {your_path}/UCF101/splits_classification/trainlist{01/02/03}.txt 18 | {your_path}/UCF101/splits_classification/testlist{01/02/03}}.txt 19 | ``` 20 | 21 | * HMDB51 22 | ``` 23 | {your_path}/HMDB51/videos/{action class}/{video name}.avi 24 | {your_path}/HMDB51/split/testTrainMulti_7030_splits/{action class}_test_split{1/2/3}.txt 25 | ``` 26 | 27 | * Kinetics400 28 | ``` 29 | {your_path}/Kinetics400/videos/train_split/{action class}/{video name}.mp4 30 | {your_path}/Kinetics400/videos/val_split/{action class}/{video name}.mp4 31 | ``` 32 | Also keep the downloaded csv files, make sure you have: 33 | ``` 34 | {your_path}/Kinetics/kinetics_train/kinetics_train.csv 35 | {your_path}/Kinetics/kinetics_val/kinetics_val.csv 36 | {your_path}/Kinetics/kinetics_test/kinetics_test.csv 37 | ``` 38 | 39 | ### 2. Extract frames 40 | 41 | Edit path arguments in `main_*()` functions, and `python extract_frame.py`. Video frames will be extracted. 42 | 43 | ### 3. Collect all paths into csv 44 | 45 | Edit path arguments in `main_*()` functions, and `python write_csv.py`. csv files will be stored in `data/` directory. 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /process_data/src/build_rawframes_optimized.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import os.path as osp 5 | import glob 6 | import cv2 7 | 8 | from pipes import quote 9 | from multiprocessing import Pool, current_process 10 | from tqdm import tqdm 11 | from subprocess import check_call,CalledProcessError 12 | 13 | import mmcv 14 | 15 | 16 | def dump_frames(vid_item): 17 | full_path, vid_path, vid_id = vid_item 18 | vid_name = vid_path.split('.')[0] 19 | out_full_path = osp.join(args.out_dir, vid_name) 20 | try: 21 | os.mkdir(out_full_path) 22 | except OSError: 23 | pass 24 | vr = mmcv.VideoReader(full_path) 25 | for i in range(len(vr)): 26 | if vr[i] is not None: 27 | mmcv.imwrite( 28 | vr[i], '{}/img_{:05d}.jpg'.format(out_full_path, i + 1)) 29 | else: 30 | print('[Warning] length inconsistent!' 31 | 'Early stop with {} out of {} frames'.format(i + 1, len(vr))) 32 | break 33 | print('{} done with {} frames'.format(vid_name, len(vr))) 34 | sys.stdout.flush() 35 | return True 36 | 37 | 38 | def num_frames_in_vid(v_path): 39 | vidcap = cv2.VideoCapture(v_path) 40 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 41 | vidcap.release() 42 | return nb_frames 43 | 44 | 45 | def run_optical_flow(vid_item, dev_id=0): 46 | full_path, vid_path, vid_id = vid_item 47 | vid_name = vid_path.split('.')[0] 48 | out_full_path = osp.join(args.out_dir, vid_name) 49 | try: 50 | os.mkdir(out_full_path) 51 | except OSError: 52 | pass 53 | 54 | current = current_process() 55 | dev_id = (int(current._identity[0]) - 1) % args.num_gpu 56 | image_path = '{}/img'.format(out_full_path) 57 | flow_x_path = '{}/flow_x'.format(out_full_path) 58 | flow_y_path = '{}/flow_y'.format(out_full_path) 59 | 60 | num_frames = num_frames_in_vid(full_path) 61 | if os.path.exists(image_path + '_%05d.jpg' % (num_frames - 1)): 62 | return True 63 | 64 | try: 65 | check_call( 66 | [osp.join(args.df_path, 'build/extract_gpu'), 67 | '-f={}'.format(quote(full_path)), '-x={}'.format(quote(flow_x_path)), '-y={}'.format(quote(flow_y_path)), 68 | '-i={}'.format(quote(image_path)), '-b=20', '-t=0', '-d={}'.format(dev_id), 69 | '-s=1', '-o={}'.format(args.out_format), '-w={}'.format(args.new_width), '-h={}'.format(args.new_height)] 70 | ) 71 | except CalledProcessError as e: 72 | print(e.stdout()) 73 | 74 | return True 75 | 76 | 77 | def run_warp_optical_flow(vid_item, dev_id=0): 78 | full_path, vid_path, vid_id = vid_item 79 | vid_name = vid_path.split('.')[0] 80 | out_full_path = osp.join(args.out_dir, vid_name) 81 | try: 82 | os.mkdir(out_full_path) 83 | except OSError: 84 | pass 85 | 86 | current = current_process() 87 | dev_id = (int(current._identity[0]) - 1) % args.num_gpu 88 | flow_x_path = '{}/flow_x'.format(out_full_path) 89 | flow_y_path = '{}/flow_y'.format(out_full_path) 90 | 91 | cmd = osp.join(args.df_path + 'build/extract_warp_gpu') + \ 92 | ' -f={} -x={} -y={} -b=20 -t=1 -d={} -s=1 -o={}'.format( 93 | quote(full_path), quote(flow_x_path), quote(flow_y_path), 94 | dev_id, args.out_format) 95 | 96 | os.system(cmd) 97 | print('warp on {} {} done'.format(vid_id, vid_name)) 98 | sys.stdout.flush() 99 | return True 100 | 101 | 102 | def parse_args(): 103 | parser = argparse.ArgumentParser(description='extract optical flows') 104 | parser.add_argument('src_dir', type=str) 105 | parser.add_argument('out_dir', type=str) 106 | parser.add_argument('--level', type=int, 107 | choices=[1, 2], 108 | default=2) 109 | parser.add_argument('--num_worker', type=int, default=8) 110 | parser.add_argument('--flow_type', type=str, 111 | default=None, choices=[None, 'tvl1', 'warp_tvl1']) 112 | parser.add_argument('--df_path', type=str, 113 | default='../mmaction/third_party/dense_flow') 114 | parser.add_argument("--out_format", type=str, default='dir', 115 | choices=['dir', 'zip'], help='output format') 116 | parser.add_argument("--ext", type=str, default='avi', 117 | choices=['avi', 'mp4'], help='video file extensions') 118 | parser.add_argument("--new_width", type=int, default=0, 119 | help='resize image width') 120 | parser.add_argument("--new_height", type=int, 121 | default=0, help='resize image height') 122 | parser.add_argument("--num_gpu", type=int, default=8, help='number of GPU') 123 | parser.add_argument("--resume", action='store_true', default=False, 124 | help='resume optical flow extraction ' 125 | 'instead of overwriting') 126 | parser.add_argument("--debug", type=int, default=0, help='debug mode') 127 | args = parser.parse_args() 128 | 129 | return args 130 | 131 | 132 | if __name__ == '__main__': 133 | args = parse_args() 134 | 135 | if not osp.isdir(args.out_dir): 136 | print('Creating folder: {}'.format(args.out_dir)) 137 | os.makedirs(args.out_dir) 138 | if args.level == 2: 139 | classes = os.listdir(args.src_dir) 140 | for classname in classes: 141 | new_dir = osp.join(args.out_dir, classname) 142 | if not osp.isdir(new_dir): 143 | print('Creating folder: {}'.format(new_dir)) 144 | os.makedirs(new_dir) 145 | 146 | print('Reading videos from folder: ', args.src_dir) 147 | print('Extension of videos: ', args.ext) 148 | if args.level == 2: 149 | fullpath_list = glob.glob(args.src_dir + '/*/*.' + args.ext) 150 | done_fullpath_list = glob.glob(args.out_dir + '/*/*') 151 | elif args.level == 1: 152 | fullpath_list = glob.glob(args.src_dir + '/*.' + args.ext) 153 | done_fullpath_list = glob.glob(args.out_dir + '/*') 154 | print('Total number of videos found: ', len(fullpath_list)) 155 | if args.resume: 156 | fullpath_list = set(fullpath_list).difference(set(done_fullpath_list)) 157 | fullpath_list = list(fullpath_list) 158 | print('Resuming. number of videos to be done: ', len(fullpath_list)) 159 | 160 | fullpath_list = sorted(fullpath_list) 161 | 162 | if args.level == 2: 163 | vid_list = list(map(lambda p: osp.join( 164 | '/'.join(p.split('/')[-2:])), fullpath_list)) 165 | elif args.level == 1: 166 | vid_list = list(map(lambda p: p.split('/')[-1], fullpath_list)) 167 | 168 | if args.debug: 169 | K = 5 170 | fullpath_list = fullpath_list[:K] 171 | vid_list = vid_list[:K] 172 | args.num_worker = 4 173 | 174 | pbar = tqdm(total=len(vid_list), smoothing=0.001) 175 | pool = Pool(args.num_worker) 176 | 177 | def update(*a): 178 | pbar.update() 179 | 180 | call_func = None 181 | if args.flow_type == 'tvl1': 182 | call_func = run_optical_flow 183 | elif args.flow_type == 'warp_tvl1': 184 | call_func = run_warp_optical_flow 185 | else: 186 | call_func = dump_frames 187 | 188 | for arg in zip(fullpath_list, vid_list, range(len(vid_list))): 189 | pool.apply_async(call_func, args=(arg,), callback=update) 190 | 191 | pool.close() 192 | pool.join() -------------------------------------------------------------------------------- /process_data/src/extract_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import cv2 4 | import os 5 | import glob 6 | import torch 7 | 8 | cv2.setNumThreads(0) 9 | 10 | from tqdm import tqdm 11 | from torch.utils import data 12 | from typing import Dict, List, Union 13 | 14 | # import some common detectron2 utilities 15 | from detectron2 import model_zoo 16 | from detectron2.config import get_cfg 17 | from detectron2.modeling import build_model 18 | from detectron2.checkpoint import DetectionCheckpointer 19 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 20 | from detectron2.structures import Boxes, ImageList, Instances 21 | from detectron2.layers import interpolate, cat 22 | from detectron2.utils.logger import setup_logger 23 | setup_logger() 24 | 25 | 26 | def str2bool(s): 27 | """Convert string to bool (in argparse context).""" 28 | if s.lower() not in ['true', 'false']: 29 | raise ValueError('Need bool; got %r' % s) 30 | return {'true': True, 'false': False}[s.lower()] 31 | 32 | 33 | imgShape = None 34 | 35 | from typing import Dict, List, Optional, Tuple, Union 36 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 37 | from detectron2.structures import Boxes, ImageList, Instances 38 | from detectron2.layers import interpolate, cat 39 | 40 | 41 | @torch.no_grad() 42 | def process_heatmaps(maps, rois, img_shapes): 43 | """ 44 | Extract predicted keypoint locations from heatmaps. 45 | Args: 46 | maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for 47 | each ROI and each keypoint. 48 | rois (Tensor): (#ROIs, 4). The box of each ROI. 49 | Returns: 50 | Tensor of shape (#ROIs, #keypoints, POOL_H, POOL_W) representing confidence scores 51 | """ 52 | 53 | offset_i = (rois[:, 1]).int() 54 | offset_j = (rois[:, 0]).int() 55 | 56 | widths = (rois[:, 2] - rois[:, 0]).clamp(min=1) 57 | heights = (rois[:, 3] - rois[:, 1]).clamp(min=1) 58 | widths_ceil = widths.ceil() 59 | heights_ceil = heights.ceil() 60 | 61 | # roi_map_scores = torch.zeros((maps.shape[0], maps.shape[1], imgShape[0], imgShape[1])) 62 | roi_map_scores = [torch.zeros((maps.shape[1], img_shapes[i][0], img_shapes[i][1])) for i in range(maps.shape[0])] 63 | num_rois, num_keypoints = maps.shape[:2] 64 | 65 | for i in range(num_rois): 66 | outsize = (int(heights_ceil[i]), int(widths_ceil[i])) 67 | # #keypoints x H x W 68 | roi_map = interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False).squeeze(0) 69 | 70 | # softmax over the spatial region 71 | max_score, _ = roi_map.view(num_keypoints, -1).max(1) 72 | max_score = max_score.view(num_keypoints, 1, 1) 73 | tmp_full_resolution = (roi_map - max_score).exp_() 74 | tmp_pool_resolution = (maps[i] - max_score).exp_() 75 | 76 | norm_score = ((tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True)) * 255.0).to(torch.uint8) 77 | 78 | # Produce scores over the region H x W, but normalize with POOL_H x POOL_W, 79 | # so that the scores of objects of different absolute sizes will be more comparable 80 | for idx in range(num_keypoints): 81 | roi_map_scores[i][idx, offset_i[i]:(offset_i[i] + outsize[0]), offset_j[i]:(offset_j[i] + outsize[1])] = \ 82 | norm_score[idx, ...].float() 83 | 84 | return roi_map_scores 85 | 86 | 87 | def heatmap_rcnn_inference(pred_keypoint_logits, pred_instances): 88 | bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) 89 | 90 | num_instances_per_image = [len(i) for i in pred_instances] 91 | img_shapes = [instance._image_size for instance in pred_instances for _ in range(len(instance))] 92 | hm_results = process_heatmaps(pred_keypoint_logits.detach(), bboxes_flat.detach(), img_shapes) 93 | 94 | hm_logits = [] 95 | cumsum_idx = np.cumsum(num_instances_per_image) 96 | 97 | assert len(hm_results) == cumsum_idx[-1], \ 98 | "Invalid sizes: {}, {}, {}".format(len(hm_results), cumsum_idx[-1], cumsum_idx) 99 | 100 | for idx in range(len(num_instances_per_image)): 101 | l = 0 if idx == 0 else cumsum_idx[idx - 1] 102 | if num_instances_per_image[idx] == 0: 103 | hm_logits.append(torch.zeros((0, 17, 0, 0))) 104 | else: 105 | hm_logits.append(torch.stack(hm_results[l:l + num_instances_per_image[idx]])) 106 | 107 | for idx in range(min(len(pred_instances), len(hm_logits))): 108 | pred_instances[idx].heat_maps = hm_logits[idx] 109 | 110 | 111 | @ROI_HEADS_REGISTRY.register() 112 | class HeatmapROIHeads(StandardROIHeads): 113 | """ 114 | A Standard ROIHeads which contains returns HeatMaps instead of keypoints. 115 | """ 116 | 117 | def __init__(self, cfg, input_shape): 118 | super().__init__(cfg, input_shape) 119 | 120 | def _forward_keypoint( 121 | self, features: List[torch.Tensor], instances: List[Instances] 122 | ) -> Union[Dict[str, torch.Tensor], List[Instances]]: 123 | if not self.keypoint_on: 124 | return {} if self.training else instances 125 | 126 | if self.training: 127 | assert False, "Not implemented yet!" 128 | else: 129 | pred_boxes = [x.pred_boxes for x in instances] 130 | keypoint_features = self.keypoint_pooler(features, pred_boxes) 131 | keypoint_logits = self.keypoint_head(keypoint_features) 132 | heatmap_rcnn_inference(keypoint_logits, instances) 133 | return instances 134 | 135 | 136 | def get_heatmap_detection_module(): 137 | # Inference with a keypoint detection module 138 | cfg = get_cfg() 139 | cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")) 140 | cfg.MODEL.ROI_HEADS.NAME = "HeatmapROIHeads" 141 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set threshold for this model 142 | cfg.MODEL.WEIGHTS = "detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl" 143 | predictor = build_model(cfg) 144 | print("heatmap head:", cfg.MODEL.ROI_HEADS.NAME) 145 | DetectionCheckpointer(predictor).load(cfg.MODEL.WEIGHTS) 146 | predictor.eval() 147 | return cfg, predictor 148 | 149 | 150 | def get_panoptic_segmentation_module(): 151 | # Inference with a segmentation module 152 | cfg = get_cfg() 153 | cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) 154 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") 155 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set threshold for this model 156 | predictor = build_model(cfg) 157 | print("segmask head:", cfg.MODEL.ROI_HEADS.NAME) 158 | DetectionCheckpointer(predictor).load(cfg.MODEL.WEIGHTS) 159 | predictor.eval() 160 | return cfg, predictor 161 | 162 | 163 | def individual_collate(batch): 164 | """ 165 | Custom collation function for collate with new implementation of individual samples in data pipeline 166 | """ 167 | 168 | data = batch 169 | 170 | # Assuming there's at least one instance in the batch 171 | add_data_keys = data[0].keys() 172 | collected_data = {k: [] for k in add_data_keys} 173 | 174 | for i in range(len(list(data))): 175 | for k in add_data_keys: 176 | collected_data[k].extend(data[i][k]) 177 | 178 | return collected_data 179 | 180 | 181 | def resize_dim(w, h, target): 182 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 183 | if w >= h: 184 | return (int(target * w / h), int(target)) 185 | else: 186 | return (int(target), int(target * h / w)) 187 | 188 | 189 | class VideoDataset(data.Dataset): 190 | 191 | def __init__(self, v_root, vid_range, save_path, skip_len=2): 192 | super(VideoDataset, self).__init__() 193 | 194 | self.v_root = v_root 195 | self.vid_range = vid_range 196 | self.save_path = save_path 197 | 198 | self.init_videos() 199 | 200 | self.max_idx = len(self.v_names) 201 | self.skip = skip_len 202 | 203 | self.width, self.height = 320, 240 204 | self.dim = 192 205 | 206 | def num_frames_in_vid(self, v_path): 207 | vidcap = cv2.VideoCapture(v_path) 208 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 209 | vidcap.release() 210 | return nb_frames 211 | 212 | def extract_video_opencv(self, v_path): 213 | 214 | global imgShape 215 | 216 | v_class = v_path.split('/')[-2] 217 | v_name = os.path.basename(v_path)[0:-4] 218 | 219 | vidcap = cv2.VideoCapture(v_path) 220 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 221 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 222 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 223 | 224 | if (width == 0) or (height == 0): 225 | print(v_path, 'not successfully loaded, drop ..') 226 | return 227 | 228 | new_dim = resize_dim(width, height, self.dim) 229 | 230 | fnames, imgs = [], [] 231 | 232 | success, image = vidcap.read() 233 | count = 1 234 | while success: 235 | image = cv2.resize(image, new_dim, interpolation=cv2.INTER_LINEAR) 236 | if (count % self.skip == 0): 237 | fnames.append((v_class, v_name, count)) 238 | imgs.append(image) 239 | 240 | success, image = vidcap.read() 241 | count += 1 242 | 243 | if int(nb_frames * 0.8) > count: 244 | print(v_path, 'NOT extracted successfully: %df/%df' % (count, nb_frames)) 245 | 246 | vidcap.release() 247 | 248 | return imgs, fnames 249 | 250 | def vid_already_processed(self, v_path): 251 | v_class = v_path.split('/')[-2] 252 | # Remove avi extension 253 | v_name = os.path.basename(v_path)[0:-4] 254 | 255 | out_dir = os.path.join(self.save_path, v_class, v_name) 256 | num_frames = self.num_frames_in_vid(v_path) 257 | for count in range(max(0, num_frames - 10), num_frames): 258 | fpath = os.path.join(out_dir, 'segmask_%05d.npz' % count) 259 | if os.path.exists(fpath): 260 | return True 261 | 262 | return False 263 | 264 | def init_videos(self): 265 | print('processing videos from %s' % self.v_root) 266 | 267 | self.v_names = [] 268 | 269 | v_act_root = sorted(glob.glob(os.path.join(self.v_root, '*/'))) 270 | 271 | num_skip, tot_files = 0, 0 272 | for vid_dir in v_act_root: 273 | v_class = vid_dir.split('/')[-2] 274 | 275 | if (v_class[0].lower() >= self.vid_range[0]) and (v_class[0].lower() <= self.vid_range[1]): 276 | v_paths = glob.glob(os.path.join(vid_dir, '*.avi')) 277 | v_paths = sorted(v_paths) 278 | 279 | for v_path in v_paths: 280 | tot_files += 1 281 | if self.vid_already_processed(v_path): 282 | num_skip += 1 283 | continue 284 | self.v_names.append(v_path) 285 | 286 | print('Processing: {} files. Skipped: {}/{} files.'.format(len(self.v_names), num_skip, tot_files)) 287 | 288 | def __getitem__(self, idx): 289 | vname = self.v_names[idx] 290 | imgs, fnames = self.extract_video_opencv(vname) 291 | return {"img": imgs, "filename": fnames} 292 | 293 | def __len__(self): 294 | return self.max_idx 295 | 296 | 297 | def get_video_data_loader(path, vid_range, save_path, batch_size=2): 298 | dataset = VideoDataset(path, vid_range, save_path) 299 | data_loader = data.DataLoader( 300 | dataset, 301 | batch_size=batch_size, 302 | sampler=data.SequentialSampler(dataset), 303 | shuffle=False, 304 | num_workers=2, 305 | collate_fn=individual_collate, 306 | pin_memory=True, 307 | drop_last=True 308 | ) 309 | return data_loader 310 | 311 | 312 | def write_heatmap_to_file(root, fname, heatmap): 313 | # fname is a list of (class, vname, count) 314 | v_class, v_name, count = fname 315 | out_dir = os.path.join(root, v_class, v_name) 316 | 317 | if not os.path.exists(out_dir): 318 | os.makedirs(out_dir) 319 | 320 | np.savez_compressed(os.path.join(out_dir, 'heatmap_%05d.npz' % count), hm=heatmap) 321 | 322 | 323 | def write_segmask_to_file(root, fname, segmask): 324 | # fname is a list of (class, vname, count) 325 | v_class, v_name, count = fname 326 | out_dir = os.path.join(root, v_class, v_name) 327 | 328 | if not os.path.exists(out_dir): 329 | os.makedirs(out_dir) 330 | 331 | np.savez_compressed(os.path.join(out_dir, 'segmask_%05d.npz' % count), seg=segmask) 332 | 333 | 334 | def convert_to_uint8(x): 335 | x[x < 0.0] = 0.0 336 | x[x > 255.0] = 255.0 337 | nx = x.to(torch.uint8).numpy() 338 | return nx 339 | 340 | 341 | def process_videos(root, vid_provider, args, batch_size=32, debug=False): 342 | 343 | _, modelKP = get_heatmap_detection_module() 344 | _, modelPS = get_panoptic_segmentation_module() 345 | 346 | for batch in tqdm(vid_provider): 347 | imgsTot, fnamesTot = batch['img'], batch['filename'] 348 | 349 | for idx in range(0, len(imgsTot), batch_size): 350 | 351 | imgs, fnames = imgsTot[idx: idx + batch_size], fnamesTot[idx: idx + batch_size] 352 | 353 | imgsDict = [{'image': torch.Tensor(img).float().permute(2, 0, 1)} for img in imgs] 354 | 355 | with torch.no_grad(): 356 | if args.heatmap: 357 | outputsKP = modelKP(imgsDict) 358 | if args.segmask: 359 | outputsPS = modelPS(imgsDict) 360 | 361 | for i in range(len(imgs)): 362 | if args.heatmap: 363 | # Process the keypoints 364 | try: 365 | heatmap = outputsKP[i]['instances'].heat_maps.cpu() 366 | scores = outputsKP[i]['instances'].scores.cpu() 367 | avgHeatmap = (heatmap * scores.view(-1, 1, 1, 1)).sum(dim=0) 368 | # Clamp the max values 369 | avgHeatmap = convert_to_uint8(avgHeatmap) 370 | except: 371 | print("Heatmap generation:", fnames[i]) 372 | print(outputsKP[i]) 373 | else: 374 | assert avgHeatmap.shape[0] == 17, "Invalid size: {}".format(heatmap.shape) 375 | if not debug: 376 | write_heatmap_to_file(root, fnames[i], avgHeatmap) 377 | 378 | if args.segmask: 379 | # Process the segmentation mask 380 | try: 381 | semantic_map = torch.softmax(outputsPS[i]['sem_seg'].detach(), dim=0)[0].cpu() * 255.0 382 | semantic_map = convert_to_uint8(semantic_map) 383 | except: 384 | print("Segmask generation:", fnames[i]) 385 | print(outputsPS[i]) 386 | else: 387 | if not debug: 388 | write_segmask_to_file(root, fnames[i], semantic_map) 389 | 390 | 391 | if __name__ == '__main__': 392 | 393 | parser = argparse.ArgumentParser() 394 | parser.add_argument('--save_path', default='/scr/nishantr/data/ucf101/features/', type=str) 395 | parser.add_argument('--dataset', default='/scr/nishantr/data/ucf101/videos', type=str) 396 | parser.add_argument('--batch_size', default=32, type=int) 397 | parser.add_argument('--vid_range', default='az', type=str) 398 | parser.add_argument('--debug', default=0, type=int) 399 | parser.add_argument('--heatmap', default=0, type=int) 400 | parser.add_argument('--segmask', default=0, type=int) 401 | args = parser.parse_args() 402 | 403 | vid_provider = get_video_data_loader(args.dataset, args.vid_range, args.save_path) 404 | 405 | process_videos(args.save_path, vid_provider, batch_size=args.batch_size, debug=args.debug, args=args) 406 | -------------------------------------------------------------------------------- /process_data/src/extract_frame.py: -------------------------------------------------------------------------------- 1 | from joblib import delayed, Parallel 2 | import os 3 | import sys 4 | import glob 5 | from tqdm import tqdm 6 | import cv2 7 | import argparse 8 | 9 | import matplotlib.pyplot as plt 10 | plt.switch_backend('agg') 11 | 12 | 13 | def str2bool(s): 14 | """Convert string to bool (in argparse context).""" 15 | if s.lower() not in ['true', 'false']: 16 | raise ValueError('Need bool; got %r' % s) 17 | return {'true': True, 'false': False}[s.lower()] 18 | 19 | 20 | def extract_video_opencv(v_path, f_root, dim=240): 21 | '''v_path: single video path; 22 | f_root: root to store frames''' 23 | v_class = v_path.split('/')[-2] 24 | v_name = os.path.basename(v_path)[0:-4] 25 | out_dir = os.path.join(f_root, v_class, v_name) 26 | if not os.path.exists(out_dir): 27 | os.makedirs(out_dir) 28 | 29 | vidcap = cv2.VideoCapture(v_path) 30 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 31 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 32 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 33 | if (width == 0) or (height == 0): 34 | print(v_path, 'not successfully loaded, drop ..'); return 35 | new_dim = resize_dim(width, height, dim) 36 | 37 | success, image = vidcap.read() 38 | count = 1 39 | while success: 40 | image = cv2.resize(image, new_dim, interpolation = cv2.INTER_LINEAR) 41 | cv2.imwrite(os.path.join(out_dir, 'image_%05d.jpg' % count), image, 42 | [cv2.IMWRITE_JPEG_QUALITY, 80])# quality from 0-100, 95 is default, high is good 43 | success, image = vidcap.read() 44 | count += 1 45 | if nb_frames > count: 46 | print('/'.join(out_dir.split('/')[-2::]), 'NOT extracted successfully: %df/%df' % (count, nb_frames)) 47 | vidcap.release() 48 | 49 | 50 | def resize_dim(w, h, target): 51 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 52 | if w >= h: 53 | return (int(target * w / h), int(target)) 54 | else: 55 | return (int(target), int(target * h / w)) 56 | 57 | 58 | def main_UCF101(v_root, f_root): 59 | print('extracting UCF101 ... ') 60 | print('extracting videos from %s' % v_root) 61 | print('frame save to %s' % f_root) 62 | 63 | if not os.path.exists(f_root): os.makedirs(f_root) 64 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 65 | print(len(v_act_root)) 66 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 67 | v_paths = glob.glob(os.path.join(j, '*.avi')) 68 | v_paths = sorted(v_paths) 69 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 70 | 71 | 72 | def main_HMDB51(v_root, f_root): 73 | print('extracting HMDB51 ... ') 74 | print('extracting videos from %s' % v_root) 75 | print('frame save to %s' % f_root) 76 | 77 | if not os.path.exists(f_root): os.makedirs(f_root) 78 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 79 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 80 | v_paths = glob.glob(os.path.join(j, '*.avi')) 81 | v_paths = sorted(v_paths) 82 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 83 | 84 | 85 | def main_JHMDB(v_root, f_root): 86 | print('extracting JHMDB ... ') 87 | print('extracting videos from %s' % v_root) 88 | print('frame save to %s' % f_root) 89 | 90 | if not os.path.exists(f_root): os.makedirs(f_root) 91 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 92 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 93 | v_paths = glob.glob(os.path.join(j, '*.avi')) 94 | v_paths = sorted(v_paths) 95 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 96 | 97 | 98 | def main_kinetics400(v_root, f_root, dim=128): 99 | print('extracting Kinetics400 ... ') 100 | for basename in ['train', 'val']: 101 | v_root_real = v_root + '/' + basename 102 | if not os.path.exists(v_root_real): 103 | print('Wrong v_root'); sys.exit() 104 | 105 | f_root_real = f_root + '/' + basename 106 | print('Extract to: \nframe: %s' % f_root_real) 107 | if not os.path.exists(f_root_real): 108 | os.makedirs(f_root_real) 109 | v_act_root = glob.glob(os.path.join(v_root_real, '*/')) 110 | v_act_root = sorted(v_act_root) 111 | 112 | # if resume, remember to delete the last video folder 113 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 114 | v_paths = glob.glob(os.path.join(j, '*.mp4')) 115 | v_paths = sorted(v_paths) 116 | # for resume: 117 | v_class = j.split('/')[-2] 118 | out_dir = os.path.join(f_root_real, v_class) 119 | if os.path.exists(out_dir): print(out_dir, 'exists!'); continue 120 | print('extracting: %s' % v_class) 121 | # dim = 150 (crop to 128 later) or 256 (crop to 224 later) 122 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root_real, dim=dim) for p in tqdm(v_paths, total=len(v_paths))) 123 | 124 | 125 | def main_Panasonic(v_root, f_root): 126 | print('extracting Panasonic ... ') 127 | print('extracting videos from %s' % v_root) 128 | print('frame save to %s' % f_root) 129 | 130 | if not os.path.exists(f_root): os.makedirs(f_root) 131 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 132 | print(len(v_act_root)) 133 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 134 | v_paths = glob.glob(os.path.join(j, '*.mkv')) 135 | v_paths = sorted(v_paths) 136 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 137 | 138 | 139 | if __name__ == '__main__': 140 | # v_root is the video source path, f_root is where to store frames 141 | # edit 'your_path' here: 142 | #dataset_path = '/vision/u/nishantr/data' 143 | 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--ucf101', default=False, type=str2bool) 146 | parser.add_argument('--jhmdb', default=False, type=str2bool) 147 | parser.add_argument('--hmdb51', default=False, type=str2bool) 148 | parser.add_argument('--kinetics', default=False, type=str2bool) 149 | parser.add_argument('--panasonic', default=False, type=str2bool) 150 | parser.add_argument('--dataset_path', default='/scr/nishantr/data', type=str) 151 | parser.add_argument('--dim', default=128, type=int) 152 | args = parser.parse_args() 153 | 154 | dataset_path = args.dataset_path 155 | 156 | if args.ucf101: 157 | main_UCF101(v_root=dataset_path + '/ucf101/videos/', f_root=dataset_path + '/ucf101/frame/') 158 | 159 | if args.jhmdb: 160 | main_JHMDB(v_root=dataset_path + '/jhmdb/videos/', f_root=dataset_path + '/jhmdb/frame/') 161 | 162 | if args.hmdb51: 163 | main_HMDB51(v_root=dataset_path+'/hmdb/videos', f_root=dataset_path+'/hmdb/frame') 164 | 165 | if args.panasonic: 166 | main_Panasonic(v_root=dataset_path+'/action_split_data/V1.0', f_root=dataset_path+'/frame', dim=256) 167 | 168 | if args.kinetics: 169 | if args.dim == 256: 170 | main_kinetics400( 171 | v_root=dataset_path + '/kinetics/video', f_root=dataset_path + '/kinetics/frame256', dim=args.dim 172 | ) 173 | else: 174 | assert args.dim == 128, "Invalid dim: {}".format(args.dim) 175 | main_kinetics400(v_root=dataset_path+'/kinetics/video', f_root=dataset_path+'/kinetics/frame', dim=128) 176 | 177 | # main_kinetics400(v_root='your_path/Kinetics400_256/videos', 178 | # f_root='your_path/Kinetics400_256/frame', dim=256) 179 | -------------------------------------------------------------------------------- /process_data/src/write_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | import sys 5 | 6 | import pandas as pd 7 | 8 | from joblib import delayed, Parallel 9 | from tqdm import tqdm 10 | 11 | 12 | def str2bool(s): 13 | """Convert string to bool (in argparse context).""" 14 | if s.lower() not in ['true', 'false']: 15 | raise ValueError('Need bool; got %r' % s) 16 | return {'true': True, 'false': False}[s.lower()] 17 | 18 | 19 | def write_list(data_list, path, ): 20 | with open(path, 'w') as f: 21 | writer = csv.writer(f, delimiter=',') 22 | for row in data_list: 23 | if row: writer.writerow(row) 24 | print('split saved to %s' % path) 25 | 26 | 27 | def main_UCF101(f_root, splits_root, csv_root='../data/ucf101/'): 28 | '''generate training/testing split, count number of available frames, save in csv''' 29 | if not os.path.exists(csv_root): os.makedirs(csv_root) 30 | for which_split in [1,2,3]: 31 | train_set = [] 32 | test_set = [] 33 | train_split_file = os.path.join(splits_root, 'trainlist%02d.txt' % which_split) 34 | with open(train_split_file, 'r') as f: 35 | for line in f: 36 | vpath = os.path.join(f_root, line.split(' ')[0][0:-4]) + '/' 37 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 38 | 39 | test_split_file = os.path.join(splits_root, 'testlist%02d.txt' % which_split) 40 | with open(test_split_file, 'r') as f: 41 | for line in f: 42 | vpath = os.path.join(f_root, line.rstrip()[0:-4]) + '/' 43 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 44 | 45 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 46 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 47 | 48 | 49 | def main_HMDB51(f_root, splits_root, csv_root='../data/hmdb51/'): 50 | '''generate training/testing split, count number of available frames, save in csv''' 51 | if not os.path.exists(csv_root): os.makedirs(csv_root) 52 | for which_split in [1,2,3]: 53 | train_set = [] 54 | test_set = [] 55 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 56 | assert len(split_files) == 51 57 | for split_file in split_files: 58 | action_name = os.path.basename(split_file)[0:-16] 59 | with open(split_file, 'r') as f: 60 | for line in f: 61 | video_name = line.split(' ')[0] 62 | _type = line.split(' ')[1] 63 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 64 | if _type == '1': 65 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 66 | elif _type == '2': 67 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 68 | 69 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 70 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 71 | 72 | 73 | def main_JHMDB(f_root, splits_root, csv_root='../data/jhmdb/'): 74 | '''generate training/testing split, count number of available frames, save in csv''' 75 | if not os.path.exists(csv_root): os.makedirs(csv_root) 76 | for which_split in [1,2,3]: 77 | train_set = [] 78 | test_set = [] 79 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 80 | assert len(split_files) == 21 81 | for split_file in split_files: 82 | action_name = os.path.basename(split_file)[0:-16] 83 | with open(split_file, 'r') as f: 84 | for line in f: 85 | video_name = line.split(' ')[0] 86 | _type = line.split(' ')[1].strip('\n') 87 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 88 | if _type == '1': 89 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 90 | elif _type == '2': 91 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 92 | 93 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 94 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 95 | 96 | 97 | ### For Kinetics ### 98 | def get_split(root, split_path, mode): 99 | print('processing %s split ...' % mode) 100 | print('checking %s' % root) 101 | split_list = [] 102 | split_content = pd.read_csv(split_path).iloc[:,0:4] 103 | split_list = Parallel(n_jobs=64)\ 104 | (delayed(check_exists)(row, root) \ 105 | for i, row in tqdm(split_content.iterrows(), total=len(split_content))) 106 | return split_list 107 | 108 | missedCnt = 0 109 | 110 | def check_exists(row, root): 111 | global missedCnt 112 | 113 | dirname = '_'.join([row['youtube_id'], '%06d' % row['time_start'], '%06d' % row['time_end']]) 114 | full_dirname = os.path.join(root, row['label'], dirname) 115 | # replace spaces with underscores 116 | full_dirname = full_dirname.replace(' ', '_') 117 | if os.path.exists(full_dirname): 118 | n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) 119 | return [full_dirname, n_frames] 120 | else: 121 | missedCnt += 1 122 | return None 123 | 124 | def main_Kinetics400(mode, k400_path, f_root, csv_root='../data/kinetics400'): 125 | global missedCnt 126 | missedCnt = 0 127 | 128 | train_split_path = os.path.join(k400_path, 'kinetics-400_train.csv') 129 | val_split_path = os.path.join(k400_path, 'kinetics-400_val.csv') 130 | test_split_path = os.path.join(k400_path, 'kinetics-400_test.csv') 131 | 132 | if not os.path.exists(csv_root): 133 | os.makedirs(csv_root) 134 | 135 | if mode == 'train': 136 | train_split = get_split(os.path.join(f_root, 'train'), train_split_path, 'train') 137 | write_list(train_split, os.path.join(csv_root, 'train_split.csv')) 138 | elif mode == 'val': 139 | val_split = get_split(os.path.join(f_root, 'val'), val_split_path, 'val') 140 | write_list(val_split, os.path.join(csv_root, 'val_split.csv')) 141 | elif mode == 'test': 142 | test_split = get_split(f_root, test_split_path, 'test') 143 | write_list(test_split, os.path.join(csv_root, 'test_split.csv')) 144 | else: 145 | raise IOError('wrong mode') 146 | 147 | print("Total files missed:", missedCnt) 148 | 149 | 150 | import argparse 151 | 152 | if __name__ == '__main__': 153 | # f_root is the frame path 154 | # edit 'your_path' here: 155 | 156 | #dataset_dir = '/vision/u/nishantr/data/' 157 | # dataset_dir = '/scr/nishantr/data/' 158 | 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument('--ucf101', default=False, type=str2bool) 161 | parser.add_argument('--jhmdb', default=False, type=str2bool) 162 | parser.add_argument('--hmdb51', default=False, type=str2bool) 163 | parser.add_argument('--kinetics', default=False, type=str2bool) 164 | parser.add_argument('--dataset_path', default='/scr/nishantr/data', type=str) 165 | args = parser.parse_args() 166 | 167 | dataset_path = args.dataset_path 168 | 169 | if args.ucf101: 170 | main_UCF101(f_root=dataset_path + 'ucf101/frame', 171 | splits_root=dataset_path + 'ucf101/splits_classification') 172 | 173 | if args.jhmdb: 174 | main_JHMDB(f_root=dataset_path + 'jhmdb/frame', splits_root=dataset_path + 'jhmdb/splits') 175 | 176 | if args.hmdb51: 177 | main_HMDB51(f_root=dataset_path + 'hmdb/frame', splits_root=dataset_path + 'hmdb/splits/') 178 | 179 | if args.kinetics: 180 | main_Kinetics400( 181 | mode='train', # train or val or test 182 | k400_path=dataset_path + 'kinetics/splits', 183 | f_root=dataset_path + 'kinetics/frame256', 184 | csv_root='../data/kinetics400_256', 185 | ) 186 | 187 | main_Kinetics400( 188 | mode='val', # train or val or test 189 | k400_path=dataset_path + 'kinetics/splits', 190 | f_root=dataset_path + 'kinetics/frame256', 191 | csv_root='../data/kinetics400_256', 192 | ) 193 | 194 | # main_Kinetics400(mode='train', # train or val or test 195 | # k400_path='your_path/Kinetics', 196 | # f_root='your_path/Kinetics400_256/frame', 197 | # csv_root='../data/kinetics400_256') 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | astor==0.8.0 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | cachetools==3.1.1 7 | calmsize==0.1.3 8 | certifi==2019.11.28 9 | cffi==1.13.1 10 | chardet==3.0.4 11 | cloudpickle==1.2.2 12 | coclust==0.2.1 13 | cycler==0.10.0 14 | Cython==0.29.14 15 | cytoolz==0.10.0 16 | dask==2.6.0 17 | decorator==4.4.1 18 | defusedxml==0.6.0 19 | detectron2==0.1+cu92 20 | docopt==0.6.2 21 | entrypoints==0.3 22 | environment-kernels==1.1.1 23 | future==0.18.2 24 | fvcore==0.1.dev200112 25 | gast==0.2.2 26 | google-auth==1.7.1 27 | google-auth-oauthlib==0.4.1 28 | google-pasta==0.1.8 29 | grpcio==1.25.0 30 | h5py==2.10.0 31 | idna==2.8 32 | imagecodecs==2020.2.18 33 | imageio==2.6.1 34 | importlib-metadata==0.23 35 | ipdb==0.12.3 36 | ipykernel==5.1.3 37 | ipython==7.9.0 38 | ipython-genutils==0.2.0 39 | ipywidgets==7.5.1 40 | jedi==0.15.1 41 | Jinja2==2.10.3 42 | joblib==0.14.0 43 | jsonschema==3.1.1 44 | jupyter==1.0.0 45 | jupyter-client==5.3.4 46 | jupyter-console==6.0.0 47 | jupyter-core==4.6.1 48 | Keras-Applications==1.0.8 49 | Keras-Preprocessing==1.1.0 50 | kiwisolver==1.1.0 51 | line-profiler==2.1.2 52 | Markdown==3.1.1 53 | MarkupSafe==1.1.1 54 | matplotlib==3.1.1 55 | mistune==0.8.4 56 | mkl-fft==1.0.14 57 | mkl-random==1.1.0 58 | mkl-service==2.3.0 59 | more-itertools==7.2.0 60 | nb-conda==2.2.1 61 | nb-conda-kernels==2.2.2 62 | nbconvert==5.6.1 63 | nbformat==4.4.0 64 | networkx==2.4 65 | nltk==3.4.5 66 | notebook==6.0.1 67 | numpy==1.17.0 68 | oauthlib==3.1.0 69 | olefile==0.46 70 | opencv-python==4.1.1.26 71 | opt-einsum==3.1.0 72 | pandas==0.25.2 73 | pandocfilters==1.4.2 74 | parso==0.5.1 75 | pexpect==4.7.0 76 | pickleshare==0.7.5 77 | Pillow==6.2.2 78 | portalocker==1.5.2 79 | prometheus-client==0.7.1 80 | prompt-toolkit==2.0.10 81 | protobuf==3.10.0 82 | ptyprocess==0.6.0 83 | pyasn1==0.4.8 84 | pyasn1-modules==0.2.7 85 | pycparser==2.19 86 | pydot==1.4.1 87 | Pygments==2.4.2 88 | pyparsing==2.4.2 89 | pyrsistent==0.15.4 90 | python-dateutil==2.8.0 91 | pytorch-memlab==0.1.0 92 | pytz==2019.3 93 | PyWavelets==1.1.1 94 | PyYAML==5.3 95 | pyzmq==18.1.0 96 | qtconsole==4.5.5 97 | requests==2.22.0 98 | requests-oauthlib==1.3.0 99 | resample2d-cuda==0.0.0 100 | rsa==4.0 101 | scikit-image==0.15.0 102 | scikit-learn==0.21.3 103 | scipy==1.1.0 104 | seaborn==0.10.1 105 | Send2Trash==1.5.0 106 | sewar==0.4.2 107 | six==1.12.0 108 | tabulate==0.8.6 109 | tensorboard==2.0.1 110 | tensorboardX==1.9 111 | tensorflow==2.0.0 112 | tensorflow-estimator==2.0.1 113 | termcolor==1.1.0 114 | terminado==0.8.2 115 | testpath==0.4.2 116 | tifffile==2020.2.16 117 | toolz==0.10.0 118 | torch==1.3.0 119 | torchvision==0.4.1 120 | tornado==6.0.3 121 | tqdm==4.42.0 122 | traitlets==4.3.3 123 | traj-conv-cuda==0.0.0 124 | tsne==0.1.8 125 | urllib3==1.25.7 126 | wcwidth==0.1.7 127 | webencodings==0.5.1 128 | Werkzeug==0.16.0 129 | widgetsnbextension==3.5.1 130 | wrapt==1.11.2 131 | yacs==0.1.6 132 | zipp==0.6.0 133 | -------------------------------------------------------------------------------- /test/dataset_3d_lc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import csv 9 | import glob 10 | import pandas as pd 11 | import numpy as np 12 | import cv2 13 | 14 | sys.path.append('../train') 15 | import model_utils as mu 16 | 17 | sys.path.append('../utils') 18 | from augmentation import * 19 | from tqdm import tqdm 20 | from joblib import Parallel, delayed 21 | 22 | 23 | def pil_loader(path): 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | 29 | toTensor = transforms.ToTensor() 30 | toPILImage = transforms.ToPILImage() 31 | def flow_loader(path): 32 | try: 33 | img = Image.open(path) 34 | except: 35 | return None 36 | f = toTensor(img) 37 | if f.mean() > 0.3: 38 | f -= 0.5 39 | return f 40 | 41 | 42 | def fetch_imgs_seq(vpath, idx_block): 43 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 44 | return seq 45 | 46 | 47 | def fill_nones(l): 48 | l = [l[i-1] if l[i] is None else l[i] for i in range(len(l))] 49 | l = [l[i-1] if l[i] is None else l[i] for i in range(len(l))] 50 | try: 51 | nonNoneL = [item for item in l if item is not None][0] 52 | except: 53 | nonNoneL = torch.zeros((1, 256, 256)) 54 | return [torch.zeros(nonNoneL.shape) if l[i] is None else l[i] for i in range(len(l))] 55 | 56 | 57 | def get_u_flow_path_list(vpath, idx_block): 58 | dataset = 'ucf101' if 'ucf101' in vpath else 'hmdb51' 59 | flow_base_path = os.path.join('/dev/shm/data/nishantr/flow/', dataset + '_flow/') 60 | vid_name = os.path.basename(os.path.normpath(vpath)) 61 | return [os.path.join(flow_base_path, 'u', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 62 | 63 | 64 | def get_v_flow_path_list(vpath, idx_block): 65 | dataset = 'ucf101' if 'ucf101' in vpath else 'hmdb51' 66 | flow_base_path = os.path.join('/dev/shm/data/nishantr/flow/', dataset + '_flow/') 67 | vid_name = os.path.basename(os.path.normpath(vpath)) 68 | return [os.path.join(flow_base_path, 'v', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 69 | 70 | 71 | def fetch_flow_seq(vpath, idx_block): 72 | u_flow_list = get_u_flow_path_list(vpath, idx_block) 73 | v_flow_list = get_v_flow_path_list(vpath, idx_block) 74 | 75 | u_seq = fill_nones([flow_loader(f) for f in u_flow_list]) 76 | v_seq = fill_nones([flow_loader(f) for f in v_flow_list]) 77 | 78 | seq = [toPILImage(torch.cat([u, v])) for u, v in zip(u_seq, v_seq)] 79 | return seq 80 | 81 | 82 | def get_class_vid(vpath): 83 | return os.path.normpath(vpath).split('/')[-2:] 84 | 85 | 86 | def load_detectron_feature(fdir, idx, opt): 87 | # opt is either hm or seg 88 | 89 | shape = (192, 256) 90 | num_channels = 17 if opt == 'hm' else 1 91 | 92 | def load_feature(path): 93 | try: 94 | x = np.load(path)[opt] 95 | except: 96 | x = np.zeros((0, 0, 0)) 97 | 98 | # Match non-existent values 99 | if x.shape[1] == 0: 100 | x = np.zeros((num_channels, shape[0], shape[1])) 101 | 102 | x = torch.tensor(x, dtype=torch.float) / 255.0 103 | 104 | # Add extra channel in case it's not present 105 | if len(x.shape) < 3: 106 | x = x.unsqueeze(0) 107 | return x 108 | 109 | suffix = 'heatmap' if opt == 'hm' else 'segmask' 110 | fpath = os.path.join(fdir, suffix + '_%05d.npz' % idx) 111 | if os.path.isfile(fpath): 112 | return load_feature(fpath) 113 | else: 114 | # We do not have results lower than idx=2 115 | idx = max(3, idx) 116 | # We assume having all results for every two frames 117 | fpath0 = os.path.join(fdir, suffix + '_%05d.npz' % (idx - 1)) 118 | fpath1 = os.path.join(fdir, suffix + '_%05d.npz' % (idx + 1)) 119 | # This is not guaranteed to exist 120 | if not os.path.isfile(fpath1): 121 | fpath1 = fpath0 122 | a0, a1 = load_feature(fpath0), load_feature(fpath1) 123 | try: 124 | a_avg = (a0 + a1) / 2.0 125 | except: 126 | a_avg = None 127 | return a_avg 128 | 129 | 130 | def fetch_kp_heatmap_seq(vpath, idx_block): 131 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 132 | feature_vpath = vpath.replace('/frame/', '/heatmaps/') 133 | seq = fill_nones([load_detectron_feature(feature_vpath, idx, opt='hm') for idx in idx_block]) 134 | 135 | if len(set([x.shape for x in seq])) > 1: 136 | # We now know the invalid paths, so no need to print them 137 | # print("Invalid path:", vpath) 138 | seq = [seq[len(seq) // 2] for _ in seq] 139 | return seq 140 | 141 | 142 | def fetch_seg_mask_seq(vpath, idx_block): 143 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 144 | feature_vpath = vpath.replace('/frame/', '/segmasks/') 145 | seq = fill_nones([load_detectron_feature(feature_vpath, idx, opt='seg') for idx in idx_block]) 146 | return seq 147 | 148 | 149 | class UCF101_3d(data.Dataset): 150 | def __init__(self, 151 | mode='train', 152 | transform=None, 153 | seq_len=10, 154 | num_seq =1, 155 | downsample=3, 156 | epsilon=5, 157 | which_split=1, 158 | modality=mu.ImgMode): 159 | self.mode = mode 160 | self.transform = transform 161 | self.seq_len = seq_len 162 | self.num_seq = num_seq 163 | self.downsample = downsample 164 | self.epsilon = epsilon 165 | self.which_split = which_split 166 | self.modality = modality 167 | 168 | # splits 169 | if mode == 'train': 170 | split = '../process_data/data/ucf101/train_split%02d.csv' % self.which_split 171 | video_info = pd.read_csv(split, header=None) 172 | elif (mode == 'val') or (mode == 'test'): 173 | split = '../process_data/data/ucf101/test_split%02d.csv' % self.which_split # use test for val, temporary 174 | video_info = pd.read_csv(split, header=None) 175 | else: raise ValueError('wrong mode') 176 | 177 | # get action list 178 | self.action_dict_encode = {} 179 | self.action_dict_decode = {} 180 | 181 | action_file = os.path.join('../process_data/data/ucf101', 'classInd.txt') 182 | action_df = pd.read_csv(action_file, sep=' ', header=None) 183 | for _, row in action_df.iterrows(): 184 | act_id, act_name = row 185 | act_id = int(act_id) - 1 # let id start from 0 186 | self.action_dict_decode[act_id] = act_name 187 | self.action_dict_encode[act_name] = act_id 188 | 189 | # filter out too short videos: 190 | drop_idx = [] 191 | for idx, row in video_info.iterrows(): 192 | vpath, vlen = row 193 | if vlen <= 0: 194 | drop_idx.append(idx) 195 | self.video_info = video_info.drop(drop_idx, axis=0) 196 | 197 | # if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 198 | # shuffle not required 199 | 200 | def idx_sampler(self, vlen, vpath): 201 | '''sample index from a video''' 202 | downsample = self.downsample 203 | if (vlen - (self.num_seq * self.seq_len * self.downsample)) <= 0: 204 | downsample = ((vlen - 1) / (self.num_seq * self.seq_len * 1.0)) * 0.9 205 | 206 | n = 1 207 | if self.mode == 'test': 208 | seq_idx_block = np.arange(0, vlen, downsample) # all possible frames with downsampling 209 | seq_idx_block = seq_idx_block.astype(int) 210 | return [seq_idx_block, vpath] 211 | start_idx = np.random.choice(range(vlen-int(self.num_seq*self.seq_len*downsample)), n) 212 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*downsample*self.seq_len + start_idx 213 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*downsample 214 | seq_idx_block = seq_idx_block.astype(int) 215 | return [seq_idx_block, vpath] 216 | 217 | def __getitem__(self, index): 218 | vpath, vlen = self.video_info.iloc[index] 219 | items = self.idx_sampler(vlen, vpath) 220 | if items is None: print(vpath) 221 | 222 | idx_block, vpath = items 223 | if self.mode != 'test': 224 | assert idx_block.shape == (self.num_seq, self.seq_len) 225 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 226 | 227 | seq = None 228 | if self.modality == mu.ImgMode: 229 | seq = fetch_imgs_seq(vpath, idx_block) 230 | elif self.modality == mu.FlowMode: 231 | seq = fetch_flow_seq(vpath, idx_block) 232 | elif self.modality == mu.KeypointHeatmap: 233 | seq = fetch_kp_heatmap_seq(vpath, idx_block) 234 | elif self.modality == mu.SegMask: 235 | seq = fetch_seg_mask_seq(vpath, idx_block) 236 | 237 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 238 | seq = torch.stack(seq) 239 | 240 | # if self.mode == 'test': 241 | # # apply same transform 242 | # t_seq = [self.transform(seq) for _ in range(5)] 243 | # else: 244 | t_seq = self.transform(seq) # apply same transform 245 | # Convert tensor into list of tensors 246 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 247 | t_seq = [t_seq[idx] for idx in range(t_seq.shape[0])] 248 | 249 | num_crop = None 250 | try: 251 | (C, H, W) = t_seq[0].size() 252 | t_seq = torch.stack(t_seq, 0) 253 | except: 254 | (C, H, W) = t_seq[0][0].size() 255 | tmp = [torch.stack(i, 0) for i in t_seq] 256 | assert len(tmp) == 5 257 | num_crop = 5 258 | t_seq = torch.stack(tmp, 1) 259 | 260 | if self.mode == 'test': 261 | # return all available clips, but cut into length = num_seq 262 | SL = t_seq.size(0) 263 | clips = []; i = 0 264 | while i+self.seq_len <= SL: 265 | clips.append(t_seq[i:i+self.seq_len, :]) 266 | # i += self.seq_len//2 267 | i += self.seq_len 268 | if num_crop: 269 | # half overlap: 270 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 271 | NC = len(clips) 272 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 273 | else: 274 | # half overlap: 275 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 276 | t_seq = torch.stack(clips, 0) 277 | else: 278 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 279 | 280 | try: 281 | vname = vpath.split('/')[-3] 282 | vid = self.encode_action(vname) 283 | except: 284 | vname = vpath.split('/')[-2] 285 | vid = self.encode_action(vname) 286 | 287 | label = torch.LongTensor([vid]) 288 | idx = torch.LongTensor([index]) 289 | 290 | return t_seq, label, idx 291 | 292 | def __len__(self): 293 | return len(self.video_info) 294 | 295 | def encode_action(self, action_name): 296 | '''give action name, return category''' 297 | return self.action_dict_encode[action_name] 298 | 299 | def decode_action(self, action_code): 300 | '''give action code, return action name''' 301 | return self.action_dict_decode[action_code] 302 | 303 | 304 | class HMDB51_3d(data.Dataset): 305 | def __init__(self, 306 | mode='train', 307 | transform=None, 308 | seq_len=10, 309 | num_seq=1, 310 | downsample=1, 311 | epsilon=5, 312 | which_split=1, 313 | modality=mu.ImgMode): 314 | self.mode = mode 315 | self.transform = transform 316 | self.seq_len = seq_len 317 | self.num_seq = num_seq 318 | self.downsample = downsample 319 | self.epsilon = epsilon 320 | self.which_split = which_split 321 | self.modality = modality 322 | 323 | # splits 324 | if mode == 'train': 325 | split = '../process_data/data/hmdb51/train_split%02d.csv' % self.which_split 326 | video_info = pd.read_csv(split, header=None) 327 | elif (mode == 'val') or (mode == 'test'): 328 | split = '../process_data/data/hmdb51/test_split%02d.csv' % self.which_split # use test for val, temporary 329 | video_info = pd.read_csv(split, header=None) 330 | else: raise ValueError('wrong mode') 331 | 332 | # get action list 333 | self.action_dict_encode = {} 334 | self.action_dict_decode = {} 335 | 336 | action_file = os.path.join('../process_data/data/hmdb51', 'classInd.txt') 337 | action_df = pd.read_csv(action_file, sep=' ', header=None) 338 | for _, row in action_df.iterrows(): 339 | act_id, act_name = row 340 | act_id = int(act_id) - 1 # let id start from 0 341 | self.action_dict_decode[act_id] = act_name 342 | self.action_dict_encode[act_name] = act_id 343 | 344 | # filter out too short videos: 345 | drop_idx = [] 346 | for idx, row in video_info.iterrows(): 347 | vpath, vlen = row 348 | if vlen <= 0: 349 | drop_idx.append(idx) 350 | self.video_info = video_info.drop(drop_idx, axis=0) 351 | 352 | # if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 353 | # shuffle not required 354 | 355 | def idx_sampler(self, vlen, vpath): 356 | '''sample index from a video''' 357 | downsample = self.downsample 358 | if (vlen - (self.num_seq * self.seq_len * self.downsample)) <= 0: 359 | downsample = ((vlen - 1) / (self.num_seq * self.seq_len * 1.0)) * 0.9 360 | 361 | n=1 362 | if self.mode == 'test': 363 | seq_idx_block = np.arange(0, vlen, downsample) # all possible frames with downsampling 364 | seq_idx_block = seq_idx_block.astype(int) 365 | return [seq_idx_block, vpath] 366 | start_idx = np.random.choice(range(vlen-int(self.num_seq*self.seq_len*downsample)), n) 367 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*downsample*self.seq_len + start_idx 368 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*downsample 369 | seq_idx_block = seq_idx_block.astype(int) 370 | return [seq_idx_block, vpath] 371 | 372 | def __getitem__(self, index): 373 | vpath, vlen = self.video_info.iloc[index] 374 | items = self.idx_sampler(vlen, vpath) 375 | if items is None: print(vpath) 376 | 377 | idx_block, vpath = items 378 | if self.mode != 'test': 379 | assert idx_block.shape == (self.num_seq, self.seq_len) 380 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 381 | 382 | seq = None 383 | if self.modality == mu.ImgMode: 384 | seq = fetch_imgs_seq(vpath, idx_block) 385 | elif self.modality == mu.FlowMode: 386 | seq = fetch_flow_seq(vpath, idx_block) 387 | elif self.modality == mu.KeypointHeatmap: 388 | seq = fetch_kp_heatmap_seq(vpath, idx_block) 389 | elif self.modality == mu.SegMask: 390 | seq = fetch_seg_mask_seq(vpath, idx_block) 391 | 392 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 393 | seq = torch.stack(seq) 394 | 395 | t_seq = self.transform(seq) # apply same transform 396 | # Convert tensor into list of tensors 397 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 398 | t_seq = [t_seq[idx] for idx in range(t_seq.shape[0])] 399 | 400 | num_crop = None 401 | try: 402 | (C, H, W) = t_seq[0].size() 403 | t_seq = torch.stack(t_seq, 0) 404 | except: 405 | (C, H, W) = t_seq[0][0].size() 406 | tmp = [torch.stack(i, 0) for i in t_seq] 407 | assert len(tmp) == 5 408 | num_crop = 5 409 | t_seq = torch.stack(tmp, 1) 410 | # print(t_seq.size()) 411 | # import ipdb; ipdb.set_trace() 412 | if self.mode == 'test': 413 | # return all available clips, but cut into length = num_seq 414 | SL = t_seq.size(0) 415 | clips = []; i = 0 416 | while i+self.seq_len <= SL: 417 | clips.append(t_seq[i:i+self.seq_len, :]) 418 | # i += self.seq_len//2 419 | i += self.seq_len 420 | if num_crop: 421 | # half overlap: 422 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 423 | NC = len(clips) 424 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 425 | else: 426 | # half overlap: 427 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,3*self.num_seq//4)] 428 | t_seq = torch.stack(clips, 0) 429 | else: 430 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 431 | 432 | try: 433 | vname = vpath.split('/')[-3] 434 | vid = self.encode_action(vname) 435 | except: 436 | vname = vpath.split('/')[-2] 437 | vid = self.encode_action(vname) 438 | 439 | label = torch.LongTensor([vid]) 440 | idx = torch.LongTensor([index]) 441 | 442 | return t_seq, label, idx 443 | 444 | def __len__(self): 445 | return len(self.video_info) 446 | 447 | def encode_action(self, action_name): 448 | '''give action name, return category''' 449 | return self.action_dict_encode[action_name] 450 | 451 | def decode_action(self, action_code): 452 | '''give action code, return action name''' 453 | return self.action_dict_decode[action_code] 454 | 455 | -------------------------------------------------------------------------------- /test/model_3d_lc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | sys.path.append('../backbone') 5 | from select_backbone import select_resnet 6 | from convrnn import ConvGRU 7 | 8 | sys.path.append('../train') 9 | import model_utils as mu 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class LC(nn.Module): 17 | def __init__(self, sample_size, num_seq, seq_len, in_channels, 18 | network='resnet18', dropout=0.5, num_class=101): 19 | super(LC, self).__init__() 20 | torch.cuda.manual_seed(666) 21 | self.sample_size = sample_size 22 | self.num_seq = num_seq 23 | self.seq_len = seq_len 24 | self.num_class = num_class 25 | self.in_channels = in_channels 26 | print('=> Using RNN + FC model with ic:', self.in_channels) 27 | 28 | print('=> Use 2D-3D %s!' % network) 29 | self.last_duration = int(math.ceil(seq_len / 4)) 30 | self.last_size = int(math.ceil(sample_size / 32)) 31 | track_running_stats = True 32 | 33 | self.backbone, self.param = \ 34 | select_resnet(network, track_running_stats=track_running_stats, in_channels=self.in_channels) 35 | self.param['num_layers'] = 1 36 | self.param['hidden_size'] = self.param['feature_size'] 37 | 38 | print('=> using ConvRNN, kernel_size = 1') 39 | self.agg = ConvGRU(input_size=self.param['feature_size'], 40 | hidden_size=self.param['hidden_size'], 41 | kernel_size=1, 42 | num_layers=self.param['num_layers']) 43 | self._initialize_weights(self.agg) 44 | 45 | self.final_bn = nn.BatchNorm1d(self.param['feature_size']) 46 | self.final_bn.weight.data.fill_(1) 47 | self.final_bn.bias.data.zero_() 48 | 49 | self.num_classes = num_class 50 | self.dropout = dropout 51 | self.hidden_size = 128 52 | self.final_fc = nn.Sequential( 53 | nn.Dropout(self.dropout), 54 | nn.Linear(self.param['feature_size'], self.num_classes), 55 | ) 56 | 57 | self._initialize_weights(self.final_fc) 58 | 59 | def forward(self, block): 60 | # seq1: [B, N, C, SL, W, H] 61 | (B, N, C, SL, H, W) = block.shape 62 | block = block.view(B*N, C, SL, H, W) 63 | feature = self.backbone(block) 64 | del block 65 | # TODO: Do we need ReLU 66 | # feature = F.relu(feature) 67 | 68 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=1) 69 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B*N,D,last_size,last_size] 70 | context, _ = self.agg(feature) 71 | context = context[:,-1,:].unsqueeze(1) 72 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 73 | del feature 74 | 75 | context = self.final_bn(context.transpose(-1,-2)).transpose(-1,-2) # [B,N,C] -> [B,C,N] -> BN() -> [B,N,C], because BN operates on id=1 channel. 76 | output = self.final_fc(context).view(B, -1, self.num_class) 77 | 78 | return output, context 79 | 80 | def _initialize_weights(self, module): 81 | for name, param in module.named_parameters(): 82 | if 'bias' in name: 83 | nn.init.constant_(param, 0.0) 84 | elif 'weight' in name: 85 | nn.init.orthogonal_(param, 1) 86 | # other resnet weights have been initialized in resnet_3d.py 87 | 88 | 89 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import pickle 6 | import re 7 | import numpy as np 8 | import transform_utils as tu 9 | 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | 13 | sys.path.append('../utils') 14 | sys.path.append('../backbone') 15 | from dataset_3d_lc import UCF101_3d, HMDB51_3d 16 | from model_3d_lc import * 17 | from resnet_2d3d import neq_load_customized 18 | from augmentation import * 19 | from utils import AverageMeter, AccuracyTable, ConfusionMeter, save_checkpoint, write_log, calc_topk_accuracy, denorm, calc_accuracy 20 | 21 | import torch 22 | import torch.optim as optim 23 | from torch.utils import data 24 | import torch.nn as nn 25 | from torchvision import datasets, models, transforms 26 | import torchvision.utils as vutils 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--save_dir', default='/data/nishantr/svl/', type=str, help='dir to save intermediate results') 30 | parser.add_argument('--net', default='resnet18', type=str) 31 | parser.add_argument('--model', default='lc', type=str) 32 | parser.add_argument('--dataset', default='ucf101', type=str) 33 | parser.add_argument('--modality', required=True, type=str, help="Modality to use") 34 | parser.add_argument('--split', default=1, type=int) 35 | parser.add_argument('--seq_len', default=5, type=int) 36 | parser.add_argument('--num_seq', default=8, type=int) 37 | parser.add_argument('--num_class', default=101, type=int) 38 | parser.add_argument('--dropout', default=0.5, type=float) 39 | parser.add_argument('--ds', default=3, type=int) 40 | parser.add_argument('--batch_size', default=4, type=int) 41 | parser.add_argument('--lr', default=1e-3, type=float) 42 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 43 | parser.add_argument('--resume', default='', type=str) 44 | parser.add_argument('--pretrain', default='random', type=str) 45 | parser.add_argument('--test', default='', type=str) 46 | parser.add_argument('--extensive', default=0, type=int) 47 | parser.add_argument('--epochs', default=50, type=int, help='number of total epochs to run') 48 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 49 | parser.add_argument('--print_freq', default=5, type=int) 50 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 51 | parser.add_argument('--train_what', default='last', type=str, help='Train what parameters?') 52 | parser.add_argument('--prefix', default='tmp', type=str) 53 | parser.add_argument('--img_dim', default=128, type=int) 54 | parser.add_argument('--full_eval_freq', default=10, type=int) 55 | parser.add_argument('--num_workers', default=8, type=int) 56 | parser.add_argument('--notes', default='', type=str) 57 | 58 | parser.add_argument('--ensemble', default=0, type=int) 59 | parser.add_argument('--prob_imgs', default='', type=str) 60 | parser.add_argument('--prob_flow', default='', type=str) 61 | parser.add_argument('--prob_seg', default='', type=str) 62 | parser.add_argument('--prob_kphm', default='', type=str) 63 | 64 | 65 | def get_data_loader(args, mode='train'): 66 | print("Getting data loader for:", args.modality) 67 | transform = None 68 | if mode == 'train': 69 | transform = tu.get_train_transforms(args) 70 | elif mode == 'val': 71 | transform = tu.get_val_transforms(args) 72 | elif mode == 'test': 73 | transform = tu.get_test_transforms(args) 74 | loader = get_data(transform, mode) 75 | return loader 76 | 77 | 78 | def get_num_channels(modality): 79 | if modality == mu.ImgMode: 80 | return 3 81 | elif modality == mu.FlowMode: 82 | return 2 83 | elif modality == mu.FnbFlowMode: 84 | return 2 85 | elif modality == mu.KeypointHeatmap: 86 | return 17 87 | elif modality == mu.SegMask: 88 | return 1 89 | else: 90 | assert False, "Invalid modality: {}".format(modality) 91 | 92 | 93 | def freeze_backbone(model): 94 | print('Freezing the backbone...') 95 | for name, param in model.module.named_parameters(): 96 | if ('resnet' in name) or ('rnn' in name) or ('agg' in name): 97 | param.requires_grad = False 98 | return model 99 | 100 | 101 | def unfreeze_backbone(model): 102 | print('Unfreezing the backbone...') 103 | for name, param in model.module.named_parameters(): 104 | if ('resnet' in name) or ('rnn' in name): 105 | param.requires_grad = True 106 | return model 107 | 108 | 109 | def main(): 110 | global args; args = parser.parse_args() 111 | global cuda; cuda = torch.device('cuda') 112 | 113 | if args.dataset == 'ucf101': args.num_class = 101 114 | elif args.dataset == 'hmdb51': args.num_class = 51 115 | 116 | if args.ensemble: 117 | def read_pkl(fname): 118 | if fname == '': 119 | return None 120 | with open(fname, 'rb') as f: 121 | prob = pickle.load(f) 122 | return prob 123 | ensemble(read_pkl(args.prob_imgs), read_pkl(args.prob_flow), read_pkl(args.prob_seg), read_pkl(args.prob_kphm)) 124 | sys.exit() 125 | 126 | args.in_channels = get_num_channels(args.modality) 127 | 128 | ### classifier model ### 129 | if args.model == 'lc': 130 | model = LC(sample_size=args.img_dim, 131 | num_seq=args.num_seq, 132 | seq_len=args.seq_len, 133 | in_channels=args.in_channels, 134 | network=args.net, 135 | num_class=args.num_class, 136 | dropout=args.dropout) 137 | else: 138 | raise ValueError('wrong model!') 139 | 140 | model = nn.DataParallel(model) 141 | model = model.to(cuda) 142 | global criterion; criterion = nn.CrossEntropyLoss() 143 | 144 | ### optimizer ### 145 | params = None 146 | if args.train_what == 'ft': 147 | print('=> finetune backbone with smaller lr') 148 | params = [] 149 | for name, param in model.module.named_parameters(): 150 | if ('resnet' in name) or ('rnn' in name): 151 | params.append({'params': param, 'lr': args.lr/10}) 152 | else: 153 | params.append({'params': param}) 154 | elif args.train_what == 'freeze': 155 | print('=> Freeze backbone') 156 | params = [] 157 | for name, param in model.module.named_parameters(): 158 | param.requires_grad = False 159 | else: 160 | pass # train all layers 161 | 162 | print('\n===========Check Grad============') 163 | for name, param in model.named_parameters(): 164 | if param.requires_grad == False: 165 | print(name, param.requires_grad) 166 | print('=================================\n') 167 | 168 | if params is None: 169 | params = model.parameters() 170 | 171 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 172 | # Old version 173 | # if args.dataset == 'hmdb51': 174 | # lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50,70,90], repeat=1) 175 | # elif args.dataset == 'ucf101': 176 | # if args.img_dim == 224: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[90,140,180], repeat=1) 177 | # else: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50, 70, 90], repeat=1) 178 | if args.img_dim == 224: 179 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60,120,180], repeat=1) 180 | else: 181 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50, 70, 90], repeat=1) 182 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 183 | 184 | args.old_lr = None 185 | best_acc = 0 186 | global iteration; iteration = 0 187 | global num_epoch; num_epoch = 0 188 | 189 | ### restart training ### 190 | if args.test: 191 | if os.path.isfile(args.test): 192 | print("=> loading testing checkpoint '{}'".format(args.test)) 193 | checkpoint = torch.load(args.test) 194 | try: model.load_state_dict(checkpoint['state_dict']) 195 | except: 196 | print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==') 197 | model = neq_load_customized(model, checkpoint['state_dict']) 198 | print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch'])) 199 | num_epoch = checkpoint['epoch'] 200 | elif args.test == 'random': 201 | print("=> [Warning] loaded random weights") 202 | else: 203 | raise ValueError() 204 | 205 | test_loader = get_data_loader(args, 'test') 206 | test_loss, test_acc = test(test_loader, model, extensive=args.extensive) 207 | sys.exit() 208 | else: # not test 209 | torch.backends.cudnn.benchmark = True 210 | 211 | if args.resume: 212 | if os.path.isfile(args.resume): 213 | # args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1)) 214 | args.old_lr = 1e-3 215 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 216 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 217 | args.start_epoch = checkpoint['epoch'] 218 | best_acc = checkpoint['best_acc'] 219 | model.load_state_dict(checkpoint['state_dict']) 220 | if not args.reset_lr: # if didn't reset lr, load old optimizer 221 | optimizer.load_state_dict(checkpoint['optimizer']) 222 | else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr)) 223 | iteration = checkpoint['iteration'] 224 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 225 | else: 226 | print("=> no checkpoint found at '{}'".format(args.resume)) 227 | 228 | if (not args.resume) and args.pretrain: 229 | if args.pretrain == 'random': 230 | print('=> using random weights') 231 | elif os.path.isfile(args.pretrain): 232 | print("=> loading pretrained checkpoint '{}'".format(args.pretrain)) 233 | checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) 234 | model = neq_load_customized(model, checkpoint['state_dict']) 235 | print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch'])) 236 | else: 237 | print("=> no checkpoint found at '{}'".format(args.pretrain)) 238 | 239 | ### load data ### 240 | train_loader = get_data_loader(args, 'train') 241 | val_loader = get_data_loader(args, 'val') 242 | test_loader = get_data_loader(args, 'test') 243 | 244 | # setup tools 245 | global de_normalize; de_normalize = denorm() 246 | global img_path; img_path, model_path = set_path(args) 247 | global writer_train 248 | try: # old version 249 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 250 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 251 | except: # v1.7 252 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 253 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 254 | 255 | args.test = model_path 256 | print("Model path:", model_path) 257 | 258 | # Freeze the model backbone initially 259 | model = freeze_backbone(model) 260 | cooldown = 0 261 | 262 | ### main loop ### 263 | for epoch in range(args.start_epoch, args.epochs): 264 | num_epoch = epoch 265 | 266 | train_loss, train_acc = train(train_loader, model, optimizer, epoch) 267 | val_loss, val_acc = validate(val_loader, model) 268 | scheduler.step(epoch) 269 | 270 | writer_train.add_scalar('global/loss', train_loss, epoch) 271 | writer_train.add_scalar('global/accuracy', train_acc, epoch) 272 | writer_val.add_scalar('global/loss', val_loss, epoch) 273 | writer_val.add_scalar('global/accuracy', val_acc, epoch) 274 | 275 | # save check_point 276 | is_best = val_acc > best_acc 277 | best_acc = max(val_acc, best_acc) 278 | 279 | # Perform testing if either the frequency is hit or the model is the best after a few epochs 280 | if (epoch + 1) % args.full_eval_freq == 0: 281 | test(test_loader, model) 282 | elif (epoch > 70) and (cooldown >= 5) and is_best: 283 | test(test_loader, model) 284 | cooldown = 0 285 | else: 286 | cooldown += 1 287 | 288 | save_checkpoint( 289 | state={ 290 | 'epoch': epoch+1, 291 | 'net': args.net, 292 | 'state_dict': model.state_dict(), 293 | 'best_acc': best_acc, 294 | 'optimizer': optimizer.state_dict(), 295 | 'iteration': iteration 296 | }, 297 | mode=args.modality, 298 | is_best=is_best, 299 | gap=5, 300 | filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch+1)), 301 | keep_all=False) 302 | 303 | # Unfreeze the model backbone after the first run 304 | if epoch == (args.start_epoch): 305 | model = unfreeze_backbone(model) 306 | 307 | print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) 308 | print("Model path:", model_path) 309 | 310 | 311 | def train(data_loader, model, optimizer, epoch): 312 | losses = AverageMeter() 313 | accuracy = AverageMeter() 314 | model.train() 315 | global iteration 316 | 317 | tq = tqdm(data_loader, desc="Train progress: Ep {}".format(epoch)) 318 | 319 | for idx, (input_seq, target, _) in enumerate(tq): 320 | tic = time.time() 321 | input_seq = input_seq.to(cuda) 322 | target = target.to(cuda) 323 | B = input_seq.size(0) 324 | output, _ = model(input_seq) 325 | 326 | # visualize 327 | if (iteration == 0) or (iteration == args.print_freq): 328 | if B > 2: input_seq = input_seq[0:2,:] 329 | writer_train.add_image('input_seq', 330 | de_normalize(vutils.make_grid( 331 | input_seq[:, :3, ...].transpose(2,3).contiguous().view(-1,3,args.img_dim,args.img_dim), 332 | nrow=args.num_seq*args.seq_len)), 333 | iteration) 334 | del input_seq 335 | 336 | [_, N, D] = output.size() 337 | output = output.view(B*N, D) 338 | target = target.repeat(1, N).view(-1) 339 | 340 | loss = criterion(output, target) 341 | acc = calc_accuracy(output, target) 342 | 343 | del target 344 | 345 | losses.update(loss.item(), B) 346 | accuracy.update(acc.item(), B) 347 | 348 | optimizer.zero_grad() 349 | loss.backward() 350 | optimizer.step() 351 | 352 | total_weight = 0.0 353 | decay_weight = 0.0 354 | for m in model.parameters(): 355 | if m.requires_grad: decay_weight += m.norm(2).data 356 | total_weight += m.norm(2).data 357 | 358 | tq_stats = { 359 | 'loss': losses.local_avg, 360 | 'acc': accuracy.local_avg, 361 | 'decay_wt': decay_weight.item(), 362 | 'total_wt': total_weight.item(), 363 | } 364 | 365 | tq.set_postfix(tq_stats) 366 | 367 | if idx % args.print_freq == 0: 368 | writer_train.add_scalar('local/loss', losses.val, iteration) 369 | writer_train.add_scalar('local/accuracy', accuracy.val, iteration) 370 | 371 | iteration += 1 372 | 373 | return losses.local_avg, accuracy.local_avg 374 | 375 | 376 | def validate(data_loader, model): 377 | losses = AverageMeter() 378 | accuracy = AverageMeter() 379 | model.eval() 380 | with torch.no_grad(): 381 | tq = tqdm(data_loader, desc="Val progress: ") 382 | for idx, (input_seq, target, _) in enumerate(tq): 383 | input_seq = input_seq.to(cuda) 384 | target = target.to(cuda) 385 | B = input_seq.size(0) 386 | output, _ = model(input_seq) 387 | 388 | [_, N, D] = output.size() 389 | output = output.view(B*N, D) 390 | target = target.repeat(1, N).view(-1) 391 | 392 | loss = criterion(output, target) 393 | acc = calc_accuracy(output, target) 394 | 395 | losses.update(loss.item(), B) 396 | accuracy.update(acc.item(), B) 397 | 398 | tq.set_postfix({ 399 | 'loss': losses.avg, 400 | 'acc': accuracy.avg, 401 | }) 402 | 403 | print('Val - Loss {loss.avg:.4f}\t' 404 | 'Acc: {acc.avg:.4f} \t'.format(loss=losses, acc=accuracy)) 405 | return losses.avg, accuracy.avg 406 | 407 | 408 | def test(data_loader, model, extensive=False): 409 | losses = AverageMeter() 410 | acc_top1 = AverageMeter() 411 | acc_top5 = AverageMeter() 412 | acc_table = AccuracyTable(data_loader.dataset.action_dict_decode) 413 | confusion_mat = ConfusionMeter(args.num_class) 414 | probs = {} 415 | 416 | model.eval() 417 | with torch.no_grad(): 418 | tq = tqdm(data_loader, desc="Test progress: ") 419 | for idx, (input_seq, target, index) in enumerate(tq): 420 | input_seq = input_seq.to(cuda) 421 | target = target.to(cuda) 422 | B = input_seq.size(0) 423 | input_seq = input_seq.squeeze(0) # squeeze the '1' batch dim 424 | output, _ = model(input_seq) 425 | del input_seq 426 | 427 | prob = torch.mean(torch.mean(nn.functional.softmax(output, 2), 0), 0, keepdim=True) 428 | top1, top5 = calc_topk_accuracy(prob, target, (1,5)) 429 | acc_top1.update(top1.item(), B) 430 | acc_top5.update(top5.item(), B) 431 | del top1, top5 432 | 433 | output = torch.mean(torch.mean(output, 0), 0, keepdim=True) 434 | loss = criterion(output, target.squeeze(-1)) 435 | 436 | losses.update(loss.item(), B) 437 | del loss 438 | 439 | _, pred = torch.max(output, 1) 440 | confusion_mat.update(pred, target.view(-1).byte()) 441 | acc_table.update(pred, target) 442 | probs[index] = {'prob': prob.detach().cpu(), 'target': target.detach().cpu()} 443 | 444 | tq.set_postfix({ 445 | 'loss': losses.avg, 446 | 'acc1': acc_top1.avg, 447 | 'acc5': acc_top5.avg, 448 | }) 449 | 450 | print('Test - Loss {loss.avg:.4f}\t' 451 | 'Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5)) 452 | confusion_mat.plot_mat(args.test+'.svg') 453 | write_log(content='Loss {loss.avg:.4f}\t Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5, args=args), 454 | epoch=num_epoch, 455 | filename=os.path.join(os.path.dirname(args.test), 'test_log_{}.md').format(args.notes)) 456 | with open(os.path.join(os.path.dirname(args.test), 'test_probs_{}.pkl').format(args.notes), 'wb') as f: 457 | pickle.dump(probs, f) 458 | 459 | if extensive: 460 | acc_table.print_table() 461 | acc_table.print_dict() 462 | 463 | # import ipdb; ipdb.set_trace() 464 | return losses.avg, [acc_top1.avg, acc_top5.avg] 465 | 466 | 467 | def ensemble(prob_imgs=None, prob_flow=None, prob_seg=None, prob_kphm=None): 468 | acc_top1 = AverageMeter() 469 | acc_top5 = AverageMeter() 470 | 471 | probs = [prob_imgs, prob_flow, prob_seg, prob_kphm] 472 | for idx in range(len(probs)): 473 | if probs[idx] is not None: 474 | probs[idx] = {k[0][0].data: v for k, v in probs[idx].items()} 475 | valid_probs = [x for x in probs if x is not None] 476 | weights = [2, 2, 1, 1] 477 | 478 | ovr_probs = {} 479 | for k in valid_probs[0].keys(): 480 | ovr_probs[k] = valid_probs[0][k]['prob'] * 0.0 481 | total = 0 482 | for idx in range(len(probs)): 483 | p = probs[idx] 484 | if p is not None: 485 | total += weights[idx] 486 | ovr_probs[k] += p[k]['prob'] * weights[idx] 487 | ovr_probs[k] /= total 488 | 489 | top1, top5 = calc_topk_accuracy(ovr_probs[k], valid_probs[0][k]['target'], (1, 5)) 490 | acc_top1.update(top1.item(), 1) 491 | acc_top5.update(top5.item(), 1) 492 | 493 | print('Test - Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(top1=acc_top1, top5=acc_top5)) 494 | 495 | 496 | def get_data(transform, mode='train'): 497 | print('Loading data for "%s" ...' % mode) 498 | global dataset 499 | if args.dataset == 'ucf101': 500 | dataset = UCF101_3d(mode=mode, 501 | transform=transform, 502 | seq_len=args.seq_len, 503 | num_seq=args.num_seq, 504 | downsample=args.ds, 505 | which_split=args.split, 506 | modality=args.modality 507 | ) 508 | elif args.dataset == 'hmdb51': 509 | dataset = HMDB51_3d(mode=mode, 510 | transform=transform, 511 | seq_len=args.seq_len, 512 | num_seq=args.num_seq, 513 | downsample=args.ds, 514 | which_split=args.split, 515 | modality=args.modality 516 | ) 517 | else: 518 | raise ValueError('dataset not supported') 519 | my_sampler = data.RandomSampler(dataset) 520 | if mode == 'train': 521 | data_loader = data.DataLoader(dataset, 522 | batch_size=args.batch_size, 523 | sampler=my_sampler, 524 | shuffle=False, 525 | num_workers=args.num_workers, 526 | pin_memory=True, 527 | drop_last=True) 528 | elif mode == 'val': 529 | data_loader = data.DataLoader(dataset, 530 | batch_size=args.batch_size, 531 | sampler=my_sampler, 532 | shuffle=False, 533 | num_workers=args.num_workers, 534 | pin_memory=True, 535 | drop_last=True) 536 | elif mode == 'test': 537 | data_loader = data.DataLoader(dataset, 538 | batch_size=1, 539 | sampler=my_sampler, 540 | shuffle=False, 541 | num_workers=args.num_workers, 542 | pin_memory=True) 543 | print('"%s" dataset size: %d' % (mode, len(dataset))) 544 | return data_loader 545 | 546 | 547 | def set_path(args): 548 | if args.resume: exp_path = os.path.dirname(os.path.dirname(args.resume)) 549 | else: 550 | exp_path = 'log/{args.prefix}/ft_{args.dataset}-{args.img_dim}_mode-{args.modality}_' \ 551 | 'sp{args.split}_{0}_{args.model}_bs{args.batch_size}_' \ 552 | 'lr{1}_wd{args.wd}_ds{args.ds}_seq{args.num_seq}_len{args.seq_len}_' \ 553 | 'dp{args.dropout}_train-{args.train_what}{2}'.format( 554 | 'r%s' % args.net[6::], 555 | args.old_lr if args.old_lr is not None else args.lr, 556 | '_'+args.notes, 557 | args=args) 558 | exp_path = os.path.join(args.save_dir, exp_path) 559 | img_path = os.path.join(exp_path, 'img') 560 | model_path = os.path.join(exp_path, 'model') 561 | if not os.path.exists(img_path): os.makedirs(img_path) 562 | if not os.path.exists(model_path): os.makedirs(model_path) 563 | return img_path, model_path 564 | 565 | 566 | def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): 567 | '''return the multipier for LambdaLR, 568 | 0 <= ep < 10: gamma^0 569 | 10 <= ep < 15: gamma^1 570 | 15 <= ep < 20: gamma^2 571 | 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' 572 | max_step = max(step) 573 | effective_epoch = epoch % max_step 574 | if epoch // max_step >= repeat: 575 | exp = len(step) - 1 576 | else: 577 | exp = len([i for i in step if effective_epoch>=i]) 578 | return gamma ** exp 579 | 580 | 581 | if __name__ == '__main__': 582 | main() 583 | -------------------------------------------------------------------------------- /test/transform_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | sys.path.append('../utils') 5 | from augmentation import * 6 | 7 | sys.path.append('../train') 8 | import model_utils as mu 9 | 10 | 11 | def get_train_transforms(args): 12 | if args.modality == mu.ImgMode: 13 | return get_imgs_train_transforms(args) 14 | elif args.modality == mu.FlowMode: 15 | return get_flow_transforms(args) 16 | elif args.modality == mu.KeypointHeatmap: 17 | return get_heatmap_transforms(args) 18 | elif args.modality == mu.SegMask: 19 | return get_segmask_transforms(args) 20 | 21 | 22 | def get_val_transforms(args): 23 | if args.modality == mu.ImgMode: 24 | return get_imgs_val_transforms(args) 25 | elif args.modality == mu.FlowMode: 26 | return get_flow_transforms(args) 27 | elif args.modality == mu.KeypointHeatmap: 28 | return get_heatmap_transforms(args) 29 | elif args.modality == mu.SegMask: 30 | return get_segmask_transforms(args) 31 | 32 | 33 | def get_test_transforms(args): 34 | if args.modality == mu.ImgMode: 35 | return get_imgs_test_transforms(args) 36 | elif args.modality == mu.FlowMode: 37 | return get_flow_test_transforms(args) 38 | elif args.modality == mu.KeypointHeatmap: 39 | return get_heatmap_test_transforms(args) 40 | elif args.modality == mu.SegMask: 41 | return get_segmask_test_transforms(args) 42 | 43 | 44 | def get_imgs_test_transforms(args): 45 | 46 | transform = transforms.Compose([ 47 | RandomSizedCrop(consistent=True, size=224, p=0.0), 48 | Scale(size=(args.img_dim, args.img_dim)), 49 | ToTensor(), 50 | Normalize() 51 | ]) 52 | 53 | return transform 54 | 55 | 56 | def get_flow_test_transforms(args): 57 | center_crop_size = 224 58 | if args.dataset == 'kinetics': 59 | center_crop_size = 128 60 | 61 | transform = transforms.Compose([ 62 | CenterCrop(size=center_crop_size, consistent=True), 63 | Scale(size=(args.img_dim, args.img_dim)), 64 | ToTensor(), 65 | ]) 66 | 67 | return transform 68 | 69 | 70 | def get_heatmap_test_transforms(_): 71 | transform = transforms.Compose([ 72 | CenterCropForTensors(size=192), 73 | ScaleForTensors(size=(64, 64)), 74 | ]) 75 | return transform 76 | 77 | 78 | def get_segmask_test_transforms(_): 79 | transform = transforms.Compose([ 80 | CenterCropForTensors(size=192), 81 | ScaleForTensors(size=(64, 64)), 82 | ]) 83 | return transform 84 | 85 | 86 | def get_imgs_train_transforms(args): 87 | transform = None 88 | 89 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 90 | if args.dataset == 'ucf101': 91 | transform = transforms.Compose([ 92 | RandomSizedCrop(consistent=True, size=224, p=1.0), 93 | Scale(size=(args.img_dim, args.img_dim)), 94 | RandomHorizontalFlip(consistent=True), 95 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 96 | ToTensor(), 97 | Normalize() 98 | ]) 99 | elif (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 100 | transform = transforms.Compose([ 101 | RandomSizedCrop(consistent=True, size=224, p=1.0), 102 | Scale(size=(args.img_dim, args.img_dim)), 103 | RandomHorizontalFlip(consistent=True), 104 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 105 | ToTensor(), 106 | Normalize() 107 | ]) 108 | # designed for kinetics400, short size=150, rand crop to 128x128 109 | elif args.dataset == 'kinetics': 110 | transform = transforms.Compose([ 111 | RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0), 112 | RandomHorizontalFlip(consistent=True), 113 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 114 | ToTensor(), 115 | Normalize() 116 | ]) 117 | 118 | return transform 119 | 120 | 121 | def get_imgs_val_transforms(args): 122 | transform = None 123 | 124 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 125 | if args.dataset == 'ucf101': 126 | transform = transforms.Compose([ 127 | RandomSizedCrop(consistent=True, size=224, p=0.3), 128 | Scale(size=(args.img_dim, args.img_dim)), 129 | RandomHorizontalFlip(consistent=True), 130 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 131 | ToTensor(), 132 | Normalize() 133 | ]) 134 | elif (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 135 | transform = transforms.Compose([ 136 | RandomSizedCrop(consistent=True, size=224, p=0.3), 137 | Scale(size=(args.img_dim, args.img_dim)), 138 | RandomHorizontalFlip(consistent=True), 139 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 140 | ToTensor(), 141 | Normalize() 142 | ]) 143 | # designed for kinetics400, short size=150, rand crop to 128x128 144 | elif args.dataset == 'kinetics': 145 | transform = transforms.Compose([ 146 | RandomSizedCrop(consistent=True, size=224, p=0.3), 147 | RandomHorizontalFlip(consistent=True), 148 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 149 | ToTensor(), 150 | Normalize() 151 | ]) 152 | 153 | return transform 154 | 155 | 156 | def get_flow_transforms(args): 157 | transform = None 158 | 159 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 160 | if (args.dataset == 'ucf101') or (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 161 | transform = transforms.Compose([ 162 | RandomIntensityCropForFlow(size=224), 163 | Scale(size=(args.img_dim, args.img_dim)), 164 | ToTensor(), 165 | ]) 166 | # designed for kinetics400, short size=150, rand crop to 128x128 167 | elif args.dataset == 'kinetics': 168 | transform = transforms.Compose([ 169 | RandomIntensityCropForFlow(size=args.img_dim), 170 | ToTensor(), 171 | ]) 172 | 173 | return transform 174 | 175 | 176 | def get_heatmap_transforms(_): 177 | crop_size = int(192 * 0.8) 178 | transform = transforms.Compose([ 179 | RandomIntensityCropForTensors(size=crop_size), 180 | ScaleForTensors(size=(64, 64)), 181 | ]) 182 | return transform 183 | 184 | 185 | def get_segmask_transforms(_): 186 | crop_size = int(192 * 0.8) 187 | transform = transforms.Compose([ 188 | RandomIntensityCropForTensors(size=crop_size), 189 | ScaleForTensors(size=(64, 64)), 190 | ]) 191 | return transform 192 | -------------------------------------------------------------------------------- /train/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def individual_collate(batch): 5 | """ 6 | Custom collation function for collate with new implementation of individual samples in data pipeline 7 | """ 8 | 9 | data = batch 10 | 11 | # Assuming there's at least one instance in the batch 12 | add_data_keys = data[0].keys() 13 | collected_data = {k: [] for k in add_data_keys} 14 | 15 | for i in range(len(list(data))): 16 | for k in add_data_keys: 17 | collected_data[k].append(data[i][k]) 18 | 19 | for k in add_data_keys: 20 | collected_data[k] = torch.stack(collected_data[k]) 21 | 22 | return collected_data 23 | -------------------------------------------------------------------------------- /train/dataset_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import glob 9 | import csv 10 | import scipy.io 11 | import pandas as pd 12 | import numpy as np 13 | import cv2 14 | import random 15 | 16 | import model_utils as mu 17 | 18 | sys.path.append('../utils') 19 | 20 | from copy import deepcopy 21 | from augmentation import * 22 | from tqdm import tqdm 23 | from joblib import Parallel, delayed 24 | 25 | 26 | def pil_loader(path): 27 | img = Image.open(path) 28 | return img.convert('RGB') 29 | 30 | 31 | toTensor = transforms.ToTensor() 32 | toPILImage = transforms.ToPILImage() 33 | def flow_loader(path): 34 | try: 35 | img = Image.open(path) 36 | except: 37 | return None 38 | return toTensor(img) 39 | 40 | 41 | class BaseDataloader(data.Dataset): 42 | 43 | def __init__( 44 | self, 45 | mode, 46 | transform, 47 | seq_len, 48 | num_seq, 49 | downsample, 50 | which_split, 51 | vals_to_return, 52 | sampling_method, 53 | dataset, 54 | debug=False 55 | ): 56 | super(BaseDataloader, self).__init__() 57 | 58 | self.dataset = dataset 59 | self.mode = mode 60 | self.debug = debug 61 | self.transform = transform 62 | self.seq_len = seq_len 63 | self.num_seq = num_seq 64 | self.downsample = downsample 65 | self.which_split = which_split 66 | # Describes which particular items to return e.g. ["imgs", "poses", "labels"] 67 | self.vals_to_return = set(vals_to_return) 68 | self.sampling_method = sampling_method 69 | self.num_classes = mu.get_num_classes(self.dataset) 70 | 71 | assert not ((self.dataset == "hmdb51") and ("poses" in self.vals_to_return)), \ 72 | "HMDB51 does not support poses yet" 73 | 74 | assert not ((self.dataset == "jhmdb") and ("flow" in self.vals_to_return)), \ 75 | "JHMDB does not support flow yet" 76 | 77 | if self.sampling_method == "random": 78 | assert "imgs" not in self.vals_to_return, \ 79 | "Invalid sampling method provided for imgs: {}".format(self.sampling_method) 80 | 81 | # splits 82 | mode_str = "test" if ((mode == 'val') or (mode == 'test')) else mode 83 | mode_split_str = '/' + mode_str + '_split%02d.csv' % self.which_split 84 | 85 | if "kinetics400" in dataset: 86 | mode_str = "val" if ((mode == 'val') or (mode == 'test')) else mode 87 | mode_split_str = '/' + mode_str + '_split.csv' 88 | 89 | split = '../process_data/data/' + self.dataset + mode_split_str 90 | video_info = pd.read_csv(split, header=None) 91 | 92 | # poses_mat_dict: vpath to poses_mat 93 | self.poses_dict = {} 94 | 95 | # get action list 96 | self.action_dict_encode = {} 97 | self.action_dict_decode = {} 98 | 99 | action_file = os.path.join('../process_data/data/' + self.dataset, 'classInd.txt') 100 | action_df = pd.read_csv(action_file, sep=' ', header=None) 101 | for _, row in action_df.iterrows(): 102 | act_id, act_name = row 103 | act_id = int(act_id) - 1 # let id start from 0 104 | assert 0 <= act_id < self.num_classes, "Incorrect class_id: {}".format(act_id) 105 | self.action_dict_decode[act_id] = act_name 106 | self.action_dict_encode[act_name] = act_id 107 | 108 | drop_idx = [] 109 | 110 | # filter out too short videos: 111 | for idx, row in tqdm(video_info.iterrows(), total=len(video_info)): 112 | vpath, vlen = row 113 | if self.sampling_method == 'disjoint': 114 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 115 | drop_idx.append(idx) 116 | else: 117 | if vlen <= 0: 118 | drop_idx.append(idx) 119 | 120 | self.video_info = video_info.drop(drop_idx, axis=0) 121 | 122 | if self.debug: 123 | self.video_info = self.video_info.sample(frac=0.0025, random_state=42) 124 | elif self.mode == 'val': 125 | self.video_info = self.video_info.sample(frac=0.3) 126 | # self.video_info = self.video_info.head(int(0.3 * len(self.video_info))) 127 | 128 | self.idx_sampler = None 129 | if self.sampling_method == "dynamic": 130 | self.idx_sampler = self.idx_sampler_dynamic 131 | if self.sampling_method == "disjoint": 132 | self.idx_sampler = self.idx_sampler_disjoint 133 | elif self.sampling_method == "random": 134 | self.idx_sampler = self.idx_sampler_random 135 | 136 | if self.mode == 'test': 137 | self.idx_sampler = self.idx_sampler_test 138 | 139 | if mu.FlowMode in self.vals_to_return: 140 | self.setup_flow_modality() 141 | 142 | # shuffle not required due to external sampler 143 | 144 | def setup_flow_modality(self): 145 | '''Can be overriden in the derived classes''' 146 | vpath, _ = self.video_info.iloc[0] 147 | vpath = vpath.rstrip('/') 148 | base_dir = vpath.split(self.dataset)[0] 149 | print("Base dir for flow:", base_dir) 150 | self.flow_base_path = os.path.join(base_dir, 'flow', self.dataset + '_flow/') 151 | 152 | def idx_sampler_test(self, seq_len, num_seq, vlen, vpath): 153 | ''' 154 | sample index uniformly from a video 155 | ''' 156 | 157 | downsample = self.downsample 158 | if (vlen - (num_seq * seq_len * self.downsample)) <= 0: 159 | downsample = ((vlen - 1) / (num_seq * seq_len * 1.0)) * 0.9 160 | 161 | seq_idx = np.expand_dims(np.arange(num_seq), -1) * downsample * seq_len 162 | seq_idx_block = seq_idx + np.expand_dims(np.arange(seq_len), 0) * downsample 163 | seq_idx_block = seq_idx_block.astype(int) 164 | 165 | return [seq_idx_block, vpath] 166 | 167 | def idx_sampler_dynamic(self, seq_len, num_seq, vlen, vpath): 168 | '''sample index from a video''' 169 | downsample = self.downsample 170 | if (vlen - (num_seq * seq_len * self.downsample)) <= 0: 171 | downsample = ((vlen - 1) / (num_seq * seq_len * 1.0)) * 0.9 172 | 173 | n = 1 174 | try: 175 | start_idx = np.random.choice(range(vlen - int(num_seq * seq_len * downsample)), n) 176 | except: 177 | print("Error!", vpath, vlen, num_seq, seq_len, downsample, n) 178 | 179 | seq_idx = np.expand_dims(np.arange(num_seq), -1) * downsample * seq_len + start_idx 180 | seq_idx_block = seq_idx + np.expand_dims(np.arange(seq_len), 0) * downsample 181 | seq_idx_block = seq_idx_block.astype(int) 182 | 183 | return [seq_idx_block, vpath] 184 | 185 | def idx_sampler_disjoint(self, seq_len, num_seq, vlen, vpath): 186 | '''sample index from a video''' 187 | 188 | if (vlen - (num_seq * seq_len * self.downsample)) <= 0: 189 | return None 190 | 191 | n = 1 192 | if self.mode == 'test': 193 | seq_idx_block = np.arange(0, vlen, self.downsample) # all possible frames with downsampling 194 | return [seq_idx_block, vpath] 195 | 196 | start_idx = np.random.choice(range(vlen - (num_seq * seq_len * self.downsample)), n) 197 | seq_idx = np.expand_dims(np.arange(num_seq), -1) * self.downsample * seq_len + start_idx 198 | # Shape num_seq x seq_len 199 | seq_idx_block = seq_idx + np.expand_dims(np.arange(seq_len), 0) * self.downsample 200 | 201 | return [seq_idx_block, vpath] 202 | 203 | def idx_sampler_random(self, seq_len, num_seq, vlen, vpath): 204 | '''sample index from a video''' 205 | 206 | # Here we compute the max downsampling we could perform 207 | max_ds = ((vlen - 1) // seq_len) 208 | 209 | if max_ds <= 0: 210 | return None 211 | 212 | if self.mode == 'test': 213 | seq_idx_block = np.arange(0, vlen, self.downsample) 214 | # all possible frames with downsampling 215 | return [seq_idx_block, vpath] 216 | 217 | seq_idx_block = [] 218 | for i in range(num_seq): 219 | rand_ds = random.randint(1, max_ds) 220 | start_idx = random.randint(0, vlen - (seq_len * rand_ds) - 1) 221 | seq_idx = np.arange(start=start_idx, stop=(start_idx + (seq_len*rand_ds)), step=rand_ds) 222 | seq_idx_block.append(seq_idx) 223 | 224 | seq_idx_block = np.array(seq_idx_block) 225 | 226 | return [seq_idx_block, vpath] 227 | 228 | def fetch_imgs_seq(self, vpath, seq_len, idx_block): 229 | '''Can be overriden in the derived classes''' 230 | img_list = [os.path.join(vpath, 'image_%05d.jpg' % (i + 1)) for i in idx_block] 231 | seq = [pil_loader(f) for f in img_list] 232 | img_t_seq = self.transform["imgs"](seq) # apply same transform 233 | (IC, IH, IW) = img_t_seq[0].size() 234 | img_t_seq = torch.stack(img_t_seq, 0) 235 | img_t_seq = img_t_seq.view(self.num_seq, seq_len, IC, IH, IW).transpose(1, 2) 236 | return img_t_seq 237 | 238 | @staticmethod 239 | def fill_nones(l): 240 | l = [l[i - 1] if l[i] is None else l[i] for i in range(len(l))] 241 | l = [l[i - 1] if l[i] is None else l[i] for i in range(len(l))] 242 | try: 243 | nonNoneL = [item for item in l if item is not None][0] 244 | except: 245 | nonNoneL = torch.zeros((1, 256, 256)) 246 | return [torch.zeros(nonNoneL.shape) if l[i] is None else l[i] for i in range(len(l))] 247 | 248 | def get_u_flow_path_list(self, vpath, idx_block): 249 | vid_name = os.path.basename(os.path.normpath(vpath)) 250 | return [os.path.join(self.flow_base_path, 'u', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 251 | 252 | def get_v_flow_path_list(self, vpath, idx_block): 253 | vid_name = os.path.basename(os.path.normpath(vpath)) 254 | return [os.path.join(self.flow_base_path, 'v', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 255 | 256 | def fetch_flow_seq(self, vpath, seq_len, idx_block): 257 | ''' 258 | Can be overriden in the derived classes 259 | - TODO: implement and experiment with stack flow, later on 260 | ''' 261 | 262 | u_flow_list = self.get_u_flow_path_list(vpath, idx_block) 263 | v_flow_list = self.get_v_flow_path_list(vpath, idx_block) 264 | 265 | u_seq = self.fill_nones([flow_loader(f) for f in u_flow_list]) 266 | v_seq = self.fill_nones([flow_loader(f) for f in v_flow_list]) 267 | 268 | seq = [toPILImage(torch.cat([u, v])) for u, v in zip(u_seq, v_seq)] 269 | flow_t_seq = self.transform["flow"](seq) 270 | 271 | (FC, FH, FW) = flow_t_seq[0].size() 272 | flow_t_seq = torch.stack(flow_t_seq, 0) 273 | flow_t_seq = flow_t_seq.view(self.num_seq, seq_len, FC, FH, FW).transpose(1, 2) 274 | 275 | if flow_t_seq.mean() > 0.3: 276 | flow_t_seq -= 0.5 277 | 278 | return flow_t_seq 279 | 280 | def fetch_fnb_flow_seq(self, vpath, seq_len, idx_block): 281 | pass 282 | 283 | def get_class_vid(self, vpath): 284 | return os.path.normpath(vpath).split('/')[-2:] 285 | 286 | def load_detectron_feature(self, fdir, idx, opt): 287 | # opt is either hm or seg 288 | 289 | shape = (192, 256) 290 | 291 | def load_feature(path): 292 | try: 293 | x = np.load(path)[opt] 294 | except: 295 | x = np.zeros((0, 0, 0)) 296 | 297 | # Match non-existent values 298 | if x.shape[1] == 0: 299 | num_channels = 17 if opt == 'hm' else 1 300 | x = np.zeros((num_channels, shape[0], shape[1])) 301 | 302 | x = torch.tensor(x, dtype=torch.float) / 255.0 303 | 304 | # Add extra channel in case it's not present 305 | if len(x.shape) < 3: 306 | x = x.unsqueeze(0) 307 | return x 308 | 309 | suffix = 'heatmap' if opt == 'hm' else 'segmask' 310 | fpath = os.path.join(fdir, suffix + '_%05d.npz' % idx) 311 | if os.path.isfile(fpath): 312 | return load_feature(fpath) 313 | else: 314 | # We do not have results lower than idx=2 315 | idx = max(3, idx) 316 | # We assume having all results for every two frames 317 | fpath0 = os.path.join(fdir, suffix + '_%05d.npz' % (idx - 1)) 318 | fpath1 = os.path.join(fdir, suffix + '_%05d.npz' % (idx + 1)) 319 | # This is not guaranteed to exist 320 | if not os.path.isfile(fpath1): 321 | fpath1 = fpath0 322 | a0, a1 = load_feature(fpath0), load_feature(fpath1) 323 | try: 324 | a_avg = (a0 + a1) / 2.0 325 | except: 326 | a_avg = None 327 | return a_avg 328 | 329 | def fetch_kp_heatmap_seq(self, vpath, seq_len, idx_block): 330 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 331 | 332 | feature_vpath = vpath.replace('/frame/', '/heatmaps/') 333 | seq = self.fill_nones([self.load_detectron_feature(feature_vpath, idx, opt='hm') for idx in idx_block]) 334 | 335 | if len(set([x.shape for x in seq])) > 1: 336 | # We now know the invalid paths, so no need to print them 337 | # print("Invalid path:", vpath) 338 | seq = [seq[len(seq) // 2] for _ in seq] 339 | 340 | hm_t_seq = self.transform[mu.KeypointHeatmap](seq) # apply same transform 341 | (IC, IH, IW) = hm_t_seq[0].size() 342 | 343 | hm_t_seq = hm_t_seq.view(self.num_seq, seq_len, IC, IH, IW).transpose(1, 2) 344 | return hm_t_seq 345 | 346 | def fetch_seg_mask_seq(self, vpath, seq_len, idx_block): 347 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 348 | 349 | feature_vpath = vpath.replace('/frame/', '/segmasks/') 350 | seq = self.fill_nones([self.load_detectron_feature(feature_vpath, idx, opt='seg') for idx in idx_block]) 351 | 352 | seg_t_seq = self.transform[mu.SegMask](seq) # apply same transform 353 | (IC, IH, IW) = seg_t_seq[0].size() 354 | 355 | seg_t_seq = seg_t_seq.view(self.num_seq, seq_len, IC, IH, IW).transpose(1, 2) 356 | return seg_t_seq 357 | 358 | def __getitem__(self, index): 359 | vpath, vlen = self.video_info.iloc[index] 360 | # Remove trailing backslash if any 361 | vpath = vpath.rstrip('/') 362 | 363 | seq_len = self.seq_len 364 | if "tgt" in self.vals_to_return: 365 | seq_len = 2 * self.seq_len 366 | 367 | items = self.idx_sampler(seq_len, self.num_seq, vlen, vpath) 368 | if items is None: 369 | print(vpath) 370 | 371 | idx_block, vpath = items 372 | assert idx_block.shape == (self.num_seq, seq_len) 373 | idx_block = idx_block.reshape(self.num_seq * seq_len) 374 | 375 | vals = {} 376 | 377 | # Populate return list 378 | if mu.ImgMode in self.vals_to_return: 379 | img_t_seq = self.fetch_imgs_seq(vpath, seq_len, idx_block) 380 | vals[mu.ImgMode] = img_t_seq 381 | if mu.FlowMode in self.vals_to_return: 382 | flow_t_seq = self.fetch_flow_seq(vpath, seq_len, idx_block) 383 | vals[mu.FlowMode] = flow_t_seq 384 | if mu.FnbFlowMode in self.vals_to_return: 385 | fnb_flow_t_seq = self.fetch_fnb_flow_seq(vpath, seq_len, idx_block) 386 | vals[mu.FnbFlowMode] = fnb_flow_t_seq 387 | if mu.KeypointHeatmap in self.vals_to_return: 388 | hm_t_seq = self.fetch_kp_heatmap_seq(vpath, seq_len, idx_block) 389 | vals[mu.KeypointHeatmap] = hm_t_seq 390 | if mu.SegMask in self.vals_to_return: 391 | seg_t_seq = self.fetch_seg_mask_seq(vpath, seq_len, idx_block) 392 | vals[mu.SegMask] = seg_t_seq 393 | 394 | # Process double length target results 395 | if "tgt" in self.vals_to_return: 396 | orig_keys = list(vals.keys()) 397 | for k in orig_keys: 398 | full_x = vals[k] 399 | vals[k] = full_x[:, :self.seq_len, ...] 400 | vals["tgt_" + k] = full_x[:, self.seq_len:, ...] 401 | if "labels" in self.vals_to_return: 402 | try: 403 | vname = vpath.split('/')[-3] 404 | vid = self.encode_action(vname) 405 | except: 406 | vname = vpath.split('/')[-2] 407 | vid = self.encode_action(vname) 408 | label = torch.LongTensor([vid]) 409 | vals["labels"] = label 410 | 411 | # Add video index field 412 | vals["vnames"] = torch.LongTensor([index]) 413 | 414 | return vals 415 | 416 | def __len__(self): 417 | return len(self.video_info) 418 | 419 | def encode_action(self, action_name): 420 | '''give action name, return category''' 421 | return self.action_dict_encode[action_name] 422 | 423 | def decode_action(self, action_code): 424 | '''give action code, return action name''' 425 | return self.action_dict_decode[action_code] 426 | 427 | 428 | class Kinetics_3d(BaseDataloader): 429 | 430 | def setup_flow_modality(self): 431 | '''Can be overriden in the derived classes''' 432 | self.flow_base_path = '/data/nishantr/kinetics/fnb_frames/' 433 | 434 | def get_u_flow_path_list(self, vpath, idx_block): 435 | v_class, v_name = self.get_class_vid(vpath) 436 | return [os.path.join(self.flow_base_path, self.mode, v_class, v_name, 'flow_x_%05d.jpg' % (i + 1)) for i in idx_block] 437 | 438 | def get_v_flow_path_list(self, vpath, idx_block): 439 | v_class, v_name = self.get_class_vid(vpath) 440 | return [os.path.join(self.flow_base_path, self.mode, v_class, v_name, 'flow_y_%05d.jpg' % (i + 1)) for i in idx_block] 441 | 442 | def __init__( 443 | self, 444 | mode='train', 445 | transform=None, 446 | seq_len=5, 447 | num_seq=6, 448 | downsample=3, 449 | which_split=1, 450 | vals_to_return=["imgs"], 451 | sampling_method="dynamic", 452 | use_big=False, 453 | ): 454 | dataset = "kinetics400" 455 | if use_big: 456 | dataset += "_256" 457 | super(Kinetics_3d, self).__init__( 458 | mode, 459 | transform, 460 | seq_len, 461 | num_seq, 462 | downsample, 463 | which_split, 464 | vals_to_return, 465 | sampling_method, 466 | dataset=dataset 467 | ) 468 | 469 | self.vid_shapes = {} 470 | 471 | if mu.FlowMode in self.vals_to_return: 472 | self.setup_flow_modality() 473 | 474 | def get_vid_shape(self, vpath, idx_block): 475 | v_class, v_name = self.get_class_vid(vpath) 476 | 477 | if (v_class, v_name) not in self.vid_shapes: 478 | img_list = [os.path.join(vpath, 'image_%05d.jpg' % (i + 1)) for i in idx_block[0:1]] 479 | seq = [pil_loader(f) for f in img_list] 480 | self.vid_shapes[(v_class, v_name)] = seq[0].size 481 | 482 | return self.vid_shapes[(v_class, v_name)] 483 | 484 | def fetch_flow_seq(self, vpath, seq_len, idx_block): 485 | ''' 486 | Can be overriden in the derived classes 487 | ''' 488 | 489 | shape = self.get_vid_shape(vpath, idx_block) 490 | 491 | def reshape_flow(img): 492 | new_img = img.resize((shape[0], shape[1])) 493 | assert new_img.size == shape, "Shape mismatch: {}, {}".format(new_img.shape, shape) 494 | return new_img 495 | 496 | def fill_nones(l): 497 | if l[0] is None: 498 | l[0] = torch.zeros((1, 128, 128)) 499 | for i in range(1, len(l)): 500 | if l[i] is None: 501 | l[i] = l[i-1] 502 | return l 503 | 504 | u_flow_list = self.get_u_flow_path_list(vpath, idx_block) 505 | v_flow_list = self.get_v_flow_path_list(vpath, idx_block) 506 | 507 | u_seq = fill_nones([flow_loader(f) for f in u_flow_list]) 508 | v_seq = fill_nones([flow_loader(f) for f in v_flow_list]) 509 | 510 | seq = [reshape_flow(toPILImage(torch.cat([u, v]))) for u, v in zip(u_seq, v_seq)] 511 | flow_t_seq = self.transform["flow"](seq) 512 | 513 | (FC, FH, FW) = flow_t_seq[0].size() 514 | flow_t_seq = torch.stack(flow_t_seq, 0) 515 | flow_t_seq = flow_t_seq.view(self.num_seq, seq_len, FC, FH, FW).transpose(1, 2) 516 | 517 | # Subract the mean to get interpretable optical flow 518 | if flow_t_seq.mean() > 0.3: 519 | flow_t_seq -= 0.5 520 | 521 | return flow_t_seq 522 | 523 | 524 | class UCF101_3d(BaseDataloader): 525 | 526 | def __init__( 527 | self, 528 | mode='train', 529 | transform=None, 530 | seq_len=5, 531 | num_seq=6, 532 | downsample=3, 533 | which_split=1, 534 | vals_to_return=["imgs"], 535 | sampling_method="dynamic", 536 | debug=False, 537 | ): 538 | super(UCF101_3d, self).__init__( 539 | mode, 540 | transform, 541 | seq_len, 542 | num_seq, 543 | downsample, 544 | which_split, 545 | vals_to_return, 546 | sampling_method, 547 | dataset="ucf101", 548 | debug=debug 549 | ) 550 | 551 | self.vid_shapes = {} 552 | self.fnb_flow_base_path = '/data/nishantr/ucf101/fnb_frames/' 553 | 554 | def get_fnb_u_flow_path_list(self, vpath, idx_block): 555 | v_class, v_name = self.get_class_vid(vpath) 556 | return [os.path.join(self.fnb_flow_base_path, v_class, v_name, 'flow_x_%05d.jpg' % (i + 1)) for i in idx_block] 557 | 558 | def get_fnb_v_flow_path_list(self, vpath, idx_block): 559 | v_class, v_name = self.get_class_vid(vpath) 560 | return [os.path.join(self.fnb_flow_base_path, v_class, v_name, 'flow_y_%05d.jpg' % (i + 1)) for i in idx_block] 561 | 562 | def get_vid_shape(self, vpath, idx_block): 563 | v_class, v_name = self.get_class_vid(vpath) 564 | 565 | if (v_class, v_name) not in self.vid_shapes: 566 | img_list = [os.path.join(vpath, 'image_%05d.jpg' % (i + 1)) for i in idx_block[0:1]] 567 | seq = [pil_loader(f) for f in img_list] 568 | self.vid_shapes[(v_class, v_name)] = seq[0].size 569 | 570 | return self.vid_shapes[(v_class, v_name)] 571 | 572 | def fetch_fnb_flow_seq(self, vpath, seq_len, idx_block): 573 | shape = self.get_vid_shape(vpath, idx_block) 574 | 575 | def reshape_flow(img): 576 | new_img = img.resize((shape[0], shape[1])) 577 | assert new_img.size == shape, "Shape mismatch: {}, {}".format(new_img.shape, shape) 578 | return new_img 579 | 580 | def fill_nones(l): 581 | if l[0] is None: 582 | l[0] = torch.zeros((1, 128, 128)) 583 | for i in range(1, len(l)): 584 | if l[i] is None: 585 | l[i] = l[i-1] 586 | return l 587 | 588 | u_flow_list = self.get_fnb_u_flow_path_list(vpath, idx_block) 589 | v_flow_list = self.get_fnb_v_flow_path_list(vpath, idx_block) 590 | 591 | u_seq = fill_nones([flow_loader(f) for f in u_flow_list]) 592 | v_seq = fill_nones([flow_loader(f) for f in v_flow_list]) 593 | 594 | seq = [reshape_flow(toPILImage(torch.cat([u, v]))) for u, v in zip(u_seq, v_seq)] 595 | flow_t_seq = self.transform["flow"](seq) 596 | 597 | (FC, FH, FW) = flow_t_seq[0].size() 598 | flow_t_seq = torch.stack(flow_t_seq, 0) 599 | flow_t_seq = flow_t_seq.view(self.num_seq, seq_len, FC, FH, FW).transpose(1, 2) 600 | 601 | # Subract the mean to get interpretable optical flow 602 | if flow_t_seq.mean() > 0.3: 603 | flow_t_seq -= 0.5 604 | 605 | return flow_t_seq 606 | 607 | 608 | class BaseDataloaderHMDB(BaseDataloader): 609 | 610 | def __init__( 611 | self, 612 | mode, 613 | transform, 614 | seq_len, 615 | num_seq, 616 | downsample, 617 | which_split, 618 | vals_to_return, 619 | sampling_method, 620 | dataset 621 | ): 622 | super(BaseDataloaderHMDB, self).__init__( 623 | mode, 624 | transform, 625 | seq_len, 626 | num_seq, 627 | downsample, 628 | which_split, 629 | vals_to_return, 630 | sampling_method, 631 | dataset=dataset 632 | ) 633 | 634 | 635 | class HMDB51_3d(BaseDataloaderHMDB): 636 | def __init__( 637 | self, 638 | mode='train', 639 | transform=None, 640 | seq_len=5, 641 | num_seq=6, 642 | downsample=1, 643 | which_split=1, 644 | vals_to_return=["imgs"], 645 | sampling_method="dynamic" 646 | ): 647 | super(HMDB51_3d, self).__init__( 648 | mode, 649 | transform, 650 | seq_len, 651 | num_seq, 652 | downsample, 653 | which_split, 654 | vals_to_return, 655 | sampling_method, 656 | dataset="hmdb51" 657 | ) 658 | 659 | 660 | class JHMDB_3d(BaseDataloaderHMDB): 661 | def __init__( 662 | self, 663 | mode='train', 664 | transform=None, 665 | seq_len=5, 666 | num_seq=6, 667 | downsample=1, 668 | which_split=1, 669 | vals_to_return=["imgs"], 670 | sampling_method="dynamic" 671 | ): 672 | super(JHMDB_3d, self).__init__( 673 | mode, 674 | transform, 675 | seq_len, 676 | num_seq, 677 | downsample, 678 | which_split, 679 | vals_to_return, 680 | sampling_method, 681 | dataset="jhmdb" 682 | ) -------------------------------------------------------------------------------- /train/finetune_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sklearn import metrics 4 | from sklearn.linear_model import RidgeClassifier 5 | from sklearn.cluster import MiniBatchKMeans 6 | 7 | 8 | class QuickSupervisedModelTrainer(object): 9 | 10 | def __init__(self, num_classes, modes): 11 | self.modes = modes 12 | self.mode_pairs = [(m0, m1) for m0 in self.modes for m1 in self.modes if m0 < m1] 13 | self.ridge = {m: RidgeClassifier() for m in self.modes} 14 | self.kmeans = {m: MiniBatchKMeans(n_clusters=num_classes, random_state=0, batch_size=256) for m in self.modes} 15 | 16 | def evaluate_classification(self, trainD, valD): 17 | tic = time.time() 18 | trainY, valY = trainD["Y"].cpu().numpy(), valD["Y"].cpu().numpy() 19 | for mode in self.modes: 20 | self.ridge[mode].fit(trainD["X"][mode].cpu().numpy(), trainY) 21 | score = round(self.ridge[mode].score(valD["X"][mode].cpu().numpy(), valY), 3) 22 | print("--- Mode: {} - RidgeAcc: {}".format(mode, score)) 23 | print("Time taken to perform classification evaluation:", time.time() - tic) 24 | 25 | def fit_and_predict_clustering(self, data, tag): 26 | tic = time.time() 27 | preds = {} 28 | for mode in self.modes: 29 | preds[mode] = self.kmeans[mode].fit_predict(data["X"][mode].cpu().numpy()) 30 | print("Time taken to perform {} clustering:".format(tag), time.time() - tic) 31 | return preds 32 | 33 | def evaluate_clustering_based_on_ground_truth(self, preds, label, tag): 34 | tic = time.time() 35 | for mode in self.modes: 36 | ars = round(metrics.adjusted_rand_score(preds[mode], label), 3) 37 | v_measure = round(metrics.v_measure_score(preds[mode], label), 3) 38 | print("--- Mode: {} - Adj Rand. Score: {}, V-Measure: {}".format(mode, ars, v_measure)) 39 | print("Time taken to evaluate {} clustering:".format(tag), time.time() - tic) 40 | 41 | def evaluate_clustering_based_on_mutual_information(self, preds, tag): 42 | tic = time.time() 43 | for m0, m1 in self.mode_pairs: 44 | ami = round(metrics.adjusted_mutual_info_score(preds[m0], preds[m1], average_method='max'), 3) 45 | v_measure = round(metrics.v_measure_score(preds[m0], preds[m1]), 3) 46 | print("--- Modes: {}/{} - Adj MI: {}, V Measure: {}".format(m0, m1, ami, v_measure)) 47 | print("Time taken to evaluate {} clustering MI:".format(tag), time.time() - tic) 48 | 49 | def evaluate_clustering(self, data, tag): 50 | ''' 51 | Need to evaluate clustering using the following methods, 52 | 1. Correctness of clustering based on ground truth labels 53 | a. Adjusted Rand Score 54 | b. Homogeneity, completeness and V-measure 55 | 2. Mutual information based scores (across modalities) 56 | ''' 57 | label = data["Y"].cpu().numpy() 58 | preds = self.fit_and_predict_clustering(data, tag) 59 | self.evaluate_clustering_based_on_ground_truth(preds, label, tag) 60 | self.evaluate_clustering_based_on_mutual_information(preds, tag) 61 | -------------------------------------------------------------------------------- /train/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_standard_grid_mask(batch_size0, batch_size1, pred_step, last_size, device="cuda"): 5 | B0, B1, N, LS = batch_size0, batch_size1, pred_step, last_size 6 | device = torch.device(device) 7 | 8 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 9 | 10 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 11 | mask = torch.zeros((B0, N, LS ** 2, B1, N, LS ** 2), dtype=torch.int8, requires_grad=False).detach().to(device) 12 | # spatial neg pairs 13 | mask[torch.arange(B0), :, :, torch.arange(B0), :, :] = -3 14 | # temporal neg pairs 15 | for k in range(B0): 16 | mask[k, :, torch.arange(LS ** 2), k, :, torch.arange(LS ** 2)] = -1 17 | tmp = mask.permute(0, 2, 1, 3, 5, 4).contiguous().view(B0 * LS ** 2, N, B1 * LS ** 2, N) 18 | # positive pairs 19 | for j in range(B0 * LS ** 2): 20 | tmp[j, torch.arange(N), j, torch.arange(N - N, N)] = 1 21 | mask = tmp.view(B0, LS ** 2, N, B1, LS ** 2, N).permute(0, 2, 1, 3, 5, 4) 22 | # Final shape: (B, N, LS**2, B, N, LS**2) 23 | assert torch.allclose(mask[:, :, :, B0:, :, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 24 | 25 | return mask 26 | 27 | 28 | def get_multi_modal_grid_mask(batch_size0, batch_size1, pred_step, last_size0, last_size1, device="cuda"): 29 | B0, B1, N, LS0, LS1 = batch_size0, batch_size1, pred_step, last_size0, last_size1 30 | device = torch.device(device) 31 | 32 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 33 | 34 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 35 | mask = torch.zeros((B0, N, LS0 ** 2, B1, N, LS1 ** 2), dtype=torch.int8, requires_grad=False).detach().to(device) 36 | # spatial neg pairs 37 | mask[torch.arange(B0), :, :, torch.arange(B0), :, :] = -3 38 | 39 | # temporal neg pairs 40 | for k in range(B0): 41 | mask[k, :, torch.arange(LS0 ** 2), k, :, torch.arange(LS1 ** 2)] = -1 42 | tmp = mask.permute(0, 2, 1, 3, 5, 4).contiguous().view(B0, LS0, LS0, N, B1, LS1, LS1, N) 43 | # shape: (B, LS0, LS0, N, B, LS1, LS1, N) 44 | 45 | # Generate downsamplings 46 | ds0, ds1 = LS0 // min(LS0, LS1), LS1 // min(LS0, LS1) 47 | 48 | # positive pairs 49 | for j in range(B0): 50 | for i in range(min(LS0, LS1)): 51 | tmp[j, i * ds0:(i + 1) * ds0, i * ds0:(i + 1) * ds0, torch.arange(N), 52 | j, i * ds1:(i + 1) * ds1, i * ds1:(i + 1) * ds1, torch.arange(N)] = 1 53 | 54 | # Sanity check 55 | for ib in range(B0): 56 | for jn in range(N): 57 | for jls0 in range(LS0): 58 | for jls1 in range(LS1): 59 | for jls01 in range(LS0): 60 | for jls11 in range(LS1): 61 | # Check that values match 62 | if (jls0 // ds0) == (jls1 // ds1) == (jls01 // ds0) == (jls11 // ds1): 63 | assert tmp[ib, jls0, jls01, jn, ib, jls1, jls11, jn] == 1, \ 64 | "Invalid value at {}".format((ib, jls0, jls01, jn, ib, jls1, jls11, jn)) 65 | else: 66 | assert tmp[ib, jls0, jls01, jn, ib, jls1, jls11, jn] < 1, \ 67 | "Invalid value at {}".format((ib, jls0, jls01, jn, ib, jls1, jls11, jn)) 68 | assert torch.allclose(tmp[:, :, :, :, B0:, :, :, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 69 | 70 | mask = tmp.view(B0, LS0 ** 2, N, B1, LS1 ** 2, N).permute(0, 2, 1, 3, 5, 4) 71 | # Shape: (B, N, LS0**2, B, N, LS1**2) 72 | mask = mask.contiguous().view(B0, N * LS0 ** 2, B1, N * LS1 ** 2) 73 | 74 | return mask 75 | 76 | 77 | def get_standard_instance_mask(batch_size0, batch_size1, pred_step, device="cuda"): 78 | B0, B1, N = batch_size0, batch_size1, pred_step 79 | device = torch.device(device) 80 | 81 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 82 | 83 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 84 | mask = torch.zeros((B0, N, B1, N), dtype=torch.int8, requires_grad=False).detach().to(device) 85 | # temporal neg pairs 86 | for k in range(B0): 87 | mask[k, :, k, :] = -1 88 | # positive pairs 89 | for j in range(B0): 90 | mask[j, torch.arange(N), j, torch.arange(N)] = 1 91 | for i in range(B0): 92 | for j in range(N): 93 | assert mask[i, j, i, j] == 1, "Invalid value at {}, {}".format(i, j) 94 | for xi in range(B0): 95 | if i == xi: 96 | continue 97 | for xj in range(N): 98 | if j == xj: 99 | continue 100 | assert mask[i, j, xi, xj] < 1, "Invalid value at {}, {}".format(i, j) 101 | assert torch.allclose(mask[:, :, B0:, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 102 | 103 | return mask 104 | 105 | 106 | def process_mask(mask): 107 | # dot product is computed in parallel gpus, so get less easy neg, bounded by batch size in each gpu''' 108 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 109 | target = mask == 1 110 | # This doesn't seem to cause any issues in our implementation 111 | target.requires_grad = False 112 | return target 113 | -------------------------------------------------------------------------------- /train/model_3d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import torch 4 | 5 | import torch.nn as nn 6 | import sim_utils as su 7 | import model_utils as mu 8 | import torch.nn.functional as F 9 | sys.path.append('../backbone') 10 | 11 | from select_backbone import select_resnet 12 | from convrnn import ConvGRU 13 | 14 | 15 | eps = 1e-7 16 | INF = 25.0 17 | 18 | 19 | class MyDataParallel(torch.nn.DataParallel): 20 | """ 21 | Allow nn.DataParallel to call model's attributes. 22 | """ 23 | def __getattr__(self, name): 24 | try: 25 | return super().__getattr__(name) 26 | except AttributeError: 27 | return getattr(self.module, name) 28 | 29 | 30 | def get_parallel_model(model): 31 | if torch.cuda.is_available(): 32 | dev_count = torch.cuda.device_count() 33 | print("Using {} GPUs".format(dev_count)) 34 | model = MyDataParallel(model, device_ids=list(range(dev_count))) 35 | return model 36 | 37 | 38 | def get_num_channels(modality): 39 | if modality == mu.ImgMode: 40 | return 3 41 | elif modality == mu.FlowMode: 42 | return 2 43 | elif modality == mu.FnbFlowMode: 44 | return 2 45 | elif modality == mu.KeypointHeatmap: 46 | return 17 47 | elif modality == mu.SegMask: 48 | return 1 49 | else: 50 | assert False, "Invalid modality: {}".format(modality) 51 | 52 | 53 | class ImageFetCombiner(nn.Module): 54 | 55 | def __init__(self, img_fet_dim, img_segments): 56 | super(ImageFetCombiner, self).__init__() 57 | 58 | # Input feature dimension is [B, dim, s, s] 59 | self.dim = img_fet_dim 60 | self.s = img_segments 61 | self.flat_dim = self.dim * self.s * self.s 62 | 63 | layers = [] 64 | if self.s == 7: 65 | layers.append(nn.MaxPool2d(2, 2, padding=1)) 66 | layers.append(nn.MaxPool2d(2, 2)) 67 | layers.append(nn.AvgPool2d(2, 2)) 68 | if self.s == 4: 69 | layers.append(nn.MaxPool2d(2, 2)) 70 | layers.append(nn.AvgPool2d(2, 2)) 71 | elif self.s == 2: 72 | layers.append(nn.AvgPool2d(2, 2)) 73 | 74 | # input is B x dim x s x s 75 | self.feature = nn.Sequential(*layers) 76 | # TODO: Normalize 77 | # Output is B x dim 78 | 79 | def forward(self, input: torch.Tensor): 80 | # input is B, N, D, s, s 81 | B, N, D, s, s = input.shape 82 | input = input.view(B * N, D, s, s) 83 | y = self.feature(input) 84 | y = y.reshape(B, N, -1) 85 | return y 86 | 87 | 88 | class IdentityFlatten(nn.Module): 89 | 90 | def __init__(self): 91 | super(IdentityFlatten, self).__init__() 92 | 93 | def forward(self, input: torch.Tensor): 94 | # input is B, N, D, s, s 95 | B, N, D, s, s = input.shape 96 | return input.reshape(B, N, -1) 97 | 98 | 99 | class DpcRnn(nn.Module): 100 | 101 | def get_modality_feature_extractor(self): 102 | if self.mode in [mu.ImgMode, mu.FlowMode, mu.KeypointHeatmap, mu.SegMask]: 103 | return ImageFetCombiner(self.final_feature_size, self.last_size) 104 | else: 105 | assert False, "Invalid mode provided: {}".format(self.mode) 106 | 107 | '''DPC with RNN''' 108 | def __init__(self, args): 109 | super(DpcRnn, self).__init__() 110 | 111 | torch.cuda.manual_seed(233) 112 | 113 | print('Using DPC-RNN model for mode: {}'.format(args["mode"])) 114 | self.num_seq = args["num_seq"] 115 | self.seq_len = args["seq_len"] 116 | self.pred_step = args["pred_step"] 117 | self.sample_size = args["img_dim"] 118 | self.last_duration = int(math.ceil(self.seq_len / 4)) 119 | self.last_size = int(math.ceil(self.sample_size / 32)) 120 | print('final feature map has size %dx%d' % (self.last_size, self.last_size)) 121 | 122 | self.mode = args["mode"] 123 | self.in_channels = get_num_channels(self.mode) 124 | self.l2_norm = args["l2_norm"] 125 | 126 | track_running_stats = True 127 | print("Track running stats: {}".format(track_running_stats)) 128 | self.backbone, self.param = select_resnet( 129 | args["net"], track_running_stats=track_running_stats, in_channels=self.in_channels 130 | ) 131 | 132 | # params for GRU 133 | self.param['num_layers'] = 1 134 | self.param['hidden_size'] = self.param['feature_size'] 135 | 136 | # param for current model 137 | self.final_feature_size = self.param["feature_size"] 138 | # self.final_feature_size = self.param['hidden_size'] * (self.last_size ** 2) 139 | self.total_feature_size = self.param['hidden_size'] * (self.last_size ** 2) 140 | 141 | self.agg = ConvGRU(input_size=self.param['feature_size'], 142 | hidden_size=self.param['hidden_size'], 143 | kernel_size=1, 144 | num_layers=self.param['num_layers']) 145 | self.network_pred = nn.Sequential( 146 | nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0) 149 | ) 150 | 151 | self.compiled_features = self.get_modality_feature_extractor() 152 | self.interModeDotHandler = su.InterModeDotHandler(self.last_size) 153 | self.cosSimHandler = su.CosSimHandler() 154 | 155 | self.mask = None 156 | # self.relu = nn.ReLU(inplace=False) 157 | self._initialize_weights(self.agg) 158 | self._initialize_weights(self.network_pred) 159 | 160 | def get_representation(self, block, detach=False): 161 | 162 | (B, N, C, SL, H, W) = block.shape 163 | block = block.view(B*N, C, SL, H, W) 164 | feature = self.backbone(block) 165 | del block 166 | feature = F.relu(feature) 167 | 168 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=1) 169 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B*N,D,last_size,last_size] 170 | context, _ = self.agg(feature) 171 | context = context[:,-1,:].unsqueeze(1) 172 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 173 | del feature 174 | 175 | if self.l2_norm: 176 | context = self.cosSimHandler.l2NormedVec(context, dim=2) 177 | 178 | # Return detached version if required 179 | if detach: 180 | return context.detach() 181 | else: 182 | return context 183 | 184 | def compute_cdot_features(self, feature): 185 | comp_feature = self.compiled_features(feature).unsqueeze(3).unsqueeze(3) 186 | cdot, cdot_fet = self.interModeDotHandler(comp_fet=comp_feature) 187 | return cdot, cdot_fet 188 | 189 | def forward(self, block, ret_rep=False): 190 | # ret_cdot values: [c, z, zt] 191 | 192 | # block: [B, N, C, SL, W, H] 193 | # B: Batch, N: Number of sequences per instance, C: Channels, SL: Sequence Length, W, H: Dims 194 | 195 | ### extract feature ### 196 | (B, N, C, SL, H, W) = block.shape 197 | 198 | block = block.view(B*N, C, SL, H, W) 199 | feature = self.backbone(block) 200 | 201 | del block 202 | 203 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=(1, 1, 1)) 204 | 205 | if self.l2_norm: 206 | feature = self.cosSimHandler.l2NormedVec(feature, dim=1) 207 | 208 | # before ReLU, (-inf, +inf) 209 | feature_inf_all = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) 210 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) 211 | 212 | # Generate feature for future frames 213 | feature_inf = feature_inf_all[:, N - self.pred_step::, :].contiguous() 214 | 215 | del feature_inf_all 216 | 217 | ### aggregate, predict future ### 218 | # Generate inferred future (stored in feature_inf) through the initial frames 219 | _, hidden = self.agg(feature[:, 0:N-self.pred_step, :].contiguous()) 220 | 221 | if self.l2_norm: 222 | hidden = self.cosSimHandler.l2NormedVec(hidden, dim=2) 223 | 224 | # Get the last hidden state, this gives us the predicted representation 225 | # after tanh, (-1,1). get the hidden state of last layer, last time step 226 | hidden = hidden[:, -1, :] 227 | 228 | # Predict next pred_step time steps for this instance 229 | pred = [] 230 | for i in range(self.pred_step): 231 | # sequentially pred future based on the hidden states 232 | p_tmp = self.network_pred(hidden) 233 | 234 | if self.l2_norm: 235 | p_tmp = self.cosSimHandler.l2NormedVec(p_tmp, dim=1) 236 | 237 | pred.append(p_tmp) 238 | _, hidden = self.agg(p_tmp.unsqueeze(1), hidden.unsqueeze(0)) 239 | 240 | if self.l2_norm: 241 | hidden = self.cosSimHandler.l2NormedVec(hidden, dim=2) 242 | 243 | hidden = hidden[:, -1, :] 244 | # Contains the representations for each of the next pred steps 245 | pred = torch.stack(pred, 1) # B, pred_step, xxx 246 | 247 | cdot, cdot_fet = self.compute_cdot_features(feature) 248 | 249 | # Both are of the form [B, pred_step, D, s, s] 250 | return pred, feature_inf, feature, hidden 251 | 252 | def _initialize_weights(self, module): 253 | for name, param in module.named_parameters(): 254 | if 'weight' in name: 255 | nn.init.orthogonal_(param, 1) 256 | # other resnet weights have been initialized in resnet itself 257 | 258 | def reset_mask(self): 259 | self.mask = None 260 | -------------------------------------------------------------------------------- /train/model_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import namedtuple 3 | 4 | import data_utils 5 | import os 6 | 7 | 8 | from torch.utils import data 9 | from tensorboardX import SummaryWriter 10 | from torchvision import transforms 11 | from copy import deepcopy 12 | from collections import defaultdict 13 | 14 | from dataset_3d import * 15 | 16 | sys.path.append('../utils') 17 | from utils import AverageMeter 18 | 19 | sys.path.append('../backbone') 20 | from resnet_2d3d import neq_load_customized 21 | 22 | 23 | # Constants for the framework 24 | eps = 1e-7 25 | 26 | CPCLoss = "cpc" 27 | CooperativeLoss = "coop" 28 | 29 | # Losses for mode sync 30 | ModeSim = "sim" 31 | CosSimLoss = "cossim" 32 | CorrLoss = "corr" 33 | DenseCosSimLoss = "dcssim" 34 | DenseCorrLoss = "dcrr" 35 | 36 | # Sets of different losses 37 | LossList = [CPCLoss, CosSimLoss, CorrLoss, DenseCorrLoss, DenseCosSimLoss, CooperativeLoss] 38 | ModeSyncLossList = [CosSimLoss, CorrLoss, DenseCorrLoss, DenseCosSimLoss] 39 | 40 | ImgMode = "imgs" 41 | FlowMode = "flow" 42 | FnbFlowMode = "farne" 43 | KeypointHeatmap = "kphm" 44 | SegMask = "seg" 45 | ModeList = [ImgMode, FlowMode, KeypointHeatmap, SegMask, FnbFlowMode] 46 | 47 | ModeParams = namedtuple('ModeParams', ['mode', 'img_fet_dim', 'img_fet_segments', 'final_dim']) 48 | 49 | 50 | def str2bool(s): 51 | """Convert string to bool (in argparse context).""" 52 | if s.lower() not in ['true', 'false']: 53 | raise ValueError('Need bool; got %r' % s) 54 | return {'true': True, 'false': False}[s.lower()] 55 | 56 | 57 | def str2list(s): 58 | """Convert string to list of strs, split on _""" 59 | return s.split('_') 60 | 61 | 62 | def get_multi_modal_model_train_args(): 63 | parser = argparse.ArgumentParser() 64 | 65 | # General global training parameters 66 | parser.add_argument('--save_dir', default='', type=str, help='dir to save intermediate results') 67 | parser.add_argument('--dataset', default='ucf101', type=str) 68 | parser.add_argument('--ds', default=3, type=int, help='frame downsampling rate') 69 | parser.add_argument('--seq_len', default=5, type=int, help='number of frames in each video block') 70 | parser.add_argument('--num_seq', default=8, type=int, help='number of video blocks') 71 | parser.add_argument('--pred_step', default=3, type=int) 72 | parser.add_argument('--batch_size', default=16, type=int) 73 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 74 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 75 | parser.add_argument('--print_freq', default=5, type=int, help='frequency of printing output during training') 76 | parser.add_argument('--num_workers', default=2, type=int, help='Number of workers for dataloader') 77 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 78 | parser.add_argument('--notes', default="", type=str, help='Additional notes') 79 | parser.add_argument('--vis_log_freq', default=100, type=int, help='Visualization frequency') 80 | 81 | # Evaluation specific flags 82 | parser.add_argument('--ft_freq', default=10, type=int, help='frequency to perform finetuning') 83 | 84 | # Global network and model details. Can be overriden using specific flags 85 | parser.add_argument('--net', default='resnet18', type=str) 86 | parser.add_argument('--train_what', default='all', type=str) 87 | parser.add_argument('--img_dim', default=128, type=int) 88 | parser.add_argument('--sampling', default="dynamic", type=str, help='sampling method (disjoint, random, dynamic)') 89 | parser.add_argument('--l2_norm', default=True, type=str2bool, help='Whether to perform L2 normalization') 90 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature to use with L2 normalization') 91 | 92 | # Training specific flags 93 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 94 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 95 | parser.add_argument('--losses', default="cpc", type=str2list, help='Losses to use (CPC, Align, Rep, Sim)') 96 | parser.add_argument('--dropout', default=0.3, type=float, help='Dropout to use for supervised training') 97 | parser.add_argument('--tune_bb', default=-1.0, type=float, 98 | help='Fine-tune back-bone lr degradation. Useful for pretrained weights. (0.5, 0.1, 0.05)') 99 | parser.add_argument('--tune_imgs_bb', default=-1.0, type=float, help='Fine-tune imgs back-bone lr degradation.') 100 | 101 | # Hyper-parameters 102 | parser.add_argument('--msync_wt', default=10.0, type=float, help='Loss weight to use for mode sync loss') 103 | parser.add_argument('--dot_wt', default=10.0, type=float, help='Loss weight to use for mode sync loss') 104 | 105 | # Multi-modal related flags 106 | parser.add_argument('--data_sources', default='imgs', type=str2list, help='data sources separated by _') 107 | parser.add_argument('--modalities', default="imgs", type=str2list, help='Modalitiles to consider. Separate by _') 108 | 109 | # Checkpoint flags 110 | parser.add_argument('--imgs_restore_ckpt', default=None, type=str, help='Restore checkpoint for imgs') 111 | parser.add_argument('--flow_restore_ckpt', default=None, type=str, help='Restore checkpoint for flow') 112 | parser.add_argument('--farne_restore_ckpt', default=None, type=str, help='Restore checkpoint for farne flow') 113 | parser.add_argument('--kphm_restore_ckpt', default=None, type=str, help='Restore checkpoint for kp heatmap') 114 | parser.add_argument('--seg_restore_ckpt', default=None, type=str, help='Restore checkpoint for seg') 115 | 116 | # TODO: Flags to be fixed/revamped 117 | # Need to change restore for each ckpt 118 | 119 | # Flags which need not be touched 120 | parser.add_argument('--resume', default='', type=str, help='path of model to resume') 121 | parser.add_argument('--pretrain', default='', type=str, help='path of pretrained model') 122 | parser.add_argument('--prefix', default='22Mar', type=str, help='prefix of checkpoint filename') 123 | 124 | # Extra arguments 125 | parser.add_argument('--debug', default=False, type=str2bool, help='Reduces latency for data ops') 126 | 127 | return parser 128 | 129 | 130 | def get_num_classes(dataset): 131 | if 'kinetics' in dataset: 132 | return 400 133 | elif dataset == 'ucf101': 134 | return 101 135 | elif dataset == 'jhmdb': 136 | return 21 137 | elif dataset == 'hmdb51': 138 | return 51 139 | else: 140 | return None 141 | 142 | 143 | def get_transforms(args): 144 | return { 145 | ImgMode: get_imgs_transforms(args), 146 | FlowMode: get_flow_transforms(args), 147 | FnbFlowMode: get_flow_transforms(args), 148 | KeypointHeatmap: get_heatmap_transforms(args), 149 | SegMask: get_segmask_transforms(args), 150 | } 151 | 152 | 153 | def get_test_transforms(args): 154 | return { 155 | ImgMode: get_imgs_test_transforms(args), 156 | FlowMode: get_flow_test_transforms(args), 157 | FnbFlowMode: get_flow_test_transforms(args), 158 | KeypointHeatmap: get_heatmap_test_transforms(args), 159 | SegMask: get_segmask_test_transforms(args), 160 | } 161 | 162 | 163 | def convert_to_dict(args): 164 | if type(args) != dict: 165 | args_dict = vars(args) 166 | else: 167 | args_dict = args 168 | return args_dict 169 | 170 | 171 | def get_imgs_test_transforms(args): 172 | args_dict = convert_to_dict(args) 173 | 174 | transform = transforms.Compose([ 175 | CenterCrop(size=224, consistent=True), 176 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 177 | ToTensor(), 178 | Normalize() 179 | ]) 180 | 181 | return transform 182 | 183 | 184 | def get_flow_test_transforms(args): 185 | args_dict = convert_to_dict(args) 186 | dim = min(128, args_dict["img_dim"]) 187 | 188 | center_crop_size = 224 189 | if args_dict["dataset"] == 'kinetics': 190 | center_crop_size = 128 191 | 192 | transform = transforms.Compose([ 193 | CenterCrop(size=center_crop_size, consistent=True), 194 | Scale(size=(dim, dim)), 195 | ToTensor(), 196 | ]) 197 | 198 | return transform 199 | 200 | 201 | def get_heatmap_test_transforms(_): 202 | transform = transforms.Compose([ 203 | CenterCropForTensors(size=192), 204 | ScaleForTensors(size=(64, 64)), 205 | ]) 206 | return transform 207 | 208 | 209 | def get_segmask_test_transforms(_): 210 | transform = transforms.Compose([ 211 | CenterCropForTensors(size=192), 212 | ScaleForTensors(size=(64, 64)), 213 | ]) 214 | return transform 215 | 216 | 217 | def get_imgs_transforms(args): 218 | 219 | args_dict = convert_to_dict(args) 220 | transform = None 221 | 222 | if args_dict["debug"]: 223 | return transforms.Compose([ 224 | CenterCrop(size=224, consistent=True), 225 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 226 | ToTensor(), 227 | Normalize() 228 | ]) 229 | 230 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 231 | if args_dict["dataset"] == 'ucf101': 232 | transform = transforms.Compose([ 233 | RandomHorizontalFlip(consistent=True), 234 | RandomCrop(size=224, consistent=True), 235 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 236 | RandomGray(consistent=False, p=0.5), 237 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 238 | ToTensor(), 239 | Normalize() 240 | ]) 241 | elif (args_dict["dataset"] == 'jhmdb') or (args_dict["dataset"] == 'hmdb51'): 242 | transform = transforms.Compose([ 243 | RandomHorizontalFlip(consistent=True), 244 | RandomCrop(size=224, consistent=True), 245 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 246 | RandomGray(consistent=False, p=0.5), 247 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 248 | ToTensor(), 249 | Normalize() 250 | ]) 251 | # designed for kinetics400, short size=150, rand crop to 128x128 252 | elif args_dict["dataset"] == 'kinetics': 253 | transform = transforms.Compose([ 254 | RandomSizedCrop(size=args_dict["img_dim"], consistent=True, p=1.0), 255 | RandomHorizontalFlip(consistent=True), 256 | RandomGray(consistent=False, p=0.5), 257 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 258 | ToTensor(), 259 | Normalize() 260 | ]) 261 | 262 | return transform 263 | 264 | 265 | def get_flow_transforms(args): 266 | # TODO: Add random horizontal flip 267 | 268 | args_dict = convert_to_dict(args) 269 | dim = min(128, args_dict["img_dim"]) 270 | transform = None 271 | 272 | if args_dict["debug"]: 273 | return transforms.Compose([ 274 | Scale(size=(dim, dim)), 275 | ToTensor(), 276 | ]) 277 | 278 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 279 | if (args_dict["dataset"] == 'ucf101') or (args_dict["dataset"] == 'jhmdb') or (args_dict["dataset"] == 'hmdb51'): 280 | transform = transforms.Compose([ 281 | RandomIntensityCropForFlow(size=224), 282 | Scale(size=(dim, dim)), 283 | ToTensor(), 284 | ]) 285 | # designed for kinetics400, short size=150, rand crop to 128x128 286 | elif args_dict["dataset"] == 'kinetics': 287 | transform = transforms.Compose([ 288 | RandomIntensityCropForFlow(size=dim), 289 | ToTensor(), 290 | ]) 291 | 292 | return transform 293 | 294 | 295 | def get_heatmap_transforms(_): 296 | crop_size = int(192 * 0.8) 297 | transform = transforms.Compose([ 298 | RandomIntensityCropForTensors(size=crop_size), 299 | ScaleForTensors(size=(64, 64)), 300 | ]) 301 | return transform 302 | 303 | 304 | def get_segmask_transforms(_): 305 | crop_size = int(192 * 0.8) 306 | transform = transforms.Compose([ 307 | RandomIntensityCropForTensors(size=crop_size), 308 | ScaleForTensors(size=(64, 64)), 309 | ]) 310 | return transform 311 | 312 | 313 | def get_poses_transforms(): 314 | return transforms.Compose([pu.RandomShift(), pu.Rescale()]) 315 | 316 | 317 | def get_writers(img_path): 318 | 319 | try: # old version 320 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 321 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 322 | except: # v1.7 323 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 324 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 325 | 326 | return writer_train, writer_val 327 | 328 | 329 | def get_dataset_loaders(args, transform, mode='train'): 330 | print('Loading data for "%s" ...' % mode) 331 | 332 | if type(args) != dict: 333 | args_dict = deepcopy(vars(args)) 334 | else: 335 | args_dict = args 336 | 337 | if args_dict['debug']: 338 | orig_mode = mode 339 | mode = 'train' 340 | 341 | use_big_K400 = False 342 | if args_dict["dataset"] == 'kinetics': 343 | use_big_K400 = args_dict["img_dim"] > 150 344 | dataset = Kinetics_3d( 345 | mode=mode, 346 | transform=transform, 347 | seq_len=args_dict["seq_len"], 348 | num_seq=args_dict["num_seq"], 349 | downsample=args_dict["ds"], 350 | vals_to_return=args_dict["data_sources"].split('_'), 351 | use_big=use_big_K400, 352 | ) 353 | elif args_dict["dataset"] == 'ucf101': 354 | dataset = UCF101_3d( 355 | mode=mode, 356 | transform=transform, 357 | seq_len=args_dict["seq_len"], 358 | num_seq=args_dict["num_seq"], 359 | downsample=args_dict["ds"], 360 | vals_to_return=args_dict["data_sources"].split('_'), 361 | debug=args_dict["debug"]) 362 | elif args_dict["dataset"] == 'jhmdb': 363 | dataset = JHMDB_3d(mode=mode, 364 | transform=transform, 365 | seq_len=args_dict["seq_len"], 366 | num_seq=args_dict["num_seq"], 367 | downsample=args_dict["ds"], 368 | vals_to_return=args_dict["data_sources"].split('_'), 369 | sampling_method=args_dict["sampling"]) 370 | elif args_dict["dataset"] == 'hmdb51': 371 | dataset = HMDB51_3d(mode=mode, 372 | transform=transform, 373 | seq_len=args_dict["seq_len"], 374 | num_seq=args_dict["num_seq"], 375 | downsample=args_dict["ds"], 376 | vals_to_return=args_dict["data_sources"].split('_'), 377 | sampling_method=args_dict["sampling"]) 378 | else: 379 | raise ValueError('dataset not supported') 380 | 381 | val_sampler = data.SequentialSampler(dataset) 382 | if use_big_K400: 383 | train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=int(0.2 * len(dataset))) 384 | else: 385 | train_sampler = data.RandomSampler(dataset) 386 | 387 | if args_dict["debug"]: 388 | if orig_mode == 'val': 389 | train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=200) 390 | else: 391 | train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=2000) 392 | val_sampler = data.RandomSampler(dataset) 393 | # train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=100) 394 | 395 | data_loader = None 396 | if mode == 'train': 397 | data_loader = data.DataLoader(dataset, 398 | batch_size=args_dict["batch_size"], 399 | sampler=train_sampler, 400 | shuffle=False, 401 | num_workers=args_dict["num_workers"], 402 | collate_fn=data_utils.individual_collate, 403 | pin_memory=True, 404 | drop_last=True) 405 | elif mode == 'val': 406 | data_loader = data.DataLoader(dataset, 407 | sampler=val_sampler, 408 | batch_size=args_dict["batch_size"], 409 | shuffle=False, 410 | num_workers=args_dict["num_workers"], 411 | collate_fn=data_utils.individual_collate, 412 | pin_memory=True, 413 | drop_last=True) 414 | elif mode == 'test': 415 | data_loader = data.DataLoader(dataset, 416 | sampler=val_sampler, 417 | batch_size=args_dict["batch_size"], 418 | shuffle=False, 419 | num_workers=args_dict["num_workers"], 420 | collate_fn=data_utils.individual_collate, 421 | pin_memory=True, 422 | drop_last=False) 423 | 424 | print('"%s" dataset size: %d' % (mode, len(dataset))) 425 | return data_loader 426 | 427 | 428 | def set_multi_modal_path(args): 429 | if args.resume: 430 | exp_path = os.path.dirname(os.path.dirname(args.resume)) 431 | else: 432 | args.modes_str = '_'.join(args.modes) 433 | args.l2norm_str = str(args.l2_norm) 434 | exp_path = 'logs/{args.prefix}/{args.dataset}-{args.img_dim}_{0}_' \ 435 | 'bs{args.batch_size}_seq{args.num_seq}_pred{args.pred_step}_len{args.seq_len}_ds{args.ds}_' \ 436 | 'train-{args.train_what}{1}_modes-{args.modes_str}_l2norm' \ 437 | '_{args.l2norm_str}_{args.notes}'.format( 438 | 'r%s' % args.net[6::], 439 | '_pt=%s' % args.pretrain.replace('/','-') if args.pretrain else '', 440 | args=args 441 | ) 442 | exp_path = os.path.join(args.save_dir, exp_path) 443 | 444 | img_path = os.path.join(exp_path, 'img') 445 | model_path = os.path.join(exp_path, 'model') 446 | if not os.path.exists(img_path): os.makedirs(img_path) 447 | if not os.path.exists(model_path): os.makedirs(model_path) 448 | return img_path, model_path 449 | 450 | 451 | def process_output(mask): 452 | '''task mask as input, compute the target for contrastive loss''' 453 | # dot product is computed in parallel gpus, so get less easy neg, bounded by batch size in each gpu''' 454 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 455 | (B, NP, SQ, B2, NS, _) = mask.size() # [B, P, SQ, B, N, SQ] 456 | target = mask == 1 457 | target.requires_grad = False 458 | return target, (B, B2, NS, NP, SQ) 459 | 460 | 461 | def check_name_to_be_avoided(k): 462 | # modules_to_avoid = ['.agg.', '.network_pred.'] 463 | modules_to_avoid = [] 464 | for m in modules_to_avoid: 465 | if m in k: 466 | return True 467 | return False 468 | 469 | 470 | def load_model(model, model_path): 471 | if os.path.isfile(model_path): 472 | print("=> loading resumed checkpoint '{}'".format(model_path)) 473 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 474 | model = neq_load_customized(model, checkpoint['state_dict']) 475 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(model_path, checkpoint['epoch'])) 476 | else: 477 | print("[WARNING] no checkpoint found at '{}'".format(model_path)) 478 | return model 479 | 480 | 481 | def get_stats_dict(losses_dict, stats, eval=False): 482 | postfix_dict = {} 483 | 484 | # Populate accuracies 485 | for loss in stats.keys(): 486 | for mode in stats[loss].keys(): 487 | for stat, meter in stats[loss][mode].items(): 488 | val = meter.avg if eval else meter.local_avg 489 | postfix_dict[loss[:3] + '_' + mode[:3] + "_" + str(stat)] = round(val, 3) 490 | 491 | # Populate losses 492 | for loss in losses_dict.keys(): 493 | for key, meter in losses_dict[loss].items(): 494 | key_str = "l_{}_{}".format(loss, key[:3]) 495 | val = meter.avg if eval else meter.local_avg 496 | postfix_dict[key_str] = round(val, 3) 497 | 498 | return postfix_dict 499 | 500 | 501 | def init_loggers(losses): 502 | losses_dict = {l: defaultdict(lambda: AverageMeter()) for l in losses} 503 | 504 | stats = {} 505 | for loss in losses: 506 | # Creates a nested default dict 507 | stats[loss] = defaultdict(lambda: defaultdict(lambda: AverageMeter())) 508 | 509 | return losses_dict, stats 510 | -------------------------------------------------------------------------------- /train/sim_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import model_utils as mu 9 | 10 | sys.path.append('../utils') 11 | from utils import calc_topk_accuracy 12 | from random import random 13 | 14 | 15 | eps = 1e-5 16 | INF = 1000.0 17 | 18 | 19 | class MemoryBank(nn.Module): 20 | 21 | def __init__(self, size): 22 | super(MemoryBank, self).__init__() 23 | self.maxlen = size 24 | self.dim = None 25 | self.bank = None 26 | 27 | def bootstrap(self, X): 28 | self.dim = X.shape[1:] 29 | gcd = math.gcd(X.shape[0], self.maxlen) 30 | self.bank = torch.cat([X[:gcd]] * (self.maxlen // gcd), dim=0).detach().to(X.device) 31 | assert self.bank.shape[0] == self.maxlen, "Invalid shape: {}".format(self.bank.shape) 32 | self.bank.requires_grad = False 33 | 34 | def update(self, X): 35 | # Initialize the memory bank 36 | N = X.shape[0] 37 | if self.dim is None: 38 | self.bootstrap(X) 39 | 40 | assert X.shape[1:] == self.dim, "Invalid size: {} {}".format(X.shape, self.dim) 41 | self.bank = torch.cat([self.bank[N:], X.detach().to(X.device)], dim=0).detach() 42 | 43 | def fetchBank(self): 44 | if self.bank is not None: 45 | assert self.bank.requires_grad is False, "Bank grad not false: {}".format(self.bank.requires_grad) 46 | return self.bank 47 | 48 | def fetchAppended(self, X): 49 | if self.bank is None: 50 | self.bootstrap(X) 51 | return self.fetchAppended(X) 52 | assert X.shape[1:] == self.bank.shape[1:], "Invalid shapes: {}, {}".format(X.shape, self.bank.shape) 53 | assert self.bank.requires_grad is False, "Bank grad not false: {}".format(self.bank.requires_grad) 54 | return torch.cat([X, self.bank], dim=0) 55 | 56 | 57 | class WeightNormalizedMarginLoss(nn.Module): 58 | def __init__(self, target): 59 | super(WeightNormalizedMarginLoss, self).__init__() 60 | 61 | self.target = target.float().clone() 62 | 63 | # Parameters for the weight loss 64 | self.f = 0.5 65 | self.one_ratio = self.target[self.target == 1].numel() / (self.target.numel() * 1.0) 66 | 67 | # Setup weight mask 68 | self.weight_mask = target.float().clone() 69 | self.weight_mask[self.weight_mask >= 1.] = self.f * (1 - self.one_ratio) 70 | self.weight_mask[self.weight_mask <= 0.] = (1. - self.f) * self.one_ratio 71 | 72 | # Normalize the weight accordingly 73 | self.weight_mask = self.weight_mask.to(self.target.device) / (self.one_ratio * (1. - self.one_ratio)) 74 | 75 | self.hinge_target = self.target.clone() 76 | self.hinge_target[self.hinge_target >= 1] = 1 77 | self.hinge_target[self.hinge_target <= 0] = -1 78 | 79 | self.dummy_target = self.target.clone() 80 | 81 | self.criteria = nn.HingeEmbeddingLoss(margin=((1 - self.f) / (1 - self.one_ratio))) 82 | 83 | def forward(self, value): 84 | distance = 1.0 - value 85 | return self.criteria(self.weight_mask * distance, self.hinge_target) 86 | 87 | 88 | class SimHandler(nn.Module): 89 | 90 | def __init__(self): 91 | super(SimHandler, self).__init__() 92 | 93 | def verify_shape_for_dot_product(self, mode0, mode1): 94 | 95 | B, N, D = mode0.shape 96 | assert (B, N, D) == tuple(mode1.shape), \ 97 | "Mismatch between mode0 and mode1 features: {}, {}".format(mode0.shape, mode1.shape) 98 | 99 | # dot product in mode0-mode1 pair, get a 4d tensor. First 2 dims are from mode0, the last from mode1 100 | nmode0 = mode0.view(B * N, D) 101 | nmode1 = mode1.view(B * N, D) 102 | 103 | return nmode0, nmode1, B, N, D 104 | 105 | def get_feature_cross_pair_score(self, mode0, mode1): 106 | """ 107 | Gives us all pair wise scores 108 | (mode0/mode1)features: [B, N, D], [B2, N2, D] 109 | Returns 4D pair score tensor 110 | """ 111 | 112 | B1, N1, D1 = mode0.shape 113 | B2, N2, D2 = mode1.shape 114 | 115 | assert D1 == D2, "Different dimensions: {} {}".format(mode0.shape, mode1.shape) 116 | nmode0 = mode0.view(B1 * N1, D1) 117 | nmode1 = mode1.view(B2 * N2, D2) 118 | 119 | score = torch.matmul( 120 | nmode0.reshape(B1 * N1, D1), 121 | nmode1.reshape(B2 * N2, D1).transpose(0, 1) 122 | ).view(B1, N1, B2, N2) 123 | 124 | return score 125 | 126 | def get_feature_pair_score(self, mode0, mode1): 127 | """ 128 | Returns aligned pair scores 129 | (pred/gt)features: [B, N, D] 130 | Returns 2D pair score tensor 131 | """ 132 | 133 | nmode0, nmode1, B, N, D = self.verify_shape_for_dot_product(mode0, mode1) 134 | score = torch.bmm( 135 | nmode0.view(B * N, 1, D), 136 | nmode1.view(B * N, D, 1) 137 | ).view(B, N) 138 | 139 | return score 140 | 141 | def l2NormedVec(self, x, dim=-1): 142 | assert x.shape[dim] >= 256, "Invalid dimension for reduction: {}".format(x.shape) 143 | return x / (torch.norm(x, p=2, dim=dim, keepdim=True) + eps) 144 | 145 | 146 | class CosSimHandler(SimHandler): 147 | 148 | def __init__(self): 149 | super(CosSimHandler, self).__init__() 150 | 151 | self.target = None 152 | self.criterion = nn.MSELoss() 153 | 154 | def score(self, mode0, mode1): 155 | cosSim = self.get_feature_pair_score(self.l2NormedVec(mode0), self.l2NormedVec(mode1)) 156 | 157 | assert cosSim.min() >= -1. - eps, "Invalid value for cos sim: {}".format(cosSim) 158 | assert cosSim.max() <= 1. + eps, "Invalid value for cos sim: {}".format(cosSim) 159 | 160 | return cosSim 161 | 162 | def forward(self, mode0, mode1): 163 | score = self.score(mode0, mode1) 164 | 165 | if self.target is None: 166 | self.target = torch.ones_like(score) 167 | 168 | stats = {"m": score.mean()} 169 | 170 | return self.criterion(score, self.target), stats 171 | 172 | 173 | class CorrSimHandler(SimHandler): 174 | 175 | def __init__(self): 176 | super(CorrSimHandler, self).__init__() 177 | 178 | self.shapeMode0, self.shapeMode1 = None, None 179 | self.runningMeanMode0 = None 180 | self.runningMeanMode1 = None 181 | 182 | self.retention = 0.7 183 | self.target = None 184 | 185 | self.criterion = nn.L1Loss() 186 | 187 | self.noInitYet = True 188 | 189 | @staticmethod 190 | def get_ovr_mean(mode): 191 | return mode.mean(dim=0, keepdim=True).mean(dim=1, keepdim=True).detach().cpu() 192 | 193 | def init_vars(self, mode0, mode1): 194 | 195 | self.shapeMode0 = mode0.shape 196 | self.shapeMode1 = mode1.shape 197 | 198 | assert len(self.shapeMode0) == 3 199 | 200 | self.runningMeanMode0 = self.get_ovr_mean(mode0) 201 | self.runningMeanMode1 = self.get_ovr_mean(mode1) 202 | 203 | self.noInitYet = False 204 | 205 | def update_means(self, mean0, mean1): 206 | 207 | self.runningMeanMode0 = (self.runningMeanMode0 * self.retention) + (mean0 * (1. - self.retention)) 208 | self.runningMeanMode1 = (self.runningMeanMode1 * self.retention) + (mean1 * (1. - self.retention)) 209 | 210 | def get_means_on_device(self, device): 211 | return self.runningMeanMode0.to(device), self.runningMeanMode1.to(device) 212 | 213 | def score(self, mode0, mode1): 214 | 215 | if self.noInitYet: 216 | self.init_vars(mode0, mode1) 217 | 218 | meanMode0 = self.get_ovr_mean(mode0) 219 | meanMode1 = self.get_ovr_mean(mode1) 220 | self.update_means(meanMode0.detach().cpu(), meanMode1.detach().cpu()) 221 | runningMean0, runningMean1 = self.get_means_on_device(mode0.device) 222 | 223 | corr = self.get_feature_pair_score( 224 | self.l2NormedVec(mode0 - runningMean0), 225 | self.l2NormedVec(mode1 - runningMean1) 226 | ) 227 | 228 | assert corr.min() >= -1. - eps, "Invalid value for correlation: {}".format(corr) 229 | assert corr.max() <= 1. + eps, "Invalid value for correlation: {}".format(corr) 230 | 231 | return corr 232 | 233 | def forward(self, mode0, mode1): 234 | score = self.score(mode0, mode1) 235 | 236 | if self.target is None: 237 | self.target = torch.ones_like(score) 238 | 239 | stats = {"m": score.mean()} 240 | 241 | return self.criterion(score, self.target), stats 242 | 243 | 244 | class DenseCorrSimHandler(CorrSimHandler): 245 | 246 | def __init__(self, instance_label): 247 | super(DenseCorrSimHandler, self).__init__() 248 | 249 | self.target = instance_label.float().clone() 250 | # self.criterion = WeightNormalizedMSELoss(self.target) 251 | self.criterion = WeightNormalizedMarginLoss(self.target) 252 | 253 | def get_feature_pair_score(self, mode0, mode1): 254 | return self.get_feature_cross_pair_score(mode0, mode1) 255 | 256 | def forward(self, mode0, mode1): 257 | score = self.score(mode0, mode1) 258 | 259 | B, N, B2, N2 = score.shape 260 | assert (B, N) == (B2, N2), "Invalid shape: {}".format(score.shape) 261 | assert score.shape == self.target.shape, "Invalid shape: {}, {}".format(score.shape, self.target.shape) 262 | 263 | stats = { 264 | "m": (self.criterion.weight_mask * score).mean(), 265 | "m-": score[self.target <= 0].mean(), 266 | "m+": score[self.target > 0].mean(), 267 | } 268 | 269 | return self.criterion(score), stats 270 | 271 | 272 | class DenseCosSimHandler(CosSimHandler): 273 | 274 | def __init__(self, instance_label): 275 | super(DenseCosSimHandler, self).__init__() 276 | 277 | self.target = instance_label.float() 278 | # self.criterion = WeightNormalizedMSELoss(self.target) 279 | self.criterion = WeightNormalizedMarginLoss(self.target) 280 | 281 | def get_feature_pair_score(self, mode0, mode1): 282 | return self.get_feature_cross_pair_score(mode0, mode1) 283 | 284 | def forward(self, mode0, mode1): 285 | score = self.score(mode0, mode1) 286 | assert score.shape == self.target.shape, "Invalid shape: {}, {}".format(score.shape, self.target.shape) 287 | 288 | stats = { 289 | "m": (self.criterion.weight_mask * score).mean(), 290 | "m-": score[self.target <= 0].mean(), 291 | "m+": score[self.target > 0].mean(), 292 | } 293 | 294 | return self.criterion(score), stats 295 | 296 | 297 | class InterModeDotHandler(nn.Module): 298 | 299 | def __init__(self, last_size=1): 300 | super(InterModeDotHandler, self).__init__() 301 | 302 | self.cosSimHandler = CosSimHandler() 303 | self.last_size = last_size 304 | 305 | def contextFetHelper(self, context): 306 | context = context[:, -1, :].unsqueeze(1) 307 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 308 | return context 309 | 310 | def fetHelper(self, z): 311 | B, N, D, S, S = z.shape 312 | z = z.permute(0, 1, 3, 4, 2).contiguous().view(B, N * S * S, D) 313 | return z 314 | 315 | def dotProdHelper(self, z, zt): 316 | return self.cosSimHandler.get_feature_cross_pair_score( 317 | self.cosSimHandler.l2NormedVec(z), self.cosSimHandler.l2NormedVec(zt) 318 | ) 319 | 320 | def get_cluster_dots(self, feature): 321 | fet = self.fetHelper(feature) 322 | return self.dotProdHelper(fet, fet) 323 | 324 | def forward(self, context=None, comp_pred=None, comp_fet=None): 325 | cdot = self.fetHelper(comp_fet) 326 | return self.dotProdHelper(cdot, cdot), cdot 327 | -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | import numpy as np 6 | from PIL import ImageOps, Image 7 | from joblib import Parallel, delayed 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as F 12 | 13 | 14 | class Padding: 15 | def __init__(self, pad): 16 | self.pad = pad 17 | 18 | def __call__(self, img): 19 | return ImageOps.expand(img, border=self.pad, fill=0) 20 | 21 | 22 | class Scale: 23 | def __init__(self, size, interpolation=Image.NEAREST): 24 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 25 | self.size = size 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, imgmap): 29 | # assert len(imgmap) > 1 # list of images 30 | img1 = imgmap[0] 31 | if isinstance(self.size, int): 32 | w, h = img1.size 33 | if (w <= h and w == self.size) or (h <= w and h == self.size): 34 | return imgmap 35 | if w < h: 36 | ow = self.size 37 | oh = int(self.size * h / w) 38 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 39 | else: 40 | oh = self.size 41 | ow = int(self.size * w / h) 42 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 43 | else: 44 | return [i.resize(self.size, self.interpolation) for i in imgmap] 45 | 46 | 47 | class ScaleForTensors: 48 | def __init__(self, size, interpolation=Image.NEAREST): 49 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 50 | self.size = size 51 | self.interpolation = interpolation 52 | self.toTensor = transforms.ToTensor() 53 | self.toPILImage = transforms.ToPILImage() 54 | 55 | def resize_multi_channel_image(self, img_tensor_list, size): 56 | c, h, w = img_tensor_list[0].shape 57 | assert c < 20, "Invalid shape: {}".format(img_tensor_list.shape) 58 | 59 | resized_channels = [ 60 | torch.stack([ 61 | self.toTensor(self.toPILImage(img_tensor_list[idx][c]).resize(size, self.interpolation)).squeeze(0) 62 | for c in range(img_tensor_list[idx].shape[0]) 63 | ]) for idx in range(len(img_tensor_list)) 64 | ] 65 | resized_img_tensor = torch.stack(resized_channels) 66 | assert resized_img_tensor[0].shape == (c, size[0], size[1]), \ 67 | "Invalid shape: {}, orig: {}".format(resized_img_tensor.shape, img_tensor_list[0].shape) 68 | return resized_img_tensor 69 | 70 | def __call__(self, img_tensor_list): 71 | # assert len(imgmap) > 1 # list of images 72 | img1 = img_tensor_list[0] 73 | if isinstance(self.size, int): 74 | c, h, w = img1.shape 75 | if (w <= h and w == self.size) or (h <= w and h == self.size): 76 | return img_tensor_list 77 | if w < h: 78 | ow = self.size 79 | oh = int(self.size * h / w) 80 | return self.resize_multi_channel_image(img_tensor_list, (ow, oh)) 81 | else: 82 | oh = self.size 83 | ow = int(self.size * w / h) 84 | return self.resize_multi_channel_image(img_tensor_list, (ow, oh)) 85 | else: 86 | return self.resize_multi_channel_image(img_tensor_list, self.size) 87 | 88 | 89 | class CenterCrop: 90 | def __init__(self, size, consistent=True): 91 | if isinstance(size, numbers.Number): 92 | self.size = (int(size), int(size)) 93 | else: 94 | self.size = size 95 | 96 | def __call__(self, imgmap): 97 | img1 = imgmap[0] 98 | w, h = img1.size 99 | th, tw = self.size 100 | x1 = int(round((w - tw) / 2.)) 101 | y1 = int(round((h - th) / 2.)) 102 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 103 | 104 | 105 | class RandomCropWithProb: 106 | def __init__(self, size, p=0.8, consistent=True): 107 | if isinstance(size, numbers.Number): 108 | self.size = (int(size), int(size)) 109 | else: 110 | self.size = size 111 | self.consistent = consistent 112 | self.threshold = p 113 | 114 | def __call__(self, imgmap): 115 | img1 = imgmap[0] 116 | w, h = img1.size 117 | if self.size is not None: 118 | th, tw = self.size 119 | if w == tw and h == th: 120 | return imgmap 121 | if self.consistent: 122 | if random.random() < self.threshold: 123 | x1 = random.randint(0, w - tw) 124 | y1 = random.randint(0, h - th) 125 | else: 126 | x1 = int(round((w - tw) / 2.)) 127 | y1 = int(round((h - th) / 2.)) 128 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 129 | else: 130 | result = [] 131 | for i in imgmap: 132 | if random.random() < self.threshold: 133 | x1 = random.randint(0, w - tw) 134 | y1 = random.randint(0, h - th) 135 | else: 136 | x1 = int(round((w - tw) / 2.)) 137 | y1 = int(round((h - th) / 2.)) 138 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 139 | return result 140 | else: 141 | return imgmap 142 | 143 | 144 | class RandomCrop: 145 | def __init__(self, size, consistent=True): 146 | if isinstance(size, numbers.Number): 147 | self.size = (int(size), int(size)) 148 | else: 149 | self.size = size 150 | self.consistent = consistent 151 | 152 | def __call__(self, imgmap, flowmap=None): 153 | img1 = imgmap[0] 154 | w, h = img1.size 155 | if self.size is not None: 156 | th, tw = self.size 157 | if w == tw and h == th: 158 | return imgmap 159 | if not flowmap: 160 | if self.consistent: 161 | x1 = random.randint(0, w - tw) 162 | y1 = random.randint(0, h - th) 163 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 164 | else: 165 | result = [] 166 | for i in imgmap: 167 | x1 = random.randint(0, w - tw) 168 | y1 = random.randint(0, h - th) 169 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 170 | return result 171 | elif flowmap is not None: 172 | assert (not self.consistent) 173 | result = [] 174 | for idx, i in enumerate(imgmap): 175 | proposal = [] 176 | for j in range(3): # number of proposal: use the one with largest optical flow 177 | x = random.randint(0, w - tw) 178 | y = random.randint(0, h - th) 179 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 180 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 181 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 182 | return result 183 | else: 184 | raise ValueError('wrong case') 185 | else: 186 | return imgmap 187 | 188 | 189 | import torch 190 | 191 | 192 | class RandomIntensityCropForTensors: 193 | def __init__(self, size): 194 | if isinstance(size, numbers.Number): 195 | self.size = (int(size), int(size)) 196 | else: 197 | self.size = size 198 | 199 | def __call__(self, img_tensor_list): 200 | img1 = img_tensor_list[0] 201 | # Expected format 202 | c, h, w = img1.shape 203 | assert c < 20, "Invalid channel size: {}".format(img1.shape) 204 | 205 | if self.size is not None: 206 | th, tw = self.size 207 | if w == tw and h == th: 208 | return img_tensor_list 209 | 210 | proposals = [] 211 | # number of proposal: use the one with largest sum of values 212 | for j in range(3): 213 | x = random.randint(0, w - tw) 214 | y = random.randint(0, h - th) 215 | val = \ 216 | sum([torch.mean(torch.abs(img_tensor_list[idx][:, y:y + th, x:x + tw])) for idx in range(len(img_tensor_list))]) 217 | proposals.append(((x, y), val)) 218 | 219 | (x, y), _ = max(proposals, key=lambda x: x[1]) 220 | crops = [i[:, y:y + th, x:x + tw] for i in img_tensor_list] 221 | return crops 222 | else: 223 | return img_tensor_list 224 | 225 | 226 | class RandomIntensityCropForFlow: 227 | def __init__(self, size): 228 | if isinstance(size, numbers.Number): 229 | self.size = (int(size), int(size)) 230 | else: 231 | self.size = size 232 | 233 | def __call__(self, imgmap): 234 | img1 = imgmap[0] 235 | w, h = img1.size 236 | if self.size is not None: 237 | th, tw = self.size 238 | if w == tw and h == th: 239 | return imgmap 240 | 241 | proposals = [] 242 | 243 | # Process img_arrs 244 | img_arrs = [np.asarray(img, dtype=float) for img in imgmap] 245 | img_arrs = [(img * 0.0) + 127. if np.max(img) < 10.0 else img for img in img_arrs] 246 | # Assuming that flow data passed has mean > 100.0 247 | img_arrs = np.array(img_arrs) - 127. 248 | 249 | # number of proposal: use the one with largest sum of values 250 | for j in range(3): 251 | try: 252 | x = random.randint(0, w - tw) 253 | y = random.randint(0, h - th) 254 | except: 255 | print("Error:", w, h, tw, th, img_arrs.shape) 256 | val = np.mean(np.abs(img_arrs[:, y:y + th, x:x + tw, :])) 257 | proposals.append(((x, y), val)) 258 | 259 | (x, y), _ = max(proposals, key=lambda x: x[1]) 260 | crops = [i.crop((x, y, x + tw, y + th)) for i in imgmap] 261 | return crops 262 | else: 263 | return imgmap 264 | 265 | 266 | class CenterCropForTensors: 267 | def __init__(self, size): 268 | if isinstance(size, numbers.Number): 269 | self.size = (int(size), int(size)) 270 | else: 271 | self.size = size 272 | 273 | def __call__(self, img_tensor_list): 274 | img1 = img_tensor_list[0] 275 | # Expected format 276 | c, h, w = img1.shape 277 | assert c < 20, "Invalid channel size: {}".format(img1.shape) 278 | 279 | th, tw = self.size 280 | x = int(round((w - tw) / 2.)) 281 | y = int(round((h - th) / 2.)) 282 | try: 283 | result = [img_tensor[:, y:y + th, x:x + tw] for img_tensor in img_tensor_list] 284 | except: 285 | print(img_tensor_list[0].shape, y, th, x, tw) 286 | return result 287 | 288 | 289 | class RandomSizedCrop: 290 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 291 | self.size = size 292 | self.interpolation = interpolation 293 | self.consistent = consistent 294 | self.threshold = p 295 | 296 | def __call__(self, imgmap): 297 | img1 = imgmap[0] 298 | if random.random() < self.threshold: # do RandomSizedCrop 299 | for attempt in range(10): 300 | area = img1.size[0] * img1.size[1] 301 | target_area = random.uniform(0.5, 1) * area 302 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 303 | 304 | w = int(round(math.sqrt(target_area * aspect_ratio))) 305 | h = int(round(math.sqrt(target_area / aspect_ratio))) 306 | 307 | if self.consistent: 308 | if random.random() < 0.5: 309 | w, h = h, w 310 | if w <= img1.size[0] and h <= img1.size[1]: 311 | x1 = random.randint(0, img1.size[0] - w) 312 | y1 = random.randint(0, img1.size[1] - h) 313 | 314 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 315 | for i in imgmap: assert(i.size == (w, h)) 316 | 317 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 318 | else: 319 | result = [] 320 | for i in imgmap: 321 | if random.random() < 0.5: 322 | w, h = h, w 323 | if w <= img1.size[0] and h <= img1.size[1]: 324 | x1 = random.randint(0, img1.size[0] - w) 325 | y1 = random.randint(0, img1.size[1] - h) 326 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 327 | assert(result[-1].size == (w, h)) 328 | else: 329 | result.append(i) 330 | 331 | assert len(result) == len(imgmap) 332 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 333 | 334 | # Fallback 335 | scale = Scale(self.size, interpolation=self.interpolation) 336 | crop = CenterCrop(self.size) 337 | return crop(scale(imgmap)) 338 | else: # don't do RandomSizedCrop, do CenterCrop 339 | crop = CenterCrop(self.size) 340 | return crop(imgmap) 341 | 342 | 343 | class RandomHorizontalFlip: 344 | def __init__(self, consistent=True, command=None): 345 | self.consistent = consistent 346 | if command == 'left': 347 | self.threshold = 0 348 | elif command == 'right': 349 | self.threshold = 1 350 | else: 351 | self.threshold = 0.5 352 | 353 | def __call__(self, imgmap): 354 | if self.consistent: 355 | if random.random() < self.threshold: 356 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 357 | else: 358 | return imgmap 359 | else: 360 | result = [] 361 | for i in imgmap: 362 | if random.random() < self.threshold: 363 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 364 | else: 365 | result.append(i) 366 | assert len(result) == len(imgmap) 367 | return result 368 | 369 | 370 | class RandomGray: 371 | '''Actually it is a channel splitting, not strictly grayscale images''' 372 | def __init__(self, consistent=True, p=0.5): 373 | self.consistent = consistent 374 | self.p = p # probability to apply grayscale 375 | 376 | def __call__(self, imgmap): 377 | if self.consistent: 378 | if random.random() < self.p: 379 | return [self.grayscale(i) for i in imgmap] 380 | else: 381 | return imgmap 382 | else: 383 | result = [] 384 | for i in imgmap: 385 | if random.random() < self.p: 386 | result.append(self.grayscale(i)) 387 | else: 388 | result.append(i) 389 | assert len(result) == len(imgmap) 390 | return result 391 | 392 | def grayscale(self, img): 393 | channel = np.random.choice(3) 394 | np_img = np.array(img)[:,:,channel] 395 | np_img = np.dstack([np_img, np_img, np_img]) 396 | img = Image.fromarray(np_img, 'RGB') 397 | return img 398 | 399 | 400 | class ColorJitter(object): 401 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code 402 | Args: 403 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 404 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 405 | or the given [min, max]. Should be non negative numbers. 406 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 407 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 408 | or the given [min, max]. Should be non negative numbers. 409 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 410 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 411 | or the given [min, max]. Should be non negative numbers. 412 | hue (float or tuple of float (min, max)): How much to jitter hue. 413 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 414 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 415 | """ 416 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 417 | self.brightness = self._check_input(brightness, 'brightness') 418 | self.contrast = self._check_input(contrast, 'contrast') 419 | self.saturation = self._check_input(saturation, 'saturation') 420 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 421 | clip_first_on_zero=False) 422 | self.consistent = consistent 423 | self.threshold = p 424 | 425 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 426 | if isinstance(value, numbers.Number): 427 | if value < 0: 428 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 429 | value = [center - value, center + value] 430 | if clip_first_on_zero: 431 | value[0] = max(value[0], 0) 432 | elif isinstance(value, (tuple, list)) and len(value) == 2: 433 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 434 | raise ValueError("{} values should be between {}".format(name, bound)) 435 | else: 436 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 437 | 438 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 439 | # or (0., 0.) for hue, do nothing 440 | if value[0] == value[1] == center: 441 | value = None 442 | return value 443 | 444 | @staticmethod 445 | def get_params(brightness, contrast, saturation, hue): 446 | """Get a randomized transform to be applied on image. 447 | Arguments are same as that of __init__. 448 | Returns: 449 | Transform which randomly adjusts brightness, contrast and 450 | saturation in a random order. 451 | """ 452 | transforms = [] 453 | 454 | if brightness is not None: 455 | brightness_factor = random.uniform(brightness[0], brightness[1]) 456 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 457 | 458 | if contrast is not None: 459 | contrast_factor = random.uniform(contrast[0], contrast[1]) 460 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 461 | 462 | if saturation is not None: 463 | saturation_factor = random.uniform(saturation[0], saturation[1]) 464 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 465 | 466 | if hue is not None: 467 | hue_factor = random.uniform(hue[0], hue[1]) 468 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 469 | 470 | random.shuffle(transforms) 471 | transform = torchvision.transforms.Compose(transforms) 472 | 473 | return transform 474 | 475 | def __call__(self, imgmap): 476 | if random.random() < self.threshold: # do ColorJitter 477 | if self.consistent: 478 | transform = self.get_params(self.brightness, self.contrast, 479 | self.saturation, self.hue) 480 | return [transform(i) for i in imgmap] 481 | else: 482 | result = [] 483 | for img in imgmap: 484 | transform = self.get_params(self.brightness, self.contrast, 485 | self.saturation, self.hue) 486 | result.append(transform(img)) 487 | return result 488 | else: # don't do ColorJitter, do nothing 489 | return imgmap 490 | 491 | def __repr__(self): 492 | format_string = self.__class__.__name__ + '(' 493 | format_string += 'brightness={0}'.format(self.brightness) 494 | format_string += ', contrast={0}'.format(self.contrast) 495 | format_string += ', saturation={0}'.format(self.saturation) 496 | format_string += ', hue={0})'.format(self.hue) 497 | return format_string 498 | 499 | 500 | class RandomRotation: 501 | def __init__(self, consistent=True, degree=15, p=1.0): 502 | self.consistent = consistent 503 | self.degree = degree 504 | self.threshold = p 505 | def __call__(self, imgmap): 506 | if random.random() < self.threshold: # do RandomRotation 507 | if self.consistent: 508 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 509 | return [i.rotate(deg, expand=True) for i in imgmap] 510 | else: 511 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 512 | else: # don't do RandomRotation, do nothing 513 | return imgmap 514 | 515 | class ToTensor: 516 | def __call__(self, imgmap): 517 | totensor = transforms.ToTensor() 518 | return [totensor(i) for i in imgmap] 519 | 520 | class Normalize: 521 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 522 | self.mean = mean 523 | self.std = std 524 | def __call__(self, imgmap): 525 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 526 | return [normalize(i) for i in imgmap] 527 | 528 | 529 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | from datetime import datetime 6 | import glob 7 | import re 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | from collections import deque 11 | from tqdm import tqdm 12 | from torchvision import transforms 13 | 14 | 15 | def save_checkpoint(state, mode, is_best=0, gap=1, filename='models/checkpoint.pth.tar', keep_all=False): 16 | torch.save(state, filename) 17 | last_epoch_path = os.path.join( 18 | os.path.dirname(filename), 'mode_' + mode + '_epoch%s.pth.tar' % str(state['epoch']-gap)) 19 | alternate_last_epoch_path = os.path.join(os.path.dirname(filename), 'epoch%s.pth.tar' % str(state['epoch']-gap)) 20 | if not keep_all: 21 | try: 22 | os.remove(last_epoch_path) 23 | except: 24 | try: 25 | os.remove(alternate_last_epoch_path) 26 | except: 27 | print("Couldn't remove last_epoch_path: ", last_epoch_path, alternate_last_epoch_path) 28 | pass 29 | if is_best: 30 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'mode_' + mode + '_model_best_*.pth.tar')) 31 | for i in past_best: 32 | try: os.remove(i) 33 | except: pass 34 | torch.save( 35 | state, 36 | os.path.join( 37 | os.path.dirname(filename), 38 | 'mode_' + mode + '_model_best_epoch%s.pth.tar' % str(state['epoch']) 39 | ) 40 | ) 41 | 42 | 43 | def write_log(content, epoch, filename): 44 | if not os.path.exists(filename): 45 | log_file = open(filename, 'w') 46 | else: 47 | log_file = open(filename, 'a') 48 | log_file.write('## Epoch %d:\n' % epoch) 49 | log_file.write('time: %s\n' % str(datetime.now())) 50 | log_file.write(content + '\n\n') 51 | log_file.close() 52 | 53 | 54 | def calc_topk_accuracy(output, target, topk=(1,)): 55 | ''' 56 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 57 | Given predicted and ground truth labels, 58 | calculate top-k accuracies. 59 | ''' 60 | maxk = max(topk) 61 | batch_size = target.size(0) 62 | 63 | _, pred = output.topk(maxk, 1, True, True) 64 | pred = pred.t() 65 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 66 | 67 | res = [] 68 | for k in topk: 69 | correct_k = correct[:k].view(-1).float().sum(0) 70 | res.append(correct_k.mul_(1 / batch_size)) 71 | return res 72 | 73 | 74 | def calc_accuracy(output, target): 75 | '''output: (B, N); target: (B)''' 76 | target = target.squeeze() 77 | _, pred = torch.max(output, 1) 78 | return torch.mean((pred == target).float()) 79 | 80 | def calc_accuracy_binary(output, target): 81 | '''output, target: (B, N), output is logits, before sigmoid ''' 82 | pred = output > 0 83 | acc = torch.mean((pred == target.byte()).float()) 84 | del pred, output, target 85 | return acc 86 | 87 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 88 | assert len(mean) == len(std) == 3 89 | inv_mean = [-mean[i]/std[i] for i in range(3)] 90 | inv_std = [1/i for i in std] 91 | return transforms.Normalize(mean=inv_mean, std=inv_std) 92 | 93 | 94 | class AverageMeter(object): 95 | """Computes and stores the average and current value""" 96 | def __init__(self): 97 | self.reset() 98 | 99 | def reset(self): 100 | self.val = 0 101 | self.avg = 0 102 | self.sum = 0 103 | self.count = 0 104 | self.local_history = deque([]) 105 | self.local_avg = 0 106 | self.history = [] 107 | self.dict = {} # save all data values here 108 | self.save_dict = {} # save mean and std here, for summary table 109 | 110 | def update(self, val, n=1, history=0, step=100): 111 | self.val = val 112 | self.sum += val * n 113 | self.count += n 114 | self.avg = self.sum / self.count 115 | if history: 116 | self.history.append(val) 117 | if step > 0: 118 | self.local_history.append(val) 119 | if len(self.local_history) > step: 120 | self.local_history.popleft() 121 | self.local_avg = np.average(self.local_history) 122 | 123 | def dict_update(self, val, key): 124 | if key in self.dict.keys(): 125 | self.dict[key].append(val) 126 | else: 127 | self.dict[key] = [val] 128 | 129 | def __len__(self): 130 | return self.count 131 | 132 | 133 | class AccuracyTable(object): 134 | '''compute accuracy for each class''' 135 | def __init__(self, names): 136 | self.names = names 137 | self.dict = {} 138 | 139 | def update(self, pred, tar): 140 | pred = pred.flatten() 141 | tar = tar.flatten() 142 | for i, j in zip(pred, tar): 143 | i = int(i) 144 | j = int(j) 145 | if j not in self.dict.keys(): 146 | self.dict[j] = {'count':0,'correct':0} 147 | self.dict[j]['count'] += 1 148 | if i == j: 149 | self.dict[j]['correct'] += 1 150 | 151 | def print_table(self): 152 | for key in sorted(self.dict.keys()): 153 | acc = self.dict[key]['correct'] / self.dict[key]['count'] 154 | print('%25s: %5d, acc: %3d/%3d = %0.6f' \ 155 | % (self.names[key], key, self.dict[key]['correct'], self.dict[key]['count'], acc)) 156 | 157 | def print_dict(self): 158 | acc_dict = {} 159 | for key in sorted(self.dict.keys()): 160 | acc_dict[self.names[key].lower()] = self.dict[key]['correct'] / self.dict[key]['count'] 161 | print(acc_dict) 162 | 163 | class ConfusionMeter(object): 164 | '''compute and show confusion matrix''' 165 | def __init__(self, num_class): 166 | self.num_class = num_class 167 | self.mat = np.zeros((num_class, num_class)) 168 | self.precision = [] 169 | self.recall = [] 170 | 171 | def update(self, pred, tar): 172 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy() 173 | pred = np.squeeze(pred) 174 | tar = np.squeeze(tar) 175 | for p,t in zip(pred.flat, tar.flat): 176 | self.mat[p][t] += 1 177 | 178 | def print_mat(self): 179 | print('Confusion Matrix: (target in columns)') 180 | print(self.mat) 181 | 182 | def plot_mat(self, path, dictionary=None, annotate=False): 183 | plt.figure(dpi=600) 184 | plt.imshow(self.mat, 185 | cmap=plt.cm.jet, 186 | interpolation=None, 187 | extent=(0.5, np.shape(self.mat)[0]+0.5, np.shape(self.mat)[1]+0.5, 0.5)) 188 | width, height = self.mat.shape 189 | if annotate: 190 | for x in range(width): 191 | for y in range(height): 192 | plt.annotate(str(int(self.mat[x][y])), xy=(y+1, x+1), 193 | horizontalalignment='center', 194 | verticalalignment='center', 195 | fontsize=8) 196 | 197 | if dictionary is not None: 198 | plt.xticks([i+1 for i in range(width)], 199 | [dictionary[i] for i in range(width)], 200 | rotation='vertical') 201 | plt.yticks([i+1 for i in range(height)], 202 | [dictionary[i] for i in range(height)]) 203 | plt.xlabel('Ground Truth') 204 | plt.ylabel('Prediction') 205 | plt.colorbar() 206 | plt.tight_layout() 207 | plt.savefig(path, format='svg') 208 | plt.clf() 209 | 210 | # for i in range(width): 211 | # if np.sum(self.mat[i,:]) != 0: 212 | # self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:])) 213 | # if np.sum(self.mat[:,i]) != 0: 214 | # self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i])) 215 | # print('Average Precision: %0.4f' % np.mean(self.precision)) 216 | # print('Average Recall: %0.4f' % np.mean(self.recall)) 217 | 218 | 219 | 220 | 221 | --------------------------------------------------------------------------------