├── train ├── data_utils.py ├── finetune_utils.py ├── mask_utils.py ├── sim_utils.py ├── hierarchical_trainer.py └── model_utils.py ├── LICENSE ├── backbone ├── select_backbone.py ├── vgg.py ├── convrnn.py └── resnet_2d3d.py ├── process_data ├── readme.md └── src │ ├── build_rawframes_optimized.py │ ├── extract_frame.py │ ├── write_csv.py │ └── extract_features.py ├── requirements.txt ├── test ├── model_3d_lc.py ├── transform_utils.py ├── dataset_3d_lc.py └── test.py ├── README.md └── utils ├── utils.py └── augmentation.py /train/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | 5 | def individual_collate(batch): 6 | """ 7 | Custom collation function for collate with new implementation of individual samples in data pipeline 8 | """ 9 | 10 | data = batch 11 | 12 | collected_data = defaultdict(list) 13 | 14 | for i in range(len(list(data))): 15 | for k in data[i].keys(): 16 | collected_data[k].append(data[i][k]) 17 | 18 | for k in collected_data.keys(): 19 | collected_data[k] = torch.stack(collected_data[k]) 20 | 21 | return collected_data 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nishant Rai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /backbone/select_backbone.py: -------------------------------------------------------------------------------- 1 | from resnet_2d3d import * 2 | 3 | def select_resnet(network, track_running_stats=True, in_channels=3): 4 | param = {'feature_size': 1024} 5 | if network == 'resnet18': 6 | model = resnet18_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 7 | param['feature_size'] = 256 8 | elif network == 'resnet34': 9 | model = resnet34_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 10 | param['feature_size'] = 256 11 | elif network == 'resnet50': 12 | model = resnet50_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 13 | elif network == 'resnet101': 14 | model = resnet101_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 15 | elif network == 'resnet152': 16 | model = resnet152_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 17 | elif network == 'resnet200': 18 | model = resnet200_2d3d_full(track_running_stats=track_running_stats, in_channels=in_channels) 19 | else: raise IOError('model type is wrong') 20 | 21 | return model, param -------------------------------------------------------------------------------- /process_data/readme.md: -------------------------------------------------------------------------------- 1 | ## Process data 2 | 3 | This folder has some tools to process UCF101, HMDB51 and Kinetics400 datasets. 4 | 5 | ### 1. Download 6 | 7 | Download the videos from source: 8 | [UCF101 source](https://www.crcv.ucf.edu/data/UCF101.php), 9 | [HMDB51 source](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads), 10 | [Kinetics400 source](https://deepmind.com/research/publications/kinetics-human-action-video-dataset). 11 | 12 | Make sure datasets are stored as follows: 13 | 14 | * UCF101 15 | ``` 16 | {your_path}/UCF101/videos/{action class}/{video name}.avi 17 | {your_path}/UCF101/splits_classification/trainlist{01/02/03}.txt 18 | {your_path}/UCF101/splits_classification/testlist{01/02/03}}.txt 19 | ``` 20 | 21 | * HMDB51 22 | ``` 23 | {your_path}/HMDB51/videos/{action class}/{video name}.avi 24 | {your_path}/HMDB51/split/testTrainMulti_7030_splits/{action class}_test_split{1/2/3}.txt 25 | ``` 26 | 27 | * Kinetics400 28 | ``` 29 | {your_path}/Kinetics400/videos/train_split/{action class}/{video name}.mp4 30 | {your_path}/Kinetics400/videos/val_split/{action class}/{video name}.mp4 31 | ``` 32 | Also keep the downloaded csv files, make sure you have: 33 | ``` 34 | {your_path}/Kinetics/kinetics_train/kinetics_train.csv 35 | {your_path}/Kinetics/kinetics_val/kinetics_val.csv 36 | {your_path}/Kinetics/kinetics_test/kinetics_test.csv 37 | ``` 38 | 39 | ### 2. Extract frames 40 | 41 | Edit path arguments in `main_*()` functions, and `python extract_frame.py`. Video frames will be extracted. 42 | 43 | ### 3. Collect all paths into csv 44 | 45 | Edit path arguments in `main_*()` functions, and `python write_csv.py`. csv files will be stored in `data/` directory. 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | astor==0.8.0 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | cachetools==3.1.1 7 | calmsize==0.1.3 8 | certifi==2019.11.28 9 | cffi==1.13.1 10 | chardet==3.0.4 11 | cloudpickle==1.2.2 12 | coclust==0.2.1 13 | cycler==0.10.0 14 | Cython==0.29.14 15 | cytoolz==0.10.0 16 | dask==2.6.0 17 | decorator==4.4.1 18 | defusedxml==0.6.0 19 | detectron2==0.1+cu92 20 | docopt==0.6.2 21 | entrypoints==0.3 22 | environment-kernels==1.1.1 23 | future==0.18.2 24 | fvcore==0.1.dev200112 25 | gast==0.2.2 26 | google-auth==1.7.1 27 | google-auth-oauthlib==0.4.1 28 | google-pasta==0.1.8 29 | grpcio==1.25.0 30 | h5py==2.10.0 31 | idna==2.8 32 | imagecodecs==2020.2.18 33 | imageio==2.6.1 34 | importlib-metadata==0.23 35 | ipdb==0.12.3 36 | ipykernel==5.1.3 37 | ipython==7.9.0 38 | ipython-genutils==0.2.0 39 | ipywidgets==7.5.1 40 | jedi==0.15.1 41 | Jinja2==2.10.3 42 | joblib==0.14.0 43 | jsonschema==3.1.1 44 | jupyter==1.0.0 45 | jupyter-client==5.3.4 46 | jupyter-console==6.0.0 47 | jupyter-core==4.6.1 48 | Keras-Applications==1.0.8 49 | Keras-Preprocessing==1.1.0 50 | kiwisolver==1.1.0 51 | line-profiler==2.1.2 52 | Markdown==3.1.1 53 | MarkupSafe==1.1.1 54 | matplotlib==3.1.1 55 | mistune==0.8.4 56 | mkl-fft==1.0.14 57 | mkl-random==1.1.0 58 | mkl-service==2.3.0 59 | more-itertools==7.2.0 60 | nb-conda==2.2.1 61 | nb-conda-kernels==2.2.2 62 | nbconvert==5.6.1 63 | nbformat==4.4.0 64 | networkx==2.4 65 | nltk==3.4.5 66 | notebook==6.0.1 67 | numpy==1.17.0 68 | oauthlib==3.1.0 69 | olefile==0.46 70 | opencv-python==4.1.1.26 71 | opt-einsum==3.1.0 72 | pandas==0.25.2 73 | pandocfilters==1.4.2 74 | parso==0.5.1 75 | pexpect==4.7.0 76 | pickleshare==0.7.5 77 | Pillow==6.2.2 78 | portalocker==1.5.2 79 | prometheus-client==0.7.1 80 | prompt-toolkit==2.0.10 81 | protobuf==3.10.0 82 | ptyprocess==0.6.0 83 | pyasn1==0.4.8 84 | pyasn1-modules==0.2.7 85 | pycparser==2.19 86 | pydot==1.4.1 87 | Pygments==2.4.2 88 | pyparsing==2.4.2 89 | pyrsistent==0.15.4 90 | python-dateutil==2.8.0 91 | pytorch-memlab==0.1.0 92 | pytz==2019.3 93 | PyWavelets==1.1.1 94 | PyYAML==5.3 95 | pyzmq==18.1.0 96 | qtconsole==4.5.5 97 | requests==2.22.0 98 | requests-oauthlib==1.3.0 99 | resample2d-cuda==0.0.0 100 | rsa==4.0 101 | scikit-image==0.15.0 102 | scikit-learn==0.21.3 103 | scipy==1.1.0 104 | seaborn==0.10.1 105 | Send2Trash==1.5.0 106 | sewar==0.4.2 107 | six==1.12.0 108 | tabulate==0.8.6 109 | tensorboard==2.0.1 110 | tensorboardX==1.9 111 | tensorflow==2.0.0 112 | tensorflow-estimator==2.0.1 113 | termcolor==1.1.0 114 | terminado==0.8.2 115 | testpath==0.4.2 116 | tifffile==2020.2.16 117 | toolz==0.10.0 118 | torch==1.3.0 119 | torchvision==0.4.1 120 | tornado==6.0.3 121 | tqdm==4.42.0 122 | traitlets==4.3.3 123 | traj-conv-cuda==0.0.0 124 | tsne==0.1.8 125 | urllib3==1.25.7 126 | wcwidth==0.1.7 127 | webencodings==0.5.1 128 | Werkzeug==0.16.0 129 | widgetsnbextension==3.5.1 130 | wrapt==1.11.2 131 | yacs==0.1.6 132 | zipp==0.6.0 133 | -------------------------------------------------------------------------------- /test/model_3d_lc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | sys.path.append('../backbone') 5 | from select_backbone import select_resnet 6 | from convrnn import ConvGRU 7 | 8 | sys.path.append('../train') 9 | import model_utils as mu 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class LC(nn.Module): 17 | def __init__(self, sample_size, num_seq, seq_len, in_channels, 18 | network='resnet18', dropout=0.5, num_class=101): 19 | super(LC, self).__init__() 20 | torch.cuda.manual_seed(666) 21 | self.sample_size = sample_size 22 | self.num_seq = num_seq 23 | self.seq_len = seq_len 24 | self.num_class = num_class 25 | self.in_channels = in_channels 26 | print('=> Using RNN + FC model with ic:', self.in_channels) 27 | 28 | print('=> Use 2D-3D %s!' % network) 29 | self.last_duration = int(math.ceil(seq_len / 4)) 30 | self.last_size = int(math.ceil(sample_size / 32)) 31 | track_running_stats = True 32 | 33 | self.backbone, self.param = \ 34 | select_resnet(network, track_running_stats=track_running_stats, in_channels=self.in_channels) 35 | self.param['num_layers'] = 1 36 | self.param['hidden_size'] = self.param['feature_size'] 37 | 38 | print('=> using ConvRNN, kernel_size = 1') 39 | self.agg = ConvGRU(input_size=self.param['feature_size'], 40 | hidden_size=self.param['hidden_size'], 41 | kernel_size=1, 42 | num_layers=self.param['num_layers']) 43 | self._initialize_weights(self.agg) 44 | 45 | self.final_bn = nn.BatchNorm1d(self.param['feature_size']) 46 | self.final_bn.weight.data.fill_(1) 47 | self.final_bn.bias.data.zero_() 48 | 49 | self.num_classes = num_class 50 | self.dropout = dropout 51 | self.hidden_size = 128 52 | self.final_fc = nn.Sequential( 53 | nn.Dropout(self.dropout), 54 | nn.Linear(self.param['feature_size'], self.num_classes), 55 | ) 56 | 57 | self._initialize_weights(self.final_fc) 58 | 59 | def forward(self, block): 60 | # seq1: [B, N, C, SL, W, H] 61 | (B, N, C, SL, H, W) = block.shape 62 | block = block.view(B*N, C, SL, H, W) 63 | feature = self.backbone(block) 64 | del block 65 | # TODO: Do we need ReLU 66 | # feature = F.relu(feature) 67 | 68 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=1) 69 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B*N,D,last_size,last_size] 70 | context, _ = self.agg(feature) 71 | context = context[:,-1,:].unsqueeze(1) 72 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 73 | del feature 74 | 75 | context = self.final_bn(context.transpose(-1,-2)).transpose(-1,-2) # [B,N,C] -> [B,C,N] -> BN() -> [B,N,C], because BN operates on id=1 channel. 76 | output = self.final_fc(context).view(B, -1, self.num_class) 77 | 78 | return output, context 79 | 80 | def _initialize_weights(self, module): 81 | for name, param in module.named_parameters(): 82 | if 'bias' in name: 83 | nn.init.constant_(param, 0.0) 84 | elif 'weight' in name: 85 | nn.init.orthogonal_(param, 1) 86 | # other resnet weights have been initialized in resnet_3d.py 87 | 88 | 89 | -------------------------------------------------------------------------------- /backbone/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model with classification 18 | ''' 19 | def __init__(self, features): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, 10), 30 | ) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A'])) 74 | 75 | 76 | def vgg11_bn(): 77 | """VGG 11-layer model (configuration "A") with batch normalization""" 78 | return VGG(make_layers(cfg['A'], batch_norm=True)) 79 | 80 | 81 | def vgg13(): 82 | """VGG 13-layer model (configuration "B")""" 83 | return VGG(make_layers(cfg['B'])) 84 | 85 | 86 | def vgg13_bn(): 87 | """VGG 13-layer model (configuration "B") with batch normalization""" 88 | return VGG(make_layers(cfg['B'], batch_norm=True)) 89 | 90 | 91 | def vgg16(): 92 | """VGG 16-layer model (configuration "D")""" 93 | return VGG(make_layers(cfg['D'])) 94 | 95 | 96 | def vgg16_bn(): 97 | """VGG 16-layer model (configuration "D") with batch normalization""" 98 | return VGG(make_layers(cfg['D'], batch_norm=True)) 99 | 100 | 101 | def vgg19(): 102 | """VGG 19-layer model (configuration "E")""" 103 | return VGG(make_layers(cfg['E'])) 104 | 105 | 106 | def vgg19_bn(): 107 | """VGG 19-layer model (configuration 'E') with batch normalization""" 108 | return VGG(make_layers(cfg['E'], batch_norm=True)) 109 | 110 | 111 | def vgg19_custom(): 112 | """VGG 19-layer tweaked model (configuration 'F') with batch normalization""" 113 | return VGG(make_layers(cfg['F'], batch_norm=True)) -------------------------------------------------------------------------------- /train/finetune_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sklearn import metrics 4 | from sklearn.linear_model import RidgeClassifier 5 | from sklearn.cluster import MiniBatchKMeans 6 | 7 | 8 | class QuickSupervisedModelTrainer(object): 9 | 10 | def __init__(self, num_classes, modes): 11 | self.modes = modes 12 | self.mode_pairs = [(m0, m1) for m0 in self.modes for m1 in self.modes if m0 < m1] 13 | self.ridge = {m: RidgeClassifier() for m in self.modes} 14 | self.kmeans = {m: MiniBatchKMeans(n_clusters=num_classes, random_state=0, batch_size=256) for m in self.modes} 15 | 16 | def evaluate_classification(self, trainD, valD): 17 | tic = time.time() 18 | trainY, valY = trainD["Y"].cpu().numpy(), valD["Y"].cpu().numpy() 19 | print(trainY.shape, valY.shape) 20 | for mode in self.modes: 21 | trainX, valX = trainD["X"][mode].cpu().numpy(), valD["X"][mode].cpu().numpy() 22 | print(trainX.shape, valX.shape) 23 | self.ridge[mode].fit(trainX, trainY) 24 | score = round(self.ridge[mode].score(valX, valY), 3) 25 | print("--- Mode: {} - RidgeAcc: {}".format(mode, score)) 26 | print("Time taken to perform classification evaluation:", time.time() - tic) 27 | 28 | def fit_and_predict_clustering(self, data, tag): 29 | tic = time.time() 30 | preds = {} 31 | for mode in self.modes: 32 | preds[mode] = self.kmeans[mode].fit_predict(data["X"][mode].cpu().numpy()) 33 | print("Time taken to perform {} clustering:".format(tag), time.time() - tic) 34 | return preds 35 | 36 | def evaluate_clustering_based_on_ground_truth(self, preds, label, tag): 37 | tic = time.time() 38 | return_dict = {} 39 | for mode in self.modes: 40 | ars = round(metrics.adjusted_rand_score(preds[mode], label), 3) 41 | v_measure = round(metrics.v_measure_score(preds[mode], label), 3) 42 | print("--- Mode: {} - Adj Rand. Score: {}, V-Measure: {}".format(mode, ars, v_measure)) 43 | return_dict[mode] = dict(ars=ars, v_measure=v_measure) 44 | print("Time taken to evaluate {} clustering:".format(tag), time.time() - tic) 45 | return return_dict 46 | 47 | def evaluate_clustering_based_on_mutual_information(self, preds, tag): 48 | tic = time.time() 49 | return_dict = {} 50 | for m0, m1 in self.mode_pairs: 51 | ami = round(metrics.adjusted_mutual_info_score(preds[m0], preds[m1], average_method='max'), 3) 52 | v_measure = round(metrics.v_measure_score(preds[m0], preds[m1]), 3) 53 | print("--- Modes: {}/{} - Adj MI: {}, V Measure: {}".format(m0, m1, ami, v_measure)) 54 | return_dict['{}_{}'.format(m0, m1)] = dict(ami=ami, v_measure=v_measure) 55 | print("Time taken to evaluate {} clustering MI:".format(tag), time.time() - tic) 56 | return return_dict 57 | 58 | def evaluate_clustering(self, data, tag): 59 | ''' 60 | Need to evaluate clustering using the following methods, 61 | 1. Correctness of clustering based on ground truth labels 62 | a. Adjusted Rand Score 63 | b. Homogeneity, completeness and V-measure 64 | 2. Mutual information based scores (across modalities) 65 | ''' 66 | label = data["Y"].cpu().numpy() 67 | 68 | preds = self.fit_and_predict_clustering(data, tag) 69 | return_dict = {} 70 | return_dict['gt'] = self.evaluate_clustering_based_on_ground_truth(preds, label, tag) 71 | return_dict['mi'] = self.evaluate_clustering_based_on_mutual_information(preds, tag) 72 | 73 | return return_dict 74 | -------------------------------------------------------------------------------- /backbone/convrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvGRUCell(nn.Module): 5 | ''' Initialize ConvGRU cell ''' 6 | def __init__(self, input_size, hidden_size, kernel_size): 7 | super(ConvGRUCell, self).__init__() 8 | self.input_size = input_size 9 | self.hidden_size = hidden_size 10 | self.kernel_size = kernel_size 11 | padding = kernel_size // 2 12 | 13 | self.reset_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 14 | self.update_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 15 | self.out_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 16 | 17 | nn.init.orthogonal_(self.reset_gate.weight) 18 | nn.init.orthogonal_(self.update_gate.weight) 19 | nn.init.orthogonal_(self.out_gate.weight) 20 | nn.init.constant_(self.reset_gate.bias, 0.0) 21 | nn.init.constant_(self.update_gate.bias, 0.0) 22 | nn.init.constant_(self.out_gate.bias, 0.) 23 | 24 | def forward(self, input_tensor, hidden_state): 25 | if hidden_state is None: 26 | B, C, *spatial_dim = input_tensor.size() 27 | hidden_state = torch.zeros([B,self.hidden_size,*spatial_dim]).to(input_tensor.device) 28 | # [B, C, H, W] 29 | combined = torch.cat([input_tensor, hidden_state], dim=1) #concat in C 30 | update = torch.sigmoid(self.update_gate(combined)) 31 | reset = torch.sigmoid(self.reset_gate(combined)) 32 | out = torch.tanh(self.out_gate(torch.cat([input_tensor, hidden_state * reset], dim=1))) 33 | new_state = hidden_state * (1 - update) + out * update 34 | return new_state 35 | 36 | 37 | class ConvGRU(nn.Module): 38 | ''' Initialize a multi-layer Conv GRU ''' 39 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, dropout=0.1): 40 | super(ConvGRU, self).__init__() 41 | self.input_size = input_size 42 | self.hidden_size = hidden_size 43 | self.kernel_size = kernel_size 44 | self.num_layers = num_layers 45 | 46 | cell_list = [] 47 | for i in range(self.num_layers): 48 | if i == 0: 49 | input_dim = self.input_size 50 | else: 51 | input_dim = self.hidden_size 52 | cell = ConvGRUCell(input_dim, self.hidden_size, self.kernel_size) 53 | name = 'ConvGRUCell_' + str(i).zfill(2) 54 | 55 | setattr(self, name, cell) 56 | cell_list.append(getattr(self, name)) 57 | 58 | self.cell_list = nn.ModuleList(cell_list) 59 | self.dropout_layer = nn.Dropout(p=dropout) 60 | 61 | def forward(self, x, hidden_state=None): 62 | [B, seq_len, *_] = x.size() 63 | 64 | if hidden_state is None: 65 | hidden_state = [None] * self.num_layers 66 | # input: image sequences [B, T, C, H, W] 67 | current_layer_input = x 68 | del x 69 | 70 | last_state_list = [] 71 | 72 | for idx in range(self.num_layers): 73 | cell_hidden = hidden_state[idx] 74 | output_inner = [] 75 | for t in range(seq_len): 76 | cell_hidden = self.cell_list[idx](current_layer_input[:,t,:], cell_hidden) 77 | cell_hidden = self.dropout_layer(cell_hidden) # dropout in each time step 78 | output_inner.append(cell_hidden) 79 | 80 | layer_output = torch.stack(output_inner, dim=1) 81 | current_layer_input = layer_output 82 | 83 | last_state_list.append(cell_hidden) 84 | 85 | last_state_list = torch.stack(last_state_list, dim=1) 86 | 87 | return layer_output, last_state_list 88 | 89 | 90 | if __name__ == '__main__': 91 | crnn = ConvGRU(input_size=10, hidden_size=20, kernel_size=3, num_layers=2) 92 | data = torch.randn(4, 5, 10, 6, 6) # [B, seq_len, C, H, W], temporal axis=1 93 | output, hn = crnn(data) 94 | import ipdb; ipdb.set_trace() 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Home Action Genome 2 | 3 | The repository for our work [Home Action Genome: Cooperative Contrastive Action Understanding](https://arxiv.org/abs/2105.05226) presented in CVPR '21. 4 | 5 | This repository contains the implementation of Home Action Genome: Cooperative Compositional Contrastive Learning. We release a multi-view action dataset with multiple modalities and view-points supplemented with hierarchical activity and atomic action labels together with dense scene composition labels, along with a supplementary approach to leverage such rich annotations. 6 | 7 | > Existing research on action recognition treats activities as monolithic events occurring in videos. Recently, the benefits of formulating actions as a combination of atomic-actions have shown promise in improving action understanding with the emergence of datasets containing such annotations, allowing us to learn representations capturing this information. However, there remains a lack of studies that extend action composition and leverage multiple viewpoints and multiple modalities of data for representation learning. To promote research in this direction, we introduce Home Action Genome (HOMAGE): a multi-view action dataset with multiple modalities and view-points supplemented with hierarchical activity and atomic action labels together with dense scene composition labels. Leveraging rich multi-modal and multi-view settings, we propose Cooperative Compositional Action Understanding (CCAU), a cooperative learning framework for hierarchical action recognition that is aware of compositional action elements. CCAU shows consistent performance improvements across all modalities. Furthermore, we demonstrate the utility of co-learning compositions in few-shot action recognition by achieving 28.6% mAP with just a single sample. 8 | 9 | ### Installation 10 | 11 | Our implementation should work with python >= 3.6, pytorch >= 0.4, torchvision >= 0.2.2. The repo also requires cv2 12 | (`conda install -c menpo opencv`), tensorboardX >= 1.7 (`pip install tensorboardX`), tqdm. 13 | 14 | A requirements.txt has been provided which can be used to create the exact environment required. 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Prepare data 20 | 21 | Follow the instructions [here](process_data/). 22 | 23 | ## Data 24 | 25 | ### Multi Camera Perspective Videos 26 | 27 | ![homage_dataset](https://user-images.githubusercontent.com/7645118/123186633-6a8b0c80-d44d-11eb-8928-82fbf3d06eb7.png) 28 | 29 | ### Atomic Action Annotations 30 | 31 | ![dataset_annotations](https://user-images.githubusercontent.com/7645118/123186626-65c65880-d44d-11eb-85b9-9bc1a15102a1.png) 32 | 33 | ### Frame-level Scene Graph Annotation 34 | 35 | ![scene_graph](https://user-images.githubusercontent.com/7645118/123186630-68c14900-d44d-11eb-84bb-523edc580ba1.png) 36 | 37 | ### Cooperative Compositional Contrastive Learning (CCAU) 38 | 39 | Training scripts are present in `cd homage/train/` 40 | 41 | Run `python model_trainer.py --help` to get details about the command lines args. The most useful ones are `--dataset` and `--modalities`, which are used to change the dataset we're supposed to run our experiments along with the input modalities to use. 42 | 43 | Our implementation has been tested with Ego-view RGB Images, Third-Person view RGB Images and Audio. However, it is easy to extend it to custom views; look at `dataset_3d.py` for details. 44 | 45 | * Single View Training: train CCAU using 2 GPUs, using only ego and third-person RGB inputs, with a 3D-ResNet18 backbone, with 224x224 resolution, for 100 epochs. Batch size is per-gpu. 46 | ``` 47 | CUDA_VISIBLE_DEVICES="0,1" python model_trainer.py --net resnet18 --modalities imgs 48 | --batch_size 16 --img_dim 224 --epochs 100 49 | ``` 50 | 51 | * Multi-View Training: train CCAU using 4 GPUs, using ego and third-person RGB views and Audio with a 3D-ResNet18 backbone, with 128x128 resolution, for 100 epochs 52 | ``` 53 | CUDA_VISIBLE_DEVICES="0,1,2,3" python model_trainer.py --net resnet18 --modalities imgs_audio --batch_size 16 --img_dim 128 --epochs 100 54 | ``` 55 | 56 | ### Evaluation: Video Action Recognition 57 | 58 | Testing scripts are present as part of `homage/train/model_trainer.py` under the `--test` flag as well as in scripts present in `cd homage/test/`. 59 | 60 | ## Citing 61 | 62 | If our paper or dataset was useful to you, please consider citing it using the below. 63 | ~~~ 64 | @InProceedings{Rai_2021_CVPR, 65 | author = {Rai, Nishant and Chen, Haofeng and Ji, Jingwei and Desai, Rishi and Kozuka, Kazuki and Ishizaka, Shun and Adeli, Ehsan and Niebles, Juan Carlos}, 66 | title = {Home Action Genome: Cooperative Compositional Action Understanding}, 67 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 68 | month = {June}, 69 | year = {2021}, 70 | pages = {11184-11193} 71 | } 72 | ~~~ 73 | 74 | ### Acknowledgements 75 | 76 | Portions of code have been borrowed from [CoCon](https://github.com/nishantrai18/cocon). Feel free to refer to it as well if you're interested in the field. 77 | -------------------------------------------------------------------------------- /train/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_standard_grid_mask(batch_size0, batch_size1, pred_step, last_size, device="cuda"): 5 | B0, B1, N, LS = batch_size0, batch_size1, pred_step, last_size 6 | device = torch.device(device) 7 | 8 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 9 | 10 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 11 | mask = torch.zeros((B0, N, LS ** 2, B1, N, LS ** 2), dtype=torch.int8, requires_grad=False).detach().to(device) 12 | # spatial neg pairs 13 | mask[torch.arange(B0), :, :, torch.arange(B0), :, :] = -3 14 | # temporal neg pairs 15 | for k in range(B0): 16 | mask[k, :, torch.arange(LS ** 2), k, :, torch.arange(LS ** 2)] = -1 17 | tmp = mask.permute(0, 2, 1, 3, 5, 4).contiguous().view(B0 * LS ** 2, N, B1 * LS ** 2, N) 18 | # positive pairs 19 | for j in range(B0 * LS ** 2): 20 | tmp[j, torch.arange(N), j, torch.arange(N - N, N)] = 1 21 | mask = tmp.view(B0, LS ** 2, N, B1, LS ** 2, N).permute(0, 2, 1, 3, 5, 4) 22 | # Final shape: (B, N, LS**2, B, N, LS**2) 23 | assert torch.allclose(mask[:, :, :, B0:, :, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 24 | 25 | return mask 26 | 27 | 28 | def get_multi_modal_grid_mask(batch_size0, batch_size1, pred_step, last_size0, last_size1, device="cuda"): 29 | B0, B1, N, LS0, LS1 = batch_size0, batch_size1, pred_step, last_size0, last_size1 30 | device = torch.device(device) 31 | 32 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 33 | 34 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 35 | mask = torch.zeros((B0, N, LS0 ** 2, B1, N, LS1 ** 2), dtype=torch.int8, requires_grad=False).detach().to(device) 36 | # spatial neg pairs 37 | mask[torch.arange(B0), :, :, torch.arange(B0), :, :] = -3 38 | 39 | # temporal neg pairs 40 | for k in range(B0): 41 | mask[k, :, torch.arange(LS0 ** 2), k, :, torch.arange(LS1 ** 2)] = -1 42 | tmp = mask.permute(0, 2, 1, 3, 5, 4).contiguous().view(B0, LS0, LS0, N, B1, LS1, LS1, N) 43 | # shape: (B, LS0, LS0, N, B, LS1, LS1, N) 44 | 45 | # Generate downsamplings 46 | ds0, ds1 = LS0 // min(LS0, LS1), LS1 // min(LS0, LS1) 47 | 48 | # positive pairs 49 | for j in range(B0): 50 | for i in range(min(LS0, LS1)): 51 | tmp[j, i * ds0:(i + 1) * ds0, i * ds0:(i + 1) * ds0, torch.arange(N), 52 | j, i * ds1:(i + 1) * ds1, i * ds1:(i + 1) * ds1, torch.arange(N)] = 1 53 | 54 | # Sanity check 55 | for ib in range(B0): 56 | for jn in range(N): 57 | for jls0 in range(LS0): 58 | for jls1 in range(LS1): 59 | for jls01 in range(LS0): 60 | for jls11 in range(LS1): 61 | # Check that values match 62 | if (jls0 // ds0) == (jls1 // ds1) == (jls01 // ds0) == (jls11 // ds1): 63 | assert tmp[ib, jls0, jls01, jn, ib, jls1, jls11, jn] == 1, \ 64 | "Invalid value at {}".format((ib, jls0, jls01, jn, ib, jls1, jls11, jn)) 65 | else: 66 | assert tmp[ib, jls0, jls01, jn, ib, jls1, jls11, jn] < 1, \ 67 | "Invalid value at {}".format((ib, jls0, jls01, jn, ib, jls1, jls11, jn)) 68 | assert torch.allclose(tmp[:, :, :, :, B0:, :, :, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 69 | 70 | mask = tmp.view(B0, LS0 ** 2, N, B1, LS1 ** 2, N).permute(0, 2, 1, 3, 5, 4) 71 | # Shape: (B, N, LS0**2, B, N, LS1**2) 72 | mask = mask.contiguous().view(B0, N * LS0 ** 2, B1, N * LS1 ** 2) 73 | 74 | return mask 75 | 76 | 77 | def get_standard_instance_mask(batch_size0, batch_size1, pred_step, device="cuda"): 78 | B0, B1, N = batch_size0, batch_size1, pred_step 79 | device = torch.device(device) 80 | 81 | assert B0 <= B1, "Invalid B0, B1: {} {}".format(B0, B1) 82 | 83 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 84 | mask = torch.zeros((B0, N, B1, N), dtype=torch.int8, requires_grad=False).detach().to(device) 85 | # temporal neg pairs 86 | for k in range(B0): 87 | mask[k, :, k, :] = -1 88 | # positive pairs 89 | for j in range(B0): 90 | mask[j, torch.arange(N), j, torch.arange(N)] = 1 91 | for i in range(B0): 92 | for j in range(N): 93 | assert mask[i, j, i, j] == 1, "Invalid value at {}, {}".format(i, j) 94 | for xi in range(B0): 95 | if i == xi: 96 | continue 97 | for xj in range(N): 98 | if j == xj: 99 | continue 100 | assert mask[i, j, xi, xj] < 1, "Invalid value at {}, {}".format(i, j) 101 | assert torch.allclose(mask[:, :, B0:, :], torch.tensor(0, dtype=torch.int8)), "Invalid values" 102 | 103 | return mask 104 | 105 | 106 | def process_mask(mask): 107 | # dot product is computed in parallel gpus, so get less easy neg, bounded by batch size in each gpu''' 108 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 109 | target = mask == 1 110 | # This doesn't seem to cause any issues in our implementation 111 | target.requires_grad = False 112 | return target 113 | -------------------------------------------------------------------------------- /test/transform_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | sys.path.append('../utils') 5 | from augmentation import * 6 | 7 | sys.path.append('../train') 8 | import model_utils as mu 9 | 10 | 11 | def get_train_transforms(args): 12 | if args.modality == mu.ImgMode: 13 | return get_imgs_train_transforms(args) 14 | elif args.modality == mu.FlowMode: 15 | return get_flow_transforms(args) 16 | elif args.modality == mu.KeypointHeatmap: 17 | return get_heatmap_transforms(args) 18 | elif args.modality == mu.SegMask: 19 | return get_segmask_transforms(args) 20 | 21 | 22 | def get_val_transforms(args): 23 | if args.modality == mu.ImgMode: 24 | return get_imgs_val_transforms(args) 25 | elif args.modality == mu.FlowMode: 26 | return get_flow_transforms(args) 27 | elif args.modality == mu.KeypointHeatmap: 28 | return get_heatmap_transforms(args) 29 | elif args.modality == mu.SegMask: 30 | return get_segmask_transforms(args) 31 | 32 | 33 | def get_test_transforms(args): 34 | if args.modality == mu.ImgMode: 35 | return get_imgs_test_transforms(args) 36 | elif args.modality == mu.FlowMode: 37 | return get_flow_test_transforms(args) 38 | elif args.modality == mu.KeypointHeatmap: 39 | return get_heatmap_test_transforms(args) 40 | elif args.modality == mu.SegMask: 41 | return get_segmask_test_transforms(args) 42 | 43 | 44 | def get_imgs_test_transforms(args): 45 | 46 | transform = transforms.Compose([ 47 | RandomSizedCrop(consistent=True, size=224, p=0.0), 48 | Scale(size=(args.img_dim, args.img_dim)), 49 | ToTensor(), 50 | Normalize() 51 | ]) 52 | 53 | return transform 54 | 55 | 56 | def get_flow_test_transforms(args): 57 | center_crop_size = 224 58 | if args.dataset == 'kinetics': 59 | center_crop_size = 128 60 | 61 | transform = transforms.Compose([ 62 | CenterCrop(size=center_crop_size, consistent=True), 63 | Scale(size=(args.img_dim, args.img_dim)), 64 | ToTensor(), 65 | ]) 66 | 67 | return transform 68 | 69 | 70 | def get_heatmap_test_transforms(_): 71 | transform = transforms.Compose([ 72 | CenterCropForTensors(size=192), 73 | ScaleForTensors(size=(64, 64)), 74 | ]) 75 | return transform 76 | 77 | 78 | def get_segmask_test_transforms(_): 79 | transform = transforms.Compose([ 80 | CenterCropForTensors(size=192), 81 | ScaleForTensors(size=(64, 64)), 82 | ]) 83 | return transform 84 | 85 | 86 | def get_imgs_train_transforms(args): 87 | transform = None 88 | 89 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 90 | if args.dataset == 'ucf101': 91 | transform = transforms.Compose([ 92 | RandomSizedCrop(consistent=True, size=224, p=1.0), 93 | Scale(size=(args.img_dim, args.img_dim)), 94 | RandomHorizontalFlip(consistent=True), 95 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 96 | ToTensor(), 97 | Normalize() 98 | ]) 99 | elif (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 100 | transform = transforms.Compose([ 101 | RandomSizedCrop(consistent=True, size=224, p=1.0), 102 | Scale(size=(args.img_dim, args.img_dim)), 103 | RandomHorizontalFlip(consistent=True), 104 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 105 | ToTensor(), 106 | Normalize() 107 | ]) 108 | # designed for kinetics400, short size=150, rand crop to 128x128 109 | elif args.dataset == 'kinetics': 110 | transform = transforms.Compose([ 111 | RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0), 112 | RandomHorizontalFlip(consistent=True), 113 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 114 | ToTensor(), 115 | Normalize() 116 | ]) 117 | 118 | return transform 119 | 120 | 121 | def get_imgs_val_transforms(args): 122 | transform = None 123 | 124 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 125 | if args.dataset == 'ucf101': 126 | transform = transforms.Compose([ 127 | RandomSizedCrop(consistent=True, size=224, p=0.3), 128 | Scale(size=(args.img_dim, args.img_dim)), 129 | RandomHorizontalFlip(consistent=True), 130 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 131 | ToTensor(), 132 | Normalize() 133 | ]) 134 | elif (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 135 | transform = transforms.Compose([ 136 | RandomSizedCrop(consistent=True, size=224, p=0.3), 137 | Scale(size=(args.img_dim, args.img_dim)), 138 | RandomHorizontalFlip(consistent=True), 139 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 140 | ToTensor(), 141 | Normalize() 142 | ]) 143 | # designed for kinetics400, short size=150, rand crop to 128x128 144 | elif args.dataset == 'kinetics': 145 | transform = transforms.Compose([ 146 | RandomSizedCrop(consistent=True, size=224, p=0.3), 147 | RandomHorizontalFlip(consistent=True), 148 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 149 | ToTensor(), 150 | Normalize() 151 | ]) 152 | 153 | return transform 154 | 155 | 156 | def get_flow_transforms(args): 157 | transform = None 158 | 159 | # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 160 | if (args.dataset == 'ucf101') or (args.dataset == 'jhmdb') or (args.dataset == 'hmdb51'): 161 | transform = transforms.Compose([ 162 | RandomIntensityCropForFlow(size=224), 163 | Scale(size=(args.img_dim, args.img_dim)), 164 | ToTensor(), 165 | ]) 166 | # designed for kinetics400, short size=150, rand crop to 128x128 167 | elif args.dataset == 'kinetics': 168 | transform = transforms.Compose([ 169 | RandomIntensityCropForFlow(size=args.img_dim), 170 | ToTensor(), 171 | ]) 172 | 173 | return transform 174 | 175 | 176 | def get_heatmap_transforms(_): 177 | crop_size = int(192 * 0.8) 178 | transform = transforms.Compose([ 179 | RandomIntensityCropForTensors(size=crop_size), 180 | ScaleForTensors(size=(64, 64)), 181 | ]) 182 | return transform 183 | 184 | 185 | def get_segmask_transforms(_): 186 | crop_size = int(192 * 0.8) 187 | transform = transforms.Compose([ 188 | RandomIntensityCropForTensors(size=crop_size), 189 | ScaleForTensors(size=(64, 64)), 190 | ]) 191 | return transform 192 | -------------------------------------------------------------------------------- /process_data/src/build_rawframes_optimized.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import os.path as osp 5 | import glob 6 | import cv2 7 | 8 | from pipes import quote 9 | from multiprocessing import Pool, current_process 10 | from tqdm import tqdm 11 | from subprocess import check_call,CalledProcessError 12 | 13 | import mmcv 14 | 15 | 16 | def dump_frames(vid_item): 17 | full_path, vid_path, vid_id = vid_item 18 | vid_name = vid_path.split('.')[0] 19 | out_full_path = osp.join(args.out_dir, vid_name) 20 | try: 21 | os.mkdir(out_full_path) 22 | except OSError: 23 | pass 24 | vr = mmcv.VideoReader(full_path) 25 | for i in range(len(vr)): 26 | if vr[i] is not None: 27 | mmcv.imwrite( 28 | vr[i], '{}/img_{:05d}.jpg'.format(out_full_path, i + 1)) 29 | else: 30 | print('[Warning] length inconsistent!' 31 | 'Early stop with {} out of {} frames'.format(i + 1, len(vr))) 32 | break 33 | print('{} done with {} frames'.format(vid_name, len(vr))) 34 | sys.stdout.flush() 35 | return True 36 | 37 | 38 | def num_frames_in_vid(v_path): 39 | vidcap = cv2.VideoCapture(v_path) 40 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 41 | vidcap.release() 42 | return nb_frames 43 | 44 | 45 | def run_optical_flow(vid_item, dev_id=0): 46 | full_path, vid_path, vid_id = vid_item 47 | vid_name = vid_path.split('.')[0] 48 | out_full_path = osp.join(args.out_dir, vid_name) 49 | try: 50 | os.mkdir(out_full_path) 51 | except OSError: 52 | pass 53 | 54 | current = current_process() 55 | dev_id = (int(current._identity[0]) - 1) % args.num_gpu 56 | image_path = '{}/img'.format(out_full_path) 57 | flow_x_path = '{}/flow_x'.format(out_full_path) 58 | flow_y_path = '{}/flow_y'.format(out_full_path) 59 | 60 | num_frames = num_frames_in_vid(full_path) 61 | if os.path.exists(image_path + '_%05d.jpg' % (num_frames - 1)): 62 | return True 63 | 64 | try: 65 | check_call( 66 | [osp.join(args.df_path, 'build/extract_gpu'), 67 | '-f={}'.format(quote(full_path)), '-x={}'.format(quote(flow_x_path)), '-y={}'.format(quote(flow_y_path)), 68 | '-i={}'.format(quote(image_path)), '-b=20', '-t=0', '-d={}'.format(dev_id), 69 | '-s=1', '-o={}'.format(args.out_format), '-w={}'.format(args.new_width), '-h={}'.format(args.new_height)] 70 | ) 71 | except CalledProcessError as e: 72 | print(e.stdout()) 73 | 74 | return True 75 | 76 | 77 | def run_warp_optical_flow(vid_item, dev_id=0): 78 | full_path, vid_path, vid_id = vid_item 79 | vid_name = vid_path.split('.')[0] 80 | out_full_path = osp.join(args.out_dir, vid_name) 81 | try: 82 | os.mkdir(out_full_path) 83 | except OSError: 84 | pass 85 | 86 | current = current_process() 87 | dev_id = (int(current._identity[0]) - 1) % args.num_gpu 88 | flow_x_path = '{}/flow_x'.format(out_full_path) 89 | flow_y_path = '{}/flow_y'.format(out_full_path) 90 | 91 | cmd = osp.join(args.df_path + 'build/extract_warp_gpu') + \ 92 | ' -f={} -x={} -y={} -b=20 -t=1 -d={} -s=1 -o={}'.format( 93 | quote(full_path), quote(flow_x_path), quote(flow_y_path), 94 | dev_id, args.out_format) 95 | 96 | os.system(cmd) 97 | print('warp on {} {} done'.format(vid_id, vid_name)) 98 | sys.stdout.flush() 99 | return True 100 | 101 | 102 | def parse_args(): 103 | parser = argparse.ArgumentParser(description='extract optical flows') 104 | parser.add_argument('src_dir', type=str) 105 | parser.add_argument('out_dir', type=str) 106 | parser.add_argument('--level', type=int, 107 | choices=[1, 2], 108 | default=2) 109 | parser.add_argument('--num_worker', type=int, default=8) 110 | parser.add_argument('--flow_type', type=str, 111 | default=None, choices=[None, 'tvl1', 'warp_tvl1']) 112 | parser.add_argument('--df_path', type=str, 113 | default='../mmaction/third_party/dense_flow') 114 | parser.add_argument("--out_format", type=str, default='dir', 115 | choices=['dir', 'zip'], help='output format') 116 | parser.add_argument("--ext", type=str, default='avi', 117 | choices=['avi', 'mp4'], help='video file extensions') 118 | parser.add_argument("--new_width", type=int, default=0, 119 | help='resize image width') 120 | parser.add_argument("--new_height", type=int, 121 | default=0, help='resize image height') 122 | parser.add_argument("--num_gpu", type=int, default=8, help='number of GPU') 123 | parser.add_argument("--resume", action='store_true', default=False, 124 | help='resume optical flow extraction ' 125 | 'instead of overwriting') 126 | parser.add_argument("--debug", type=int, default=0, help='debug mode') 127 | args = parser.parse_args() 128 | 129 | return args 130 | 131 | 132 | if __name__ == '__main__': 133 | args = parse_args() 134 | 135 | if not osp.isdir(args.out_dir): 136 | print('Creating folder: {}'.format(args.out_dir)) 137 | os.makedirs(args.out_dir) 138 | if args.level == 2: 139 | classes = os.listdir(args.src_dir) 140 | for classname in classes: 141 | new_dir = osp.join(args.out_dir, classname) 142 | if not osp.isdir(new_dir): 143 | print('Creating folder: {}'.format(new_dir)) 144 | os.makedirs(new_dir) 145 | 146 | print('Reading videos from folder: ', args.src_dir) 147 | print('Extension of videos: ', args.ext) 148 | if args.level == 2: 149 | fullpath_list = glob.glob(args.src_dir + '/*/*.' + args.ext) 150 | done_fullpath_list = glob.glob(args.out_dir + '/*/*') 151 | elif args.level == 1: 152 | fullpath_list = glob.glob(args.src_dir + '/*.' + args.ext) 153 | done_fullpath_list = glob.glob(args.out_dir + '/*') 154 | print('Total number of videos found: ', len(fullpath_list)) 155 | if args.resume: 156 | fullpath_list = set(fullpath_list).difference(set(done_fullpath_list)) 157 | fullpath_list = list(fullpath_list) 158 | print('Resuming. number of videos to be done: ', len(fullpath_list)) 159 | 160 | fullpath_list = sorted(fullpath_list) 161 | 162 | if args.level == 2: 163 | vid_list = list(map(lambda p: osp.join( 164 | '/'.join(p.split('/')[-2:])), fullpath_list)) 165 | elif args.level == 1: 166 | vid_list = list(map(lambda p: p.split('/')[-1], fullpath_list)) 167 | 168 | if args.debug: 169 | K = 5 170 | fullpath_list = fullpath_list[:K] 171 | vid_list = vid_list[:K] 172 | args.num_worker = 4 173 | 174 | pbar = tqdm(total=len(vid_list), smoothing=0.001) 175 | pool = Pool(args.num_worker) 176 | 177 | def update(*a): 178 | pbar.update() 179 | 180 | call_func = None 181 | if args.flow_type == 'tvl1': 182 | call_func = run_optical_flow 183 | elif args.flow_type == 'warp_tvl1': 184 | call_func = run_warp_optical_flow 185 | else: 186 | call_func = dump_frames 187 | 188 | for arg in zip(fullpath_list, vid_list, range(len(vid_list))): 189 | pool.apply_async(call_func, args=(arg,), callback=update) 190 | 191 | pool.close() 192 | pool.join() -------------------------------------------------------------------------------- /process_data/src/extract_frame.py: -------------------------------------------------------------------------------- 1 | from joblib import delayed, Parallel 2 | import os 3 | import sys 4 | import glob 5 | from tqdm import tqdm 6 | import cv2 7 | import argparse 8 | 9 | import matplotlib.pyplot as plt 10 | plt.switch_backend('agg') 11 | 12 | 13 | def str2bool(s): 14 | """Convert string to bool (in argparse context).""" 15 | if s.lower() not in ['true', 'false']: 16 | raise ValueError('Need bool; got %r' % s) 17 | return {'true': True, 'false': False}[s.lower()] 18 | 19 | 20 | def extract_video_opencv(v_path, f_root, dim=240): 21 | '''v_path: single video path; 22 | f_root: root to store frames''' 23 | v_class = v_path.split('/')[-2] 24 | v_name = os.path.basename(v_path)[0:-4] 25 | out_dir = os.path.join(f_root, v_class, v_name) 26 | if not os.path.exists(out_dir): 27 | os.makedirs(out_dir) 28 | 29 | vidcap = cv2.VideoCapture(v_path) 30 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 31 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 32 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 33 | if (width == 0) or (height == 0): 34 | print(v_path, 'not successfully loaded, drop ..'); return 35 | new_dim = resize_dim(width, height, dim) 36 | 37 | success, image = vidcap.read() 38 | count = 1 39 | while success: 40 | image = cv2.resize(image, new_dim, interpolation = cv2.INTER_LINEAR) 41 | cv2.imwrite(os.path.join(out_dir, 'image_%05d.jpg' % count), image, 42 | [cv2.IMWRITE_JPEG_QUALITY, 80])# quality from 0-100, 95 is default, high is good 43 | success, image = vidcap.read() 44 | count += 1 45 | 46 | # Correct the amount of frames 47 | if (count * 30) < nb_frames: 48 | nb_frames = int(nb_frames * 30 / 1000) 49 | 50 | if nb_frames > count: 51 | print('/'.join(out_dir.split('/')[-2::]), 'NOT extracted successfully: %df/%df' % (count, nb_frames)) 52 | 53 | vidcap.release() 54 | 55 | 56 | def resize_dim(w, h, target): 57 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 58 | if w >= h: 59 | return (int(target * w / h), int(target)) 60 | else: 61 | return (int(target), int(target * h / w)) 62 | 63 | 64 | def main_UCF101(v_root, f_root): 65 | print('extracting UCF101 ... ') 66 | print('extracting videos from %s' % v_root) 67 | print('frame save to %s' % f_root) 68 | 69 | if not os.path.exists(f_root): os.makedirs(f_root) 70 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 71 | print(len(v_act_root)) 72 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 73 | v_paths = glob.glob(os.path.join(j, '*.avi')) 74 | v_paths = sorted(v_paths) 75 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 76 | 77 | 78 | def main_HMDB51(v_root, f_root): 79 | print('extracting HMDB51 ... ') 80 | print('extracting videos from %s' % v_root) 81 | print('frame save to %s' % f_root) 82 | 83 | if not os.path.exists(f_root): os.makedirs(f_root) 84 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 85 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 86 | v_paths = glob.glob(os.path.join(j, '*.avi')) 87 | v_paths = sorted(v_paths) 88 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 89 | 90 | 91 | def main_JHMDB(v_root, f_root): 92 | print('extracting JHMDB ... ') 93 | print('extracting videos from %s' % v_root) 94 | print('frame save to %s' % f_root) 95 | 96 | if not os.path.exists(f_root): os.makedirs(f_root) 97 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 98 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 99 | v_paths = glob.glob(os.path.join(j, '*.avi')) 100 | v_paths = sorted(v_paths) 101 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 102 | 103 | 104 | def main_kinetics400(v_root, f_root, dim=128): 105 | print('extracting Kinetics400 ... ') 106 | for basename in ['train', 'val']: 107 | v_root_real = v_root + '/' + basename 108 | if not os.path.exists(v_root_real): 109 | print('Wrong v_root'); sys.exit() 110 | 111 | f_root_real = f_root + '/' + basename 112 | print('Extract to: \nframe: %s' % f_root_real) 113 | if not os.path.exists(f_root_real): 114 | os.makedirs(f_root_real) 115 | v_act_root = glob.glob(os.path.join(v_root_real, '*/')) 116 | v_act_root = sorted(v_act_root) 117 | 118 | # if resume, remember to delete the last video folder 119 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 120 | v_paths = glob.glob(os.path.join(j, '*.mp4')) 121 | v_paths = sorted(v_paths) 122 | # for resume: 123 | v_class = j.split('/')[-2] 124 | out_dir = os.path.join(f_root_real, v_class) 125 | if os.path.exists(out_dir): print(out_dir, 'exists!'); continue 126 | print('extracting: %s' % v_class) 127 | # dim = 150 (crop to 128 later) or 256 (crop to 224 later) 128 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root_real, dim=dim) for p in tqdm(v_paths, total=len(v_paths))) 129 | 130 | 131 | def main_Panasonic(v_root, f_root, dim): 132 | print('extracting Panasonic ... ') 133 | print('extracting videos from %s' % v_root) 134 | print('frame save to %s' % f_root) 135 | 136 | if not os.path.exists(f_root): os.makedirs(f_root) 137 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 138 | print(len(v_act_root)) 139 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 140 | v_paths = glob.glob(os.path.join(j, '*.mkv')) 141 | v_paths = sorted(v_paths) 142 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root, dim) for p in tqdm(v_paths, total=len(v_paths))) 143 | 144 | 145 | if __name__ == '__main__': 146 | # v_root is the video source path, f_root is where to store frames 147 | # edit 'your_path' here: 148 | 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--ucf101', default=False, type=str2bool) 151 | parser.add_argument('--jhmdb', default=False, type=str2bool) 152 | parser.add_argument('--hmdb51', default=False, type=str2bool) 153 | parser.add_argument('--kinetics', default=False, type=str2bool) 154 | parser.add_argument('--panasonic', default=False, type=str2bool) 155 | parser.add_argument('--dataset_path', default='/scr/data', type=str) 156 | parser.add_argument('--dim', default=128, type=int) 157 | args = parser.parse_args() 158 | 159 | dataset_path = args.dataset_path 160 | 161 | if args.ucf101: 162 | main_UCF101(v_root=dataset_path + '/ucf101/videos/', f_root=dataset_path + '/ucf101/frame/') 163 | 164 | if args.jhmdb: 165 | main_JHMDB(v_root=dataset_path + '/jhmdb/videos/', f_root=dataset_path + '/jhmdb/frame/') 166 | 167 | if args.hmdb51: 168 | main_HMDB51(v_root=dataset_path+'/hmdb/videos', f_root=dataset_path+'/hmdb/frame') 169 | 170 | if args.panasonic: 171 | main_Panasonic(v_root=dataset_path+'/action_split_data/V1.0', f_root=dataset_path+'/frame', dim=192) 172 | 173 | if args.kinetics: 174 | if args.dim == 256: 175 | main_kinetics400( 176 | v_root=dataset_path + '/kinetics/video', f_root=dataset_path + '/kinetics/frame256', dim=args.dim 177 | ) 178 | else: 179 | assert args.dim == 128, "Invalid dim: {}".format(args.dim) 180 | main_kinetics400(v_root=dataset_path+'/kinetics/video', f_root=dataset_path+'/kinetics/frame', dim=128) 181 | 182 | # main_kinetics400(v_root='your_path/Kinetics400_256/videos', 183 | # f_root='your_path/Kinetics400_256/frame', dim=256) 184 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | from datetime import datetime 6 | import glob 7 | import re 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | from collections import deque 11 | from tqdm import tqdm 12 | from torchvision import transforms 13 | 14 | 15 | def save_checkpoint(state, mode, is_best=0, gap=1, filename='models/checkpoint.pth.tar', keep_all=False): 16 | torch.save(state, filename) 17 | last_epoch_path = os.path.join( 18 | os.path.dirname(filename), 'mode_' + mode + '_epoch%s.pth.tar' % str(state['epoch']-gap)) 19 | alternate_last_epoch_path = os.path.join(os.path.dirname(filename), 'epoch%s.pth.tar' % str(state['epoch']-gap)) 20 | if not keep_all: 21 | try: 22 | os.remove(last_epoch_path) 23 | except: 24 | try: 25 | os.remove(alternate_last_epoch_path) 26 | except: 27 | print("Couldn't remove last_epoch_path: ", last_epoch_path, alternate_last_epoch_path) 28 | pass 29 | if is_best: 30 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'mode_' + mode + '_model_best_*.pth.tar')) 31 | for i in past_best: 32 | try: os.remove(i) 33 | except: pass 34 | torch.save( 35 | state, 36 | os.path.join( 37 | os.path.dirname(filename), 38 | 'mode_' + mode + '_model_best_epoch%s.pth.tar' % str(state['epoch']) 39 | ) 40 | ) 41 | 42 | 43 | def write_log(content, epoch, filename): 44 | if not os.path.exists(filename): 45 | log_file = open(filename, 'w') 46 | else: 47 | log_file = open(filename, 'a') 48 | log_file.write('## Epoch %d:\n' % epoch) 49 | log_file.write('time: %s\n' % str(datetime.now())) 50 | log_file.write(content + '\n\n') 51 | log_file.close() 52 | 53 | 54 | def calc_topk_accuracy(output, target, topk=(1,)): 55 | ''' 56 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 57 | Given predicted and ground truth labels, 58 | calculate top-k accuracies. 59 | ''' 60 | maxk = max(topk) 61 | batch_size = target.size(0) 62 | 63 | _, pred = output.topk(maxk, 1, True, True) 64 | pred = pred.t() 65 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 66 | 67 | res = [] 68 | for k in topk: 69 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 70 | res.append(correct_k.mul_(1 / batch_size)) 71 | return res 72 | 73 | 74 | def calc_accuracy(output, target): 75 | '''output: (B, N); target: (B)''' 76 | target = target.squeeze() 77 | _, pred = torch.max(output, 1) 78 | return torch.mean((pred == target).float()) 79 | 80 | def calc_accuracy_binary(output, target): 81 | '''output, target: (B, N), output is logits, before sigmoid ''' 82 | pred = output > 0 83 | acc = torch.mean((pred == target.byte()).float()) 84 | del pred, output, target 85 | return acc 86 | 87 | def get_topk_single_from_multilabel_pred(pred, target, ks): 88 | assert ks == sorted(ks) 89 | pred = pred.clone() 90 | out = [] 91 | valid_indices = torch.zeros_like(pred, dtype=bool) 92 | for k in range(1, ks[-1] + 1): 93 | max_indices = torch.argmax(pred, dim=1) 94 | valid_indices[:, max_indices] = pred[:, max_indices] > 0 95 | if k in ks: 96 | topk_single = torch.sum(valid_indices & target) 97 | out.append(topk_single) 98 | pred[:, max_indices] = 0 99 | 100 | return out 101 | 102 | def calc_per_class_multilabel_counts(output, target, num_ths=101): 103 | tp, tn, fp, fn = [torch.zeros((num_ths, output.shape[1])) for _ in range(4)] 104 | target = target > 0 105 | single_mask = target.sum(dim=-1) == 1 106 | for th in range(num_ths): 107 | pred = output >= th / (num_ths - 1) 108 | tp[th] = torch.sum(pred & target, axis=0) 109 | fp[th] = torch.sum(pred & ~target, axis=0) 110 | fn[th] = torch.sum(~pred & target, axis=0) 111 | tn[th] = torch.sum(~pred & ~target, axis=0) 112 | # single label metrics 113 | top1_single, top3_single = get_topk_single_from_multilabel_pred(output[single_mask], target[single_mask], [1, 3]) 114 | all_single = torch.sum(single_mask) 115 | return tp, tn, fp, fn, top1_single, top3_single, all_single 116 | 117 | def calc_hamming_loss(tp, tn, fp, fn, num_ths=101): 118 | th = (num_ths - 1) // 2 119 | tp, fp, tn, fn = tp[th].sum(), fp[th].sum(), tn[th].sum(), fn[th].sum() 120 | return ((fp + fn) / (fp + fn + tp + tn)).cpu().data.numpy() 121 | 122 | def calc_mAP(tp, fp, num_ths=101): 123 | # clamp to > 0.5 124 | tp, fp = tp[(num_ths - 1) // 2 : -1], fp[(num_ths - 1) // 2 : -1] 125 | m = (tp + fp) > 0 126 | precision = torch.zeros_like(tp) 127 | precision[m] = tp[m] / (tp[m] + fp[m]) 128 | return precision[m].mean().cpu().data.numpy() 129 | 130 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 131 | assert len(mean) == len(std) == 3 132 | inv_mean = [-mean[i]/std[i] for i in range(3)] 133 | inv_std = [1/i for i in std] 134 | return transforms.Normalize(mean=inv_mean, std=inv_std) 135 | 136 | 137 | class AverageMeter(object): 138 | """Computes and stores the average and current value""" 139 | def __init__(self): 140 | self.reset() 141 | 142 | def reset(self): 143 | self.val = 0 144 | self.avg = 0 145 | self.sum = 0 146 | self.count = 0 147 | self.local_history = deque([]) 148 | self.local_avg = 0 149 | self.history = [] 150 | self.dict = {} # save all data values here 151 | self.save_dict = {} # save mean and std here, for summary table 152 | 153 | def update(self, val, n=1, history=0, step=100): 154 | self.val = val 155 | self.sum += val * n 156 | self.count += n 157 | self.avg = self.sum / self.count 158 | if history: 159 | self.history.append(val) 160 | if step > 0: 161 | self.local_history.append(val) 162 | if len(self.local_history) > step: 163 | self.local_history.popleft() 164 | self.local_avg = np.average(self.local_history) 165 | 166 | def dict_update(self, val, key): 167 | if key in self.dict.keys(): 168 | self.dict[key].append(val) 169 | else: 170 | self.dict[key] = [val] 171 | 172 | def __len__(self): 173 | return self.count 174 | 175 | 176 | class AccuracyTable(object): 177 | '''compute accuracy for each class''' 178 | def __init__(self, names): 179 | self.names = names 180 | self.dict = {} 181 | 182 | def update(self, pred, tar): 183 | pred = pred.flatten() 184 | tar = tar.flatten() 185 | for i, j in zip(pred, tar): 186 | i = int(i) 187 | j = int(j) 188 | if j not in self.dict.keys(): 189 | self.dict[j] = {'count':0,'correct':0} 190 | self.dict[j]['count'] += 1 191 | if i == j: 192 | self.dict[j]['correct'] += 1 193 | 194 | def print_table(self): 195 | for key in sorted(self.dict.keys()): 196 | acc = self.dict[key]['correct'] / self.dict[key]['count'] 197 | print('%25s: %5d, acc: %3d/%3d = %0.6f' \ 198 | % (self.names[key], key, self.dict[key]['correct'], self.dict[key]['count'], acc)) 199 | 200 | def print_dict(self): 201 | acc_dict = {} 202 | for key in sorted(self.dict.keys()): 203 | acc_dict[self.names[key].lower()] = self.dict[key]['correct'] / self.dict[key]['count'] 204 | print(acc_dict) 205 | 206 | class ConfusionMeter(object): 207 | '''compute and show confusion matrix''' 208 | def __init__(self, num_class): 209 | self.num_class = num_class 210 | self.mat = np.zeros((num_class, num_class)) 211 | self.precision = [] 212 | self.recall = [] 213 | 214 | def update(self, pred, tar): 215 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy() 216 | pred = np.squeeze(pred) 217 | tar = np.squeeze(tar) 218 | for p,t in zip(pred.flat, tar.flat): 219 | self.mat[p][t] += 1 220 | 221 | def print_mat(self): 222 | print('Confusion Matrix: (target in columns)') 223 | print(self.mat) 224 | 225 | def plot_mat(self, path, dictionary=None, annotate=False): 226 | plt.figure(dpi=600) 227 | plt.imshow(self.mat, 228 | cmap=plt.cm.jet, 229 | interpolation=None, 230 | extent=(0.5, np.shape(self.mat)[0]+0.5, np.shape(self.mat)[1]+0.5, 0.5)) 231 | width, height = self.mat.shape 232 | if annotate: 233 | for x in range(width): 234 | for y in range(height): 235 | plt.annotate(str(int(self.mat[x][y])), xy=(y+1, x+1), 236 | horizontalalignment='center', 237 | verticalalignment='center', 238 | fontsize=8) 239 | 240 | if dictionary is not None: 241 | plt.xticks([i+1 for i in range(width)], 242 | [dictionary[i] for i in range(width)], 243 | rotation='vertical') 244 | plt.yticks([i+1 for i in range(height)], 245 | [dictionary[i] for i in range(height)]) 246 | plt.xlabel('Ground Truth') 247 | plt.ylabel('Prediction') 248 | plt.colorbar() 249 | plt.tight_layout() 250 | plt.savefig(path, format='svg') 251 | plt.clf() 252 | 253 | # for i in range(width): 254 | # if np.sum(self.mat[i,:]) != 0: 255 | # self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:])) 256 | # if np.sum(self.mat[:,i]) != 0: 257 | # self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i])) 258 | # print('Average Precision: %0.4f' % np.mean(self.precision)) 259 | # print('Average Recall: %0.4f' % np.mean(self.recall)) 260 | 261 | 262 | 263 | 264 | -------------------------------------------------------------------------------- /train/sim_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import model_utils as mu 9 | 10 | sys.path.append('../utils') 11 | from utils import calc_topk_accuracy 12 | from random import random 13 | 14 | 15 | eps = 1e-5 16 | INF = 1000.0 17 | 18 | 19 | class MemoryBank(nn.Module): 20 | 21 | def __init__(self, size): 22 | super(MemoryBank, self).__init__() 23 | self.maxlen = size 24 | self.dim = None 25 | self.bank = None 26 | 27 | def bootstrap(self, X): 28 | self.dim = X.shape[1:] 29 | gcd = math.gcd(X.shape[0], self.maxlen) 30 | self.bank = torch.cat([X[:gcd]] * (self.maxlen // gcd), dim=0).detach().to(X.device) 31 | assert self.bank.shape[0] == self.maxlen, "Invalid shape: {}".format(self.bank.shape) 32 | self.bank.requires_grad = False 33 | 34 | def update(self, X): 35 | # Initialize the memory bank 36 | N = X.shape[0] 37 | if self.dim is None: 38 | self.bootstrap(X) 39 | 40 | assert X.shape[1:] == self.dim, "Invalid size: {} {}".format(X.shape, self.dim) 41 | self.bank = torch.cat([self.bank[N:], X.detach().to(X.device)], dim=0).detach() 42 | 43 | def fetchBank(self): 44 | if self.bank is not None: 45 | assert self.bank.requires_grad is False, "Bank grad not false: {}".format(self.bank.requires_grad) 46 | return self.bank 47 | 48 | def fetchAppended(self, X): 49 | if self.bank is None: 50 | self.bootstrap(X) 51 | return self.fetchAppended(X) 52 | assert X.shape[1:] == self.bank.shape[1:], "Invalid shapes: {}, {}".format(X.shape, self.bank.shape) 53 | assert self.bank.requires_grad is False, "Bank grad not false: {}".format(self.bank.requires_grad) 54 | return torch.cat([X, self.bank], dim=0) 55 | 56 | 57 | class WeightNormalizedMarginLoss(nn.Module): 58 | def __init__(self, target): 59 | super(WeightNormalizedMarginLoss, self).__init__() 60 | 61 | self.target = target.float().clone() 62 | 63 | # Parameters for the weight loss 64 | self.f = 0.5 65 | self.one_ratio = self.target[self.target == 1].numel() / (self.target.numel() * 1.0) 66 | 67 | # Setup weight mask 68 | self.weight_mask = target.float().clone() 69 | self.weight_mask[self.weight_mask >= 1.] = self.f * (1 - self.one_ratio) 70 | self.weight_mask[self.weight_mask <= 0.] = (1. - self.f) * self.one_ratio 71 | 72 | # Normalize the weight accordingly 73 | self.weight_mask = self.weight_mask.to(self.target.device) / (self.one_ratio * (1. - self.one_ratio)) 74 | 75 | self.hinge_target = self.target.clone() 76 | self.hinge_target[self.hinge_target >= 1] = 1 77 | self.hinge_target[self.hinge_target <= 0] = -1 78 | 79 | self.dummy_target = self.target.clone() 80 | 81 | self.criteria = nn.HingeEmbeddingLoss(margin=((1 - self.f) / (1 - self.one_ratio))) 82 | 83 | def forward(self, value): 84 | distance = 1.0 - value 85 | return self.criteria(self.weight_mask * distance, self.hinge_target) 86 | 87 | 88 | class SimHandler(nn.Module): 89 | 90 | def __init__(self): 91 | super(SimHandler, self).__init__() 92 | 93 | def verify_shape_for_dot_product(self, mode0, mode1): 94 | 95 | B, N, D = mode0.shape 96 | assert (B, N, D) == tuple(mode1.shape), \ 97 | "Mismatch between mode0 and mode1 features: {}, {}".format(mode0.shape, mode1.shape) 98 | 99 | # dot product in mode0-mode1 pair, get a 4d tensor. First 2 dims are from mode0, the last from mode1 100 | nmode0 = mode0.view(B * N, D) 101 | nmode1 = mode1.view(B * N, D) 102 | 103 | return nmode0, nmode1, B, N, D 104 | 105 | def get_feature_cross_pair_score(self, mode0, mode1): 106 | """ 107 | Gives us all pair wise scores 108 | (mode0/mode1)features: [B, N, D], [B2, N2, D] 109 | Returns 4D pair score tensor 110 | """ 111 | 112 | B1, N1, D1 = mode0.shape 113 | B2, N2, D2 = mode1.shape 114 | 115 | assert D1 == D2, "Different dimensions: {} {}".format(mode0.shape, mode1.shape) 116 | nmode0 = mode0.view(B1 * N1, D1) 117 | nmode1 = mode1.view(B2 * N2, D2) 118 | 119 | score = torch.matmul( 120 | nmode0.reshape(B1 * N1, D1), 121 | nmode1.reshape(B2 * N2, D1).transpose(0, 1) 122 | ).view(B1, N1, B2, N2) 123 | 124 | return score 125 | 126 | def get_feature_pair_score(self, mode0, mode1): 127 | """ 128 | Returns aligned pair scores 129 | (pred/gt)features: [B, N, D] 130 | Returns 2D pair score tensor 131 | """ 132 | 133 | nmode0, nmode1, B, N, D = self.verify_shape_for_dot_product(mode0, mode1) 134 | score = torch.bmm( 135 | nmode0.view(B * N, 1, D), 136 | nmode1.view(B * N, D, 1) 137 | ).view(B, N) 138 | 139 | return score 140 | 141 | def l2NormedVec(self, x, dim=-1): 142 | assert x.shape[dim] >= 256, "Invalid dimension for reduction: {}".format(x.shape) 143 | return x / (torch.norm(x, p=2, dim=dim, keepdim=True) + eps) 144 | 145 | 146 | class AlignSimHandler(SimHandler): 147 | 148 | def __init__(self, instance_label): 149 | super(AlignSimHandler, self).__init__() 150 | 151 | self.target = instance_label.clone() 152 | self.criterion_base = nn.CrossEntropyLoss() 153 | self.criterion = lambda x, y: self.criterion_base(x, y.float().argmax(dim=1)) 154 | 155 | self.accuracyKList = [1, 3] 156 | 157 | def forward(self, mode0, mode1): 158 | score = self.get_feature_cross_pair_score(mode0, mode1) 159 | 160 | B, NS, _, _ = score.shape 161 | 162 | score_flattened = score.view(B * NS, B * NS) 163 | 164 | assert self.target.shape == (B, NS, B, NS), "Invalid shape: {}, {}".format(self.target.shape, score.shape) 165 | 166 | instance_target_flattened = self.target.view(B * NS, B * NS) 167 | instance_target_lbl = instance_target_flattened.float().argmax(dim=1) 168 | 169 | # Compute and log performance metrics 170 | stats = {} 171 | topKs = calc_topk_accuracy(score_flattened, instance_target_lbl, self.accuracyKList) 172 | for i in range(len(self.accuracyKList)): 173 | stats["acc" + str(self.accuracyKList[i])] = topKs[i] 174 | 175 | return self.criterion(score_flattened, instance_target_flattened), stats 176 | 177 | 178 | class CosSimHandler(SimHandler): 179 | 180 | def __init__(self): 181 | super(CosSimHandler, self).__init__() 182 | 183 | self.target = None 184 | self.criterion = nn.MSELoss() 185 | 186 | def score(self, mode0, mode1): 187 | cosSim = self.get_feature_pair_score(self.l2NormedVec(mode0), self.l2NormedVec(mode1)) 188 | 189 | assert cosSim.min() >= -1. - eps, "Invalid value for cos sim: {}".format(cosSim) 190 | assert cosSim.max() <= 1. + eps, "Invalid value for cos sim: {}".format(cosSim) 191 | 192 | return cosSim 193 | 194 | def forward(self, mode0, mode1): 195 | score = self.score(mode0, mode1) 196 | 197 | if self.target is None: 198 | self.target = torch.ones_like(score) 199 | 200 | stats = {"m": score.mean()} 201 | 202 | return self.criterion(score, self.target), stats 203 | 204 | 205 | class CorrSimHandler(SimHandler): 206 | 207 | def __init__(self): 208 | super(CorrSimHandler, self).__init__() 209 | 210 | self.shapeMode0, self.shapeMode1 = None, None 211 | self.runningMeanMode0 = None 212 | self.runningMeanMode1 = None 213 | 214 | self.retention = 0.7 215 | self.target = None 216 | 217 | self.criterion = nn.L1Loss() 218 | 219 | self.noInitYet = True 220 | 221 | @staticmethod 222 | def get_ovr_mean(mode): 223 | return mode.mean(dim=0, keepdim=True).mean(dim=1, keepdim=True).detach().cpu() 224 | 225 | def init_vars(self, mode0, mode1): 226 | 227 | self.shapeMode0 = mode0.shape 228 | self.shapeMode1 = mode1.shape 229 | 230 | assert len(self.shapeMode0) == 3 231 | 232 | self.runningMeanMode0 = self.get_ovr_mean(mode0) 233 | self.runningMeanMode1 = self.get_ovr_mean(mode1) 234 | 235 | self.noInitYet = False 236 | 237 | def update_means(self, mean0, mean1): 238 | 239 | self.runningMeanMode0 = (self.runningMeanMode0 * self.retention) + (mean0 * (1. - self.retention)) 240 | self.runningMeanMode1 = (self.runningMeanMode1 * self.retention) + (mean1 * (1. - self.retention)) 241 | 242 | def get_means_on_device(self, device): 243 | return self.runningMeanMode0.to(device), self.runningMeanMode1.to(device) 244 | 245 | def score(self, mode0, mode1): 246 | 247 | if self.noInitYet: 248 | self.init_vars(mode0, mode1) 249 | 250 | meanMode0 = self.get_ovr_mean(mode0) 251 | meanMode1 = self.get_ovr_mean(mode1) 252 | self.update_means(meanMode0.detach().cpu(), meanMode1.detach().cpu()) 253 | runningMean0, runningMean1 = self.get_means_on_device(mode0.device) 254 | 255 | corr = self.get_feature_pair_score( 256 | self.l2NormedVec(mode0 - runningMean0), 257 | self.l2NormedVec(mode1 - runningMean1) 258 | ) 259 | 260 | assert corr.min() >= -1. - eps, "Invalid value for correlation: {}".format(corr) 261 | assert corr.max() <= 1. + eps, "Invalid value for correlation: {}".format(corr) 262 | 263 | return corr 264 | 265 | def forward(self, mode0, mode1): 266 | score = self.score(mode0, mode1) 267 | 268 | if self.target is None: 269 | self.target = torch.ones_like(score) 270 | 271 | stats = {"m": score.mean()} 272 | 273 | return self.criterion(score, self.target), stats 274 | 275 | 276 | class DenseCorrSimHandler(CorrSimHandler): 277 | 278 | def __init__(self, instance_label): 279 | super(DenseCorrSimHandler, self).__init__() 280 | 281 | self.target = instance_label.float().clone() 282 | # self.criterion = WeightNormalizedMSELoss(self.target) 283 | self.criterion = WeightNormalizedMarginLoss(self.target) 284 | 285 | def get_feature_pair_score(self, mode0, mode1): 286 | return self.get_feature_cross_pair_score(mode0, mode1) 287 | 288 | def forward(self, mode0, mode1): 289 | score = self.score(mode0, mode1) 290 | 291 | B, N, B2, N2 = score.shape 292 | assert (B, N) == (B2, N2), "Invalid shape: {}".format(score.shape) 293 | assert score.shape == self.target.shape, "Invalid shape: {}, {}".format(score.shape, self.target.shape) 294 | 295 | stats = { 296 | "m": (self.criterion.weight_mask * score).mean(), 297 | "m-": score[self.target <= 0].mean(), 298 | "m+": score[self.target > 0].mean(), 299 | } 300 | 301 | return self.criterion(score), stats 302 | 303 | 304 | class DenseCosSimHandler(CosSimHandler): 305 | 306 | def __init__(self, instance_label): 307 | super(DenseCosSimHandler, self).__init__() 308 | 309 | self.target = instance_label.float() 310 | # self.criterion = WeightNormalizedMSELoss(self.target) 311 | self.criterion = WeightNormalizedMarginLoss(self.target) 312 | 313 | def get_feature_pair_score(self, mode0, mode1): 314 | return self.get_feature_cross_pair_score(mode0, mode1) 315 | 316 | def forward(self, mode0, mode1): 317 | score = self.score(mode0, mode1) 318 | assert score.shape == self.target.shape, "Invalid shape: {}, {}".format(score.shape, self.target.shape) 319 | 320 | stats = { 321 | "m": (self.criterion.weight_mask * score).mean(), 322 | "m-": score[self.target <= 0].mean(), 323 | "m+": score[self.target > 0].mean(), 324 | } 325 | 326 | return self.criterion(score), stats 327 | 328 | 329 | class InterModeDotHandler(nn.Module): 330 | 331 | def __init__(self, last_size=1): 332 | super(InterModeDotHandler, self).__init__() 333 | 334 | self.cosSimHandler = CosSimHandler() 335 | self.last_size = last_size 336 | 337 | def contextFetHelper(self, context): 338 | context = context[:, -1, :].unsqueeze(1) 339 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 340 | return context 341 | 342 | def fetHelper(self, z): 343 | B, N, D, S, S = z.shape 344 | z = z.permute(0, 1, 3, 4, 2).contiguous().view(B, N * S * S, D) 345 | return z 346 | 347 | def dotProdHelper(self, z, zt): 348 | return self.cosSimHandler.get_feature_cross_pair_score( 349 | self.cosSimHandler.l2NormedVec(z), self.cosSimHandler.l2NormedVec(zt) 350 | ) 351 | 352 | def get_cluster_dots(self, feature): 353 | fet = self.fetHelper(feature) 354 | return self.dotProdHelper(fet, fet) 355 | 356 | def forward(self, context=None, comp_pred=None, comp_fet=None): 357 | cdot = self.fetHelper(comp_fet) 358 | return self.dotProdHelper(cdot, cdot), cdot 359 | -------------------------------------------------------------------------------- /backbone/resnet_2d3d.py: -------------------------------------------------------------------------------- 1 | ## modified from https://github.com/kenshohara/3D-ResNets-PyTorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | 8 | __all__ = [ 9 | 'ResNet2d3d_full', 'resnet18_2d3d_full', 'resnet34_2d3d_full', 'resnet50_2d3d_full', 'resnet101_2d3d_full', 10 | 'resnet152_2d3d_full', 'resnet200_2d3d_full', 11 | ] 12 | 13 | def conv3x3x3(in_planes, out_planes, stride=1, bias=False): 14 | # 3x3x3 convolution with padding 15 | return nn.Conv3d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=bias) 22 | 23 | def conv1x3x3(in_planes, out_planes, stride=1, bias=False): 24 | # 1x3x3 convolution with padding 25 | return nn.Conv3d( 26 | in_planes, 27 | out_planes, 28 | kernel_size=(1,3,3), 29 | stride=(1,stride,stride), 30 | padding=(0,1,1), 31 | bias=bias) 32 | 33 | 34 | def downsample_basic_block(x, planes, stride): 35 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 36 | zero_pads = torch.Tensor( 37 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 38 | out.size(4)).zero_() 39 | if isinstance(out.data, torch.cuda.FloatTensor): 40 | zero_pads = zero_pads.cuda() 41 | 42 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 43 | 44 | return out 45 | 46 | 47 | class BasicBlock3d(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 51 | super(BasicBlock3d, self).__init__() 52 | bias = False 53 | self.use_final_relu = use_final_relu 54 | self.conv1 = conv3x3x3(inplanes, planes, stride, bias=bias) 55 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 56 | 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3x3(planes, planes, bias=bias) 59 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | if self.use_final_relu: out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class BasicBlock2d(nn.Module): 84 | expansion = 1 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 87 | super(BasicBlock2d, self).__init__() 88 | bias = False 89 | self.use_final_relu = use_final_relu 90 | self.conv1 = conv1x3x3(inplanes, planes, stride, bias=bias) 91 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 92 | 93 | self.relu = nn.ReLU(inplace=True) 94 | self.conv2 = conv1x3x3(planes, planes, bias=bias) 95 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 96 | 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | if self.use_final_relu: out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class Bottleneck3d(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 123 | super(Bottleneck3d, self).__init__() 124 | bias = False 125 | self.use_final_relu = use_final_relu 126 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 127 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 128 | 129 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias) 130 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 131 | 132 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 133 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 134 | 135 | self.relu = nn.ReLU(inplace=True) 136 | self.downsample = downsample 137 | self.stride = stride 138 | 139 | def forward(self, x): 140 | residual = x 141 | 142 | out = self.conv1(x) 143 | out = self.bn1(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv2(out) 147 | out = self.bn2(out) 148 | out = self.relu(out) 149 | 150 | out = self.conv3(out) 151 | out = self.bn3(out) 152 | 153 | if self.downsample is not None: 154 | residual = self.downsample(x) 155 | 156 | out += residual 157 | if self.use_final_relu: out = self.relu(out) 158 | 159 | return out 160 | 161 | 162 | class Bottleneck2d(nn.Module): 163 | expansion = 4 164 | 165 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 166 | super(Bottleneck2d, self).__init__() 167 | bias = False 168 | self.use_final_relu = use_final_relu 169 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 170 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 171 | 172 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 173 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 174 | 175 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 176 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 177 | 178 | self.relu = nn.ReLU(inplace=True) 179 | self.downsample = downsample 180 | self.stride = stride 181 | self.batchnorm = True 182 | 183 | def forward(self, x): 184 | residual = x 185 | 186 | out = self.conv1(x) 187 | if self.batchnorm: out = self.bn1(out) 188 | out = self.relu(out) 189 | 190 | out = self.conv2(out) 191 | if self.batchnorm: out = self.bn2(out) 192 | out = self.relu(out) 193 | 194 | out = self.conv3(out) 195 | if self.batchnorm: out = self.bn3(out) 196 | 197 | if self.downsample is not None: 198 | residual = self.downsample(x) 199 | 200 | out += residual 201 | if self.use_final_relu: out = self.relu(out) 202 | 203 | return out 204 | 205 | 206 | class ResNet2d3d_full(nn.Module): 207 | def __init__(self, block, layers, track_running_stats=True, in_channels=3): 208 | super(ResNet2d3d_full, self).__init__() 209 | self.inplanes = 64 210 | self.track_running_stats = track_running_stats 211 | bias = False 212 | self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=(1,7,7), stride=(1, 2, 2), padding=(0, 3, 3), bias=bias) 213 | self.bn1 = nn.BatchNorm3d(64, track_running_stats=track_running_stats) 214 | self.relu = nn.ReLU(inplace=True) 215 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 216 | 217 | if not isinstance(block, list): 218 | block = [block] * 4 219 | 220 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 221 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2) 222 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2) 223 | self.layer4 = self._make_layer(block[3], 256, layers[3], stride=2, is_final=True) 224 | # modify layer4 from exp=512 to exp=256 225 | for m in self.modules(): 226 | if isinstance(m, nn.Conv3d): 227 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 228 | if m.bias is not None: m.bias.data.zero_() 229 | elif isinstance(m, nn.BatchNorm3d): 230 | m.weight.data.fill_(1) 231 | m.bias.data.zero_() 232 | 233 | def _make_layer(self, block, planes, blocks, stride=1, is_final=False): 234 | downsample = None 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | # customized_stride to deal with 2d or 3d residual blocks 237 | if (block == Bottleneck2d) or (block == BasicBlock2d): 238 | customized_stride = (1, stride, stride) 239 | else: 240 | customized_stride = stride 241 | 242 | downsample = nn.Sequential( 243 | nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=customized_stride, bias=False), 244 | nn.BatchNorm3d(planes * block.expansion, track_running_stats=self.track_running_stats) 245 | ) 246 | 247 | layers = [] 248 | layers.append(block(self.inplanes, planes, stride, downsample, track_running_stats=self.track_running_stats)) 249 | self.inplanes = planes * block.expansion 250 | if is_final: # if is final block, no ReLU in the final output 251 | for i in range(1, blocks-1): 252 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 253 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats, use_final_relu=False)) 254 | else: 255 | for i in range(1, blocks): 256 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 257 | 258 | return nn.Sequential(*layers) 259 | 260 | def forward(self, x): 261 | x = self.conv1(x) 262 | x = self.bn1(x) 263 | x = self.relu(x) 264 | x = self.maxpool(x) 265 | 266 | x = self.layer1(x) 267 | x = self.layer2(x) 268 | x = self.layer3(x) 269 | x = self.layer4(x) 270 | 271 | return x 272 | 273 | 274 | ## full resnet 275 | def resnet18_2d3d_full(**kwargs): 276 | '''Constructs a ResNet-18 model. ''' 277 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 278 | [2, 2, 2, 2], **kwargs) 279 | return model 280 | 281 | def resnet34_2d3d_full(**kwargs): 282 | '''Constructs a ResNet-34 model. ''' 283 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 284 | [3, 4, 6, 3], **kwargs) 285 | return model 286 | 287 | def resnet50_2d3d_full(**kwargs): 288 | '''Constructs a ResNet-50 model. ''' 289 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 290 | [3, 4, 6, 3], **kwargs) 291 | return model 292 | 293 | def resnet101_2d3d_full(**kwargs): 294 | '''Constructs a ResNet-101 model. ''' 295 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 296 | [3, 4, 23, 3], **kwargs) 297 | return model 298 | 299 | def resnet152_2d3d_full(**kwargs): 300 | '''Constructs a ResNet-101 model. ''' 301 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 302 | [3, 8, 36, 3], **kwargs) 303 | return model 304 | 305 | def resnet200_2d3d_full(**kwargs): 306 | '''Constructs a ResNet-101 model. ''' 307 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 308 | [3, 24, 36, 3], **kwargs) 309 | return model 310 | 311 | def neq_load_customized(model, pretrained_dict): 312 | ''' load pre-trained model in a not-equal way, 313 | when new model has been partially modified ''' 314 | model_dict = model.state_dict() 315 | tmp = {} 316 | print('\n=======Check Weights Loading======') 317 | print('Weights not used from pretrained file:') 318 | names = [] 319 | for k, v in pretrained_dict.items(): 320 | if k in model_dict: 321 | tmp[k] = v 322 | else: 323 | names.append(k) 324 | print(set([k.split('.')[-1] for k in names])) 325 | print('---------------------------') 326 | print('Weights not loaded into new model:') 327 | names = [] 328 | for k, v in model_dict.items(): 329 | if k not in pretrained_dict: 330 | names.append(k) 331 | print(set([k.split('.')[-1] for k in names])) 332 | print('===================================\n') 333 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 334 | del pretrained_dict 335 | model_dict.update(tmp) 336 | del tmp 337 | model.load_state_dict(model_dict) 338 | return model 339 | 340 | 341 | if __name__ == '__main__': 342 | mymodel = resnet18_2d3d_full() 343 | mydata = torch.FloatTensor(4, 3, 16, 128, 128) 344 | nn.init.normal_(mydata) 345 | import ipdb; ipdb.set_trace() 346 | mymodel(mydata) 347 | -------------------------------------------------------------------------------- /process_data/src/write_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | import sys 5 | import json 6 | from collections import defaultdict 7 | import pandas as pd 8 | 9 | from joblib import delayed, Parallel 10 | from tqdm import tqdm 11 | from collections import defaultdict 12 | 13 | 14 | def str2bool(s): 15 | """Convert string to bool (in argparse context).""" 16 | if s.lower() not in ['true', 'false']: 17 | raise ValueError('Need bool; got %r' % s) 18 | return {'true': True, 'false': False}[s.lower()] 19 | 20 | 21 | def write_list(data_list, path, ): 22 | with open(path, 'w') as f: 23 | writer = csv.writer(f, delimiter=',') 24 | for row in data_list: 25 | if row: writer.writerow(row) 26 | print('split saved to %s' % path) 27 | 28 | 29 | def main_UCF101(f_root, splits_root, csv_root='../data/ucf101/'): 30 | '''generate training/testing split, count number of available frames, save in csv''' 31 | if not os.path.exists(csv_root): os.makedirs(csv_root) 32 | for which_split in [1,2,3]: 33 | train_set = [] 34 | test_set = [] 35 | train_split_file = os.path.join(splits_root, 'trainlist%02d.txt' % which_split) 36 | with open(train_split_file, 'r') as f: 37 | for line in f: 38 | vpath = os.path.join(f_root, line.split(' ')[0][0:-4]) + '/' 39 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 40 | 41 | test_split_file = os.path.join(splits_root, 'testlist%02d.txt' % which_split) 42 | with open(test_split_file, 'r') as f: 43 | for line in f: 44 | vpath = os.path.join(f_root, line.rstrip()[0:-4]) + '/' 45 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 46 | 47 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 48 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 49 | 50 | 51 | def main_HMDB51(f_root, splits_root, csv_root='../data/hmdb51/'): 52 | '''generate training/testing split, count number of available frames, save in csv''' 53 | if not os.path.exists(csv_root): os.makedirs(csv_root) 54 | for which_split in [1,2,3]: 55 | train_set = [] 56 | test_set = [] 57 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 58 | assert len(split_files) == 51 59 | for split_file in split_files: 60 | action_name = os.path.basename(split_file)[0:-16] 61 | with open(split_file, 'r') as f: 62 | for line in f: 63 | video_name = line.split(' ')[0] 64 | _type = line.split(' ')[1] 65 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 66 | if _type == '1': 67 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 68 | elif _type == '2': 69 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 70 | 71 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 72 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 73 | 74 | 75 | def main_JHMDB(f_root, splits_root, csv_root='../data/jhmdb/'): 76 | '''generate training/testing split, count number of available frames, save in csv''' 77 | if not os.path.exists(csv_root): os.makedirs(csv_root) 78 | for which_split in [1,2,3]: 79 | train_set = [] 80 | test_set = [] 81 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 82 | assert len(split_files) == 21 83 | for split_file in split_files: 84 | action_name = os.path.basename(split_file)[0:-16] 85 | with open(split_file, 'r') as f: 86 | for line in f: 87 | video_name = line.split(' ')[0] 88 | _type = line.split(' ')[1].strip('\n') 89 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 90 | if _type == '1': 91 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 92 | elif _type == '2': 93 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 94 | 95 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 96 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 97 | 98 | 99 | ### For Kinetics ### 100 | def get_split(root, split_path, mode): 101 | print('processing %s split ...' % mode) 102 | print('checking %s' % root) 103 | split_list = [] 104 | split_content = pd.read_csv(split_path).iloc[:,0:4] 105 | split_list = Parallel(n_jobs=64)\ 106 | (delayed(check_exists)(row, root) \ 107 | for i, row in tqdm(split_content.iterrows(), total=len(split_content))) 108 | return split_list 109 | 110 | missedCnt = 0 111 | 112 | def check_exists(row, root): 113 | global missedCnt 114 | 115 | dirname = '_'.join([row['youtube_id'], '%06d' % row['time_start'], '%06d' % row['time_end']]) 116 | full_dirname = os.path.join(root, row['label'], dirname) 117 | # replace spaces with underscores 118 | full_dirname = full_dirname.replace(' ', '_') 119 | if os.path.exists(full_dirname): 120 | n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) 121 | return [full_dirname, n_frames] 122 | else: 123 | missedCnt += 1 124 | return None 125 | 126 | def main_Kinetics400(mode, k400_path, f_root, csv_root='../data/kinetics400'): 127 | global missedCnt 128 | missedCnt = 0 129 | 130 | train_split_path = os.path.join(k400_path, 'kinetics-400_train.csv') 131 | val_split_path = os.path.join(k400_path, 'kinetics-400_val.csv') 132 | test_split_path = os.path.join(k400_path, 'kinetics-400_test.csv') 133 | 134 | if not os.path.exists(csv_root): 135 | os.makedirs(csv_root) 136 | 137 | if mode == 'train': 138 | train_split = get_split(os.path.join(f_root, 'train'), train_split_path, 'train') 139 | write_list(train_split, os.path.join(csv_root, 'train_split.csv')) 140 | elif mode == 'val': 141 | val_split = get_split(os.path.join(f_root, 'val'), val_split_path, 'val') 142 | write_list(val_split, os.path.join(csv_root, 'val_split.csv')) 143 | elif mode == 'test': 144 | test_split = get_split(f_root, test_split_path, 'test') 145 | write_list(test_split, os.path.join(csv_root, 'test_split.csv')) 146 | else: 147 | raise IOError('wrong mode') 148 | 149 | print("Total files missed:", missedCnt) 150 | 151 | def check_exists_panasonic(row, root, pra_to_prva, atomic_actions_by_pra_dict=None): 152 | # global missedCnt 153 | 154 | pra, cat = row 155 | p, r, a = pra.split('_') 156 | prvas = [prva for prva in pra_to_prva[pra] if os.path.exists(os.path.join(root, p, prva))] 157 | 158 | result = [] 159 | 160 | if atomic_actions_by_pra_dict is not None: 161 | # atomic 162 | for atomic_action in atomic_actions_by_pra_dict[pra]: 163 | result_row = [] 164 | for prva in prvas: 165 | v = prva.split('_')[2] 166 | full_dirname = os.path.join(root, p, prva) 167 | # FIXME: add action length 168 | s, e, c, _ = atomic_action 169 | # FIXME: ignore actions whose end index is greater than video length. Check consistency! 170 | if e > len(os.listdir(full_dirname)): 171 | continue 172 | c = '_'.join(c.split(' ')) 173 | result_row += [full_dirname, s, e, c] 174 | result.append(result_row) 175 | else: 176 | # video-level, a single row 177 | for prva in prvas: 178 | v = prva.split('_')[2] 179 | full_dirname = os.path.join(root, p, prva) 180 | if os.path.exists(full_dirname): 181 | n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) 182 | # dir name, start frame, end frame, category 183 | result += [full_dirname, 0, n_frames - 1, cat] 184 | result = [result] 185 | 186 | # if len(result) == 0: 187 | # missedCnt += 1 188 | # return None 189 | return result 190 | 191 | 192 | def get_split_panasonic(root, split_path, mode, atomic_actions_by_pra_dict=None): 193 | print('processing %s split ...' % mode) 194 | print('checking %s' % root) 195 | split_list = [] 196 | split_content = pd.read_csv(split_path, header=None) 197 | pra_to_prva = defaultdict(list) 198 | for p in os.listdir(root): 199 | for prva in os.listdir(os.path.join(root, p)): 200 | p, r, v, a = prva.split('_') 201 | pra = '_'.join((p, r, a)) 202 | pra_to_prva[pra].append(prva) 203 | 204 | split_list = Parallel(n_jobs=64)\ 205 | (delayed(check_exists_panasonic)(row, root, pra_to_prva, atomic_actions_by_pra_dict) \ 206 | for i, row in tqdm(split_content.iterrows(), total=len(split_content))) 207 | split_list = [j for i in split_list for j in i] 208 | return split_list 209 | 210 | 211 | def _panasonic_get_atomic_actions(annotation_root): 212 | print('reading atomic action data...') 213 | atomic_actions_by_pra_dict = defaultdict(list) 214 | for p in tqdm(os.listdir(annotation_root)): 215 | for j in os.listdir(os.path.join(annotation_root, p)): 216 | r, a = j.split('_')[1:3] 217 | pra = '_'.join((p, r, a)) 218 | with open(os.path.join(annotation_root, p, j)) as f: 219 | atomic_action_list = json.load(f) 220 | for d in atomic_action_list: 221 | atomic_actions_by_pra_dict[pra].append([d['frame_start'], d['frame_end'], d['class'], d['action_length']]) 222 | return atomic_actions_by_pra_dict 223 | 224 | 225 | def main_panasonic(f_root, in_root, out_root='../../data/panasonic/', atomic=False): 226 | '''generate training/testing split, count number of available frames, save in csv''' 227 | 228 | train_split_path = os.path.join(in_root, 'list_with_activity_labels/train_list.csv') 229 | val_split_path = os.path.join(in_root, 'list_with_activity_labels/val_list.csv') 230 | test_split_path = os.path.join(in_root, 'list_with_activity_labels/test_list.csv') 231 | 232 | atomic_actions_by_pra_dict = None 233 | 234 | if atomic: 235 | atomic_actions_by_pra_dict = _panasonic_get_atomic_actions(os.path.join(in_root, 'annotation_files/atomic_actions')) 236 | 237 | if not os.path.exists(out_root): 238 | os.makedirs(out_root) 239 | 240 | train_split = get_split_panasonic(f_root, train_split_path, 'train', atomic_actions_by_pra_dict) 241 | write_list(train_split, os.path.join(out_root, 'train_split{}.csv'.format('_atomic' if atomic else ''))) 242 | 243 | val_split = get_split_panasonic(f_root, val_split_path, 'val', atomic_actions_by_pra_dict) 244 | write_list(val_split, os.path.join(out_root, 'val_split{}.csv'.format('_atomic' if atomic else ''))) 245 | 246 | test_split = get_split_panasonic(f_root, test_split_path, 'test', atomic_actions_by_pra_dict) 247 | write_list(test_split, os.path.join(out_root, 'test_split{}.csv'.format('_atomic' if atomic else ''))) 248 | 249 | 250 | import argparse 251 | 252 | if __name__ == '__main__': 253 | # f_root is the frame path 254 | # edit 'your_path' here: 255 | 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--ucf101', action='store_true') 258 | parser.add_argument('--jhmdb', action='store_true') 259 | parser.add_argument('--hmdb51', action='store_true') 260 | parser.add_argument('--kinetics', action='store_true') 261 | parser.add_argument('--panasonic', action='store_true') 262 | parser.add_argument('--panasonic-atomic', action='store_true') 263 | parser.add_argument('--dataset_path', default='/scr/data', type=str) 264 | args = parser.parse_args() 265 | 266 | dataset_path = args.dataset_path 267 | 268 | if args.ucf101: 269 | main_UCF101(f_root=dataset_path + 'ucf101/frame', 270 | splits_root=dataset_path + 'ucf101/splits_classification', 271 | csv_root='../../data/ucf101') 272 | 273 | if args.jhmdb: 274 | main_JHMDB(f_root=dataset_path + 'jhmdb/frame', splits_root=dataset_path + 'jhmdb/splits') 275 | 276 | if args.hmdb51: 277 | main_HMDB51(f_root=dataset_path + 'hmdb/frame', splits_root=dataset_path + 'hmdb/splits/') 278 | 279 | if args.kinetics: 280 | main_Kinetics400( 281 | mode='train', # train or val or test 282 | k400_path=dataset_path + 'kinetics/splits', 283 | f_root=dataset_path + 'kinetics/frame256', 284 | csv_root=os.path.join(dataset_path, 'kinetics400_256'), 285 | ) 286 | 287 | main_Kinetics400( 288 | mode='val', # train or val or test 289 | k400_path=dataset_path + 'kinetics/splits', 290 | f_root=dataset_path + 'kinetics/frame256', 291 | csv_root=os.path.join(dataset_path, 'kinetics400_256'), 292 | ) 293 | 294 | # main_Kinetics400(mode='train', # train or val or test 295 | # k400_path='your_path/Kinetics', 296 | # f_root='your_path/Kinetics400_256/frame', 297 | # csv_root='../data/kinetics400_256') 298 | 299 | if args.panasonic: 300 | main_panasonic( 301 | os.path.join(dataset_path, 'panasonic/frame'), 302 | os.path.join(dataset_path, 'panasonic'), 303 | out_root=os.path.join(dataset_path, 'panasonic')) 304 | 305 | if args.panasonic_atomic: 306 | main_panasonic( 307 | os.path.join(dataset_path, 'panasonic/frame'), 308 | os.path.join(dataset_path, 'panasonic'), 309 | out_root=os.path.join(dataset_path, 'panasonic'), atomic=True) 310 | -------------------------------------------------------------------------------- /train/hierarchical_trainer.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import argparse 3 | import torch 4 | import os 5 | import pickle 6 | import data_utils 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import model_utils as mu 11 | import sim_utils as su 12 | import model_trainer as mt 13 | import dataset_3d as d3d 14 | 15 | from tqdm import tqdm 16 | from torch.utils import data 17 | 18 | ''' 19 | Important components of hierarchical training 20 | 1. Collect per second block features for each video 21 | - Checkpoint 22 | 2. Create list of classes to train on, etc to extend easily to few shot scenario 23 | 3. Add LSTM/Transformer i.e. recurrent net training to perform training easily 24 | 4. Add DataLoader to load necessary classes, instances into train, val, test 25 | 5. Wrap up training with video level and validation 26 | - Checkpoint 27 | 6. Add multi task learning i.e. train for both atomic action type and video level action 28 | ''' 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--save_dir', default='', type=str, help='save dir for model') 33 | parser.add_argument('--prefix', required=True, type=str, help='prefix') 34 | parser.add_argument('--notes', default='', type=str, help='additional notes') 35 | parser.add_argument('--ckpt_path', required=True, type=str, help='Model ckpt path') 36 | parser.add_argument('--img_dim', required=True, type=int) 37 | parser.add_argument('--modality', required=True, type=str) 38 | parser.add_argument('--num_workers', default=0, type=int) 39 | 40 | eps = 1e-3 41 | cuda = torch.device('cuda') 42 | cosSimHandler = su.CosSimHandler() 43 | 44 | 45 | def get_cross_cos_sim_score(list0, list1): 46 | return cosSimHandler.get_feature_cross_pair_score( 47 | cosSimHandler.l2NormedVec(list0), cosSimHandler.l2NormedVec(list1) 48 | ) 49 | 50 | 51 | def get_instances_for_class(fnames, featuresArr, className): 52 | idxs = [i for i in range(len(fnames)) if get_class(fnames[i]) == className.lower()] 53 | return featuresArr[idxs], idxs 54 | 55 | 56 | def plot_histogram(scores, notes=''): 57 | scores = scores[scores < 1 - eps] 58 | plt.hist(scores.flatten().cpu(), bins=50, alpha=0.5) 59 | plt.ylabel('Cosine Similarity: {}'.format(notes)) 60 | 61 | 62 | def gen_and_plot_cossim_for_class(fnames, featuresArr, className): 63 | classFets, classFnames = get_instances_for_class(fnames, featuresArr, className) 64 | classScore = get_cross_cos_sim_score(classFets, classFets) 65 | plot_histogram(classScore, notes="Class - {}".format(className)) 66 | 67 | 68 | def get_context_representations(model, dataloader, modality): 69 | ''' 70 | Returns a single context vector for some random sample of a video 71 | ''' 72 | features = {} 73 | with torch.no_grad(): 74 | tq = tqdm(dataloader, desc="Progress:") 75 | for idx, data in enumerate(tq): 76 | input_seqs = data[modality].to(cuda) 77 | for input_idx in range(input_seqs.shape[0]): 78 | video = dataloader.dataset.get_video_name(data['vnames'][0]) 79 | input_seq = input_seqs[input_idx] 80 | contexts = model.get_representation(input_seq)[0] 81 | features[video] = { 82 | 'fets': contexts.cpu().detach(), 83 | 'video_labels': data['video_labels'], 84 | 'atomic_labels': data['atomic_labels'] 85 | } 86 | return features 87 | 88 | 89 | def get_class(fname): 90 | fname = fname.rstrip('/') 91 | return fname.split('/')[-2].lower() 92 | 93 | 94 | def save_features(features, path_name): 95 | with open(path_name, 'wb') as handle: 96 | pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL) 97 | 98 | 99 | def load_features(path_name): 100 | with open(path_name, 'rb') as handle: 101 | features = pickle.load(handle) 102 | return features 103 | 104 | 105 | def setup_panasonic_model_args(save_dir, restore_ckpt, img_dim=128, modality="imgs-0"): 106 | 107 | dataset = "panasonic-atomic" if "panasonic-atomic" in restore_ckpt else "panasonic" 108 | 109 | parser = mu.get_multi_modal_model_train_args() 110 | args = parser.parse_args('') 111 | 112 | # Populate dataset and device 113 | args.dataset = dataset 114 | args.num_classes = mu.get_num_classes(args.dataset) 115 | args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") 116 | 117 | args.model = "super" 118 | args.batch_size = 8 119 | args.img_dim = img_dim 120 | args.ds = 3 121 | args.num_seq = 8 122 | args.seq_len = 5 123 | args.save_dir = save_dir 124 | 125 | # Populate modalities 126 | args.modalities = modality 127 | args.modes = mt.get_modality_list(args.modalities) 128 | 129 | args.losses = ["super"] 130 | args.num_workers = 6 131 | 132 | if modality == "imgs-0": 133 | args.imgs_0_restore_ckpt = restore_ckpt 134 | elif modality == "imgs-1": 135 | args.imgs_1_restore_ckpt = restore_ckpt 136 | elif modality == mu.AudioMode: 137 | args.audio_restore_ckpt = restore_ckpt 138 | 139 | args.restore_ckpts = mt.get_modality_restore_ckpts(args) 140 | 141 | return args 142 | 143 | 144 | def get_hierarchical_panasonic_dataloader(args, split): 145 | dataset = d3d.HierarchicalPanasonic( 146 | mode=split, 147 | transform=mu.get_test_transforms(args), 148 | seq_len=args.seq_len, 149 | num_seq=args.num_seq, 150 | downsample=3, 151 | vals_to_return=args.modes + ["labels"], 152 | ) 153 | 154 | data_loader = data.DataLoader( 155 | dataset, 156 | sampler=data.SequentialSampler(dataset), 157 | batch_size=1, 158 | shuffle=False, 159 | num_workers=args.num_workers, 160 | collate_fn=data_utils.individual_collate, 161 | pin_memory=True, 162 | drop_last=False 163 | ) 164 | 165 | return data_loader 166 | 167 | 168 | import time 169 | import torch.nn as nn 170 | import torch.optim as optim 171 | 172 | 173 | class PredictorRNN(nn.Module): 174 | def __init__(self, input_size, hidden_size, num_classes, device): 175 | super(PredictorRNN, self).__init__() 176 | self.input_size = input_size 177 | self.hidden_size = hidden_size 178 | self.num_classes = num_classes 179 | 180 | self.dropout = 0.5 181 | self.rnn = nn.LSTM(self.input_size, self.hidden_size) 182 | self.final_fc = nn.Sequential( 183 | nn.Dropout(self.dropout), 184 | nn.Linear(self.hidden_size, self.num_classes), 185 | ) 186 | 187 | self.device = device 188 | 189 | def forward(self, input: torch.tensor, hidden: torch.tensor): 190 | B, N, D = input.shape 191 | output, hidden = self.rnn(input, hidden) 192 | print(output.shape, hidden.shape) 193 | return output, hidden 194 | 195 | def initHidden(self, batch): 196 | return torch.zeros(1, batch, self.hidden_size, device=self.device) 197 | 198 | 199 | class HierarchicalLearner(nn.Module): 200 | 201 | def __init__(self, args): 202 | super(HierarchicalLearner, self).__init__() 203 | 204 | self.device = args["device"] 205 | self.use_rep_loss = args["use_rep_loss"] 206 | 207 | self.predict = PredictorRNN(args["input_size"], args["hidden_size"], args["num_classes"], self.device) 208 | 209 | self.optimizer = optim.Adam(self.predict.parameters(), lr=args["lr"], weight_decay=args["wd"]) 210 | self.teacher_forcing_ratio = args["teacher_forcing_ratio"] 211 | 212 | self.writer_train, self.writer_val = mu.get_writers(args["img_path"]) 213 | 214 | self.feature_size = self.predict.hidden_size 215 | self.print_freq = args["print_freq"] 216 | self.iteration = 0 217 | 218 | self.rep_criterion_base = nn.CrossEntropyLoss() 219 | self.rep_criterion = lambda x, y: self.rep_criterion_base(x, y.float().argmax(dim=1)) 220 | 221 | self.criterion = nn.MSELoss() 222 | 223 | def prep_data(self, input_seq, target_seq): 224 | batch, num_seq, seq_len, C, K = input_seq.shape 225 | 226 | input_seq = input_seq.view(batch * num_seq, seq_len, C, K).permute(1, 0, 2, 3) 227 | target_seq = target_seq.view(batch * num_seq, seq_len, C, K).permute(1, 0, 2, 3) 228 | 229 | encoder_hidden = self.predict.initHidden(batch * num_seq) 230 | encoder_outputs = torch.zeros(seq_len, batch * num_seq, self.predict.hidden_size, device=self.device) 231 | 232 | return input_seq, target_seq, encoder_hidden, encoder_outputs, batch, num_seq, seq_len 233 | 234 | def get_representation(self, input_seq): 235 | 236 | batch, num_seq, seq_len, C, K = input_seq.shape 237 | input_seq = input_seq.view(batch * num_seq, seq_len, C, K).permute(1, 0, 2, 3) 238 | encoder_hidden = self.predict.initHidden(batch * num_seq) 239 | 240 | for ei in range(seq_len): 241 | encoder_output, encoder_hidden = self.predict(input_seq[ei], encoder_hidden) 242 | 243 | return encoder_hidden.view(batch, num_seq, self.predict.hidden_size).detach() 244 | 245 | def train_step(self, input_seq, label): 246 | self.optimizer.zero_grad() 247 | 248 | B, N, D = input_seq.shape 249 | 250 | hidden = self.encoder.initHidden(B * N) 251 | output, hidden = self.predict(input_seq, hidden) 252 | 253 | loss = self.criterion(output, label) 254 | loss.backward() 255 | 256 | self.optimizer.step() 257 | 258 | return loss 259 | 260 | def val_step(self, input_seq, target_seq, ret_rep=False): 261 | 262 | input_seq, target_seq, encoder_hidden, encoder_outputs, batch, num_seq, seq_len = \ 263 | self.prep_data(input_seq, target_seq) 264 | 265 | loss = 0 266 | 267 | for ei in range(seq_len): 268 | encoder_output, encoder_hidden = self.predict(input_seq[ei], encoder_hidden) 269 | encoder_outputs[ei] = encoder_output 270 | 271 | decoder_input = input_seq[-1].clone() 272 | 273 | decoder_hidden = encoder_hidden 274 | representation = encoder_hidden.reshape(batch * num_seq, -1).clone().detach() 275 | 276 | for di in range(seq_len): 277 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) 278 | decoder_input = decoder_output.detach() # detach from history as input 279 | loss += self.criterion(decoder_output, target_seq[di]) 280 | 281 | if ret_rep: 282 | return loss / seq_len, representation 283 | else: 284 | return loss / seq_len 285 | 286 | def train_epoch(self, epoch): 287 | 288 | self.train() 289 | 290 | trainX, trainY = [], [] 291 | losses = AverageMeter() 292 | 293 | tq = tqdm(self.train_loader, desc="Training progress in epoch: {}".format(epoch)) 294 | 295 | for idx, data in enumerate(tq): 296 | tic = time.time() 297 | 298 | input_seq = data["poses"] 299 | input_seq = input_seq.to(self.device) 300 | B = input_seq.size(0) 301 | NS = input_seq.size(1) 302 | 303 | target_seq = data["tgt_poses"] 304 | target_seq = target_seq.to(self.device) 305 | 306 | loss, X = self.train_step(input_seq, target_seq, ret_rep=True) 307 | losses.update(loss.item(), B) 308 | 309 | trainX.append(X) 310 | trainY.append(data["labels"].repeat(1, NS).reshape(-1)) 311 | 312 | tq.set_postfix({ 313 | "loss_val": losses.val, 314 | "loss_local_avg": losses.local_avg, 315 | "T": time.time()-tic 316 | }) 317 | 318 | if idx % self.print_freq == 0: 319 | self.writer_train.add_scalar('local/loss', losses.val, self.iteration) 320 | self.iteration += 1 321 | 322 | trainX = torch.cat(trainX) 323 | trainY = torch.cat(trainY).reshape(-1) 324 | 325 | return losses.local_avg, {"X": trainX, "Y": trainY} 326 | 327 | def val_epoch(self): 328 | 329 | self.eval() 330 | 331 | valX, valY = [], [] 332 | losses = AverageMeter() 333 | 334 | tq = tqdm(self.val_loader, desc="Val progress:") 335 | 336 | for idx, data in enumerate(tq): 337 | tic = time.time() 338 | 339 | input_seq = data["poses"] 340 | input_seq = input_seq.to(self.device) 341 | B = input_seq.size(0) 342 | NS = input_seq.size(1) 343 | 344 | target_seq = data["tgt_poses"] 345 | target_seq = target_seq.to(self.device) 346 | 347 | loss, X = self.val_step(input_seq, target_seq, ret_rep=True) 348 | losses.update(loss.item(), B) 349 | 350 | valX.append(X) 351 | valY.append(data["labels"].repeat(1, NS).reshape(-1)) 352 | 353 | tq.set_postfix({ 354 | "loss_val": losses.val, 355 | "loss_local_avg": losses.local_avg, 356 | "T": time.time()-tic 357 | }) 358 | 359 | if idx % self.print_freq == 0: 360 | self.writer_val.add_scalar('local/loss', losses.val, self.iteration) 361 | self.iteration += 1 362 | 363 | valX = torch.cat(valX) 364 | valY = torch.cat(valY).reshape(-1) 365 | 366 | return losses.local_avg, {"X": valX, "Y": valY} 367 | 368 | 369 | if __name__ == '__main__': 370 | script_args = parser.parse_args() 371 | args = setup_panasonic_model_args( 372 | save_dir=script_args.save_dir, 373 | restore_ckpt=script_args.ckpt_path, 374 | img_dim=script_args.img_dim, 375 | modality=script_args.modality, 376 | ) 377 | 378 | args.data_sources = script_args.modality 379 | args.num_workers = script_args.num_workers 380 | 381 | datasets = {} 382 | 383 | # Create model and switch to eval mode 384 | model = mt.get_backbone_for_modality(args, script_args.modality) 385 | model.eval() 386 | 387 | splits = ['train', 'val', 'test'] 388 | for split in splits: 389 | data_loader = get_hierarchical_panasonic_dataloader(args, split) 390 | features = get_context_representations(model, data_loader, script_args.modality) 391 | datasets[split] = features 392 | 393 | # Create parent file 394 | save_path = '/{}/nishantr/logs/{}/hierarchical_training_notes{}/features.pickle'.format( 395 | os.environ['BASE_DIR'], script_args.prefix, script_args.notes 396 | ) 397 | parent_path = os.path.dirname(save_path) 398 | if not os.path.isdir(parent_path): 399 | os.makedirs(parent_path) 400 | save_features(datasets, save_path) 401 | -------------------------------------------------------------------------------- /process_data/src/extract_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import cv2 4 | import os 5 | import glob 6 | import torch 7 | 8 | cv2.setNumThreads(0) 9 | 10 | from tqdm import tqdm 11 | from torch.utils import data 12 | from typing import Dict, List, Union 13 | 14 | # import some common detectron2 utilities 15 | from detectron2 import model_zoo 16 | from detectron2.config import get_cfg 17 | from detectron2.modeling import build_model 18 | from detectron2.checkpoint import DetectionCheckpointer 19 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 20 | from detectron2.structures import Boxes, ImageList, Instances 21 | from detectron2.layers import interpolate, cat 22 | from detectron2.utils.logger import setup_logger 23 | setup_logger() 24 | 25 | 26 | def str2bool(s): 27 | """Convert string to bool (in argparse context).""" 28 | if s.lower() not in ['true', 'false']: 29 | raise ValueError('Need bool; got %r' % s) 30 | return {'true': True, 'false': False}[s.lower()] 31 | 32 | 33 | imgShape = None 34 | 35 | from typing import Dict, List, Optional, Tuple, Union 36 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 37 | from detectron2.structures import Boxes, ImageList, Instances 38 | from detectron2.layers import interpolate, cat 39 | 40 | 41 | @torch.no_grad() 42 | def process_heatmaps(maps, rois, img_shapes): 43 | """ 44 | Extract predicted keypoint locations from heatmaps. 45 | Args: 46 | maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for 47 | each ROI and each keypoint. 48 | rois (Tensor): (#ROIs, 4). The box of each ROI. 49 | Returns: 50 | Tensor of shape (#ROIs, #keypoints, POOL_H, POOL_W) representing confidence scores 51 | """ 52 | 53 | offset_i = (rois[:, 1]).int() 54 | offset_j = (rois[:, 0]).int() 55 | 56 | widths = (rois[:, 2] - rois[:, 0]).clamp(min=1) 57 | heights = (rois[:, 3] - rois[:, 1]).clamp(min=1) 58 | widths_ceil = widths.ceil() 59 | heights_ceil = heights.ceil() 60 | 61 | # roi_map_scores = torch.zeros((maps.shape[0], maps.shape[1], imgShape[0], imgShape[1])) 62 | roi_map_scores = [torch.zeros((maps.shape[1], img_shapes[i][0], img_shapes[i][1])) for i in range(maps.shape[0])] 63 | num_rois, num_keypoints = maps.shape[:2] 64 | 65 | for i in range(num_rois): 66 | outsize = (int(heights_ceil[i]), int(widths_ceil[i])) 67 | # #keypoints x H x W 68 | roi_map = interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False).squeeze(0) 69 | 70 | # softmax over the spatial region 71 | max_score, _ = roi_map.view(num_keypoints, -1).max(1) 72 | max_score = max_score.view(num_keypoints, 1, 1) 73 | tmp_full_resolution = (roi_map - max_score).exp_() 74 | tmp_pool_resolution = (maps[i] - max_score).exp_() 75 | 76 | norm_score = ((tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True)) * 255.0).to(torch.uint8) 77 | 78 | # Produce scores over the region H x W, but normalize with POOL_H x POOL_W, 79 | # so that the scores of objects of different absolute sizes will be more comparable 80 | for idx in range(num_keypoints): 81 | roi_map_scores[i][idx, offset_i[i]:(offset_i[i] + outsize[0]), offset_j[i]:(offset_j[i] + outsize[1])] = \ 82 | norm_score[idx, ...].float() 83 | 84 | return roi_map_scores 85 | 86 | 87 | def heatmap_rcnn_inference(pred_keypoint_logits, pred_instances): 88 | bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) 89 | 90 | num_instances_per_image = [len(i) for i in pred_instances] 91 | img_shapes = [instance._image_size for instance in pred_instances for _ in range(len(instance))] 92 | hm_results = process_heatmaps(pred_keypoint_logits.detach(), bboxes_flat.detach(), img_shapes) 93 | 94 | hm_logits = [] 95 | cumsum_idx = np.cumsum(num_instances_per_image) 96 | 97 | assert len(hm_results) == cumsum_idx[-1], \ 98 | "Invalid sizes: {}, {}, {}".format(len(hm_results), cumsum_idx[-1], cumsum_idx) 99 | 100 | for idx in range(len(num_instances_per_image)): 101 | l = 0 if idx == 0 else cumsum_idx[idx - 1] 102 | if num_instances_per_image[idx] == 0: 103 | hm_logits.append(torch.zeros((0, 17, 0, 0))) 104 | else: 105 | hm_logits.append(torch.stack(hm_results[l:l + num_instances_per_image[idx]])) 106 | 107 | for idx in range(min(len(pred_instances), len(hm_logits))): 108 | pred_instances[idx].heat_maps = hm_logits[idx] 109 | 110 | 111 | @ROI_HEADS_REGISTRY.register() 112 | class HeatmapROIHeads(StandardROIHeads): 113 | """ 114 | A Standard ROIHeads which contains returns HeatMaps instead of keypoints. 115 | """ 116 | 117 | def __init__(self, cfg, input_shape): 118 | super().__init__(cfg, input_shape) 119 | 120 | def _forward_keypoint( 121 | self, features: List[torch.Tensor], instances: List[Instances] 122 | ) -> Union[Dict[str, torch.Tensor], List[Instances]]: 123 | if not self.keypoint_on: 124 | return {} if self.training else instances 125 | 126 | if self.training: 127 | assert False, "Not implemented yet!" 128 | else: 129 | pred_boxes = [x.pred_boxes for x in instances] 130 | keypoint_features = self.keypoint_pooler(features, pred_boxes) 131 | keypoint_logits = self.keypoint_head(keypoint_features) 132 | heatmap_rcnn_inference(keypoint_logits, instances) 133 | return instances 134 | 135 | 136 | def get_heatmap_detection_module(): 137 | # Inference with a keypoint detection module 138 | cfg = get_cfg() 139 | cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")) 140 | cfg.MODEL.ROI_HEADS.NAME = "HeatmapROIHeads" 141 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set threshold for this model 142 | cfg.MODEL.WEIGHTS = "detectron2://COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x/137849621/model_final_a6e10b.pkl" 143 | predictor = build_model(cfg) 144 | print("heatmap head:", cfg.MODEL.ROI_HEADS.NAME) 145 | DetectionCheckpointer(predictor).load(cfg.MODEL.WEIGHTS) 146 | predictor.eval() 147 | return cfg, predictor 148 | 149 | 150 | def get_panoptic_segmentation_module(): 151 | # Inference with a segmentation module 152 | cfg = get_cfg() 153 | cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) 154 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") 155 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set threshold for this model 156 | predictor = build_model(cfg) 157 | print("segmask head:", cfg.MODEL.ROI_HEADS.NAME) 158 | DetectionCheckpointer(predictor).load(cfg.MODEL.WEIGHTS) 159 | predictor.eval() 160 | return cfg, predictor 161 | 162 | 163 | def individual_collate(batch): 164 | """ 165 | Custom collation function for collate with new implementation of individual samples in data pipeline 166 | """ 167 | 168 | data = batch 169 | 170 | # Assuming there's at least one instance in the batch 171 | add_data_keys = data[0].keys() 172 | collected_data = {k: [] for k in add_data_keys} 173 | 174 | for i in range(len(list(data))): 175 | for k in add_data_keys: 176 | collected_data[k].extend(data[i][k]) 177 | 178 | return collected_data 179 | 180 | 181 | def resize_dim(w, h, target): 182 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 183 | if w >= h: 184 | return (int(target * w / h), int(target)) 185 | else: 186 | return (int(target), int(target * h / w)) 187 | 188 | 189 | class VideoDataset(data.Dataset): 190 | 191 | def __init__(self, v_root, vid_range, save_path, skip_len=2): 192 | super(VideoDataset, self).__init__() 193 | 194 | self.v_root = v_root 195 | self.vid_range = vid_range 196 | self.save_path = save_path 197 | 198 | self.init_videos() 199 | 200 | self.max_idx = len(self.v_names) 201 | self.skip = skip_len 202 | 203 | self.width, self.height = 320, 240 204 | self.dim = 192 205 | 206 | def num_frames_in_vid(self, v_path): 207 | vidcap = cv2.VideoCapture(v_path) 208 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 209 | vidcap.release() 210 | return nb_frames 211 | 212 | def extract_video_opencv(self, v_path): 213 | 214 | global imgShape 215 | 216 | v_class = v_path.split('/')[-2] 217 | v_name = os.path.basename(v_path)[0:-4] 218 | 219 | vidcap = cv2.VideoCapture(v_path) 220 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 221 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 222 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 223 | 224 | if (width == 0) or (height == 0): 225 | print(v_path, 'not successfully loaded, drop ..') 226 | return 227 | 228 | new_dim = resize_dim(width, height, self.dim) 229 | 230 | fnames, imgs = [], [] 231 | 232 | success, image = vidcap.read() 233 | count = 1 234 | while success: 235 | image = cv2.resize(image, new_dim, interpolation=cv2.INTER_LINEAR) 236 | if (count % self.skip == 0): 237 | fnames.append((v_class, v_name, count)) 238 | imgs.append(image) 239 | 240 | success, image = vidcap.read() 241 | count += 1 242 | 243 | if int(nb_frames * 0.8) > count: 244 | print(v_path, 'NOT extracted successfully: %df/%df' % (count, nb_frames)) 245 | 246 | vidcap.release() 247 | 248 | return imgs, fnames 249 | 250 | def vid_already_processed(self, v_path): 251 | v_class = v_path.split('/')[-2] 252 | # Remove avi extension 253 | v_name = os.path.basename(v_path)[0:-4] 254 | 255 | out_dir = os.path.join(self.save_path, v_class, v_name) 256 | num_frames = self.num_frames_in_vid(v_path) 257 | for count in range(max(0, num_frames - 10), num_frames): 258 | fpath = os.path.join(out_dir, 'segmask_%05d.npz' % count) 259 | if os.path.exists(fpath): 260 | return True 261 | 262 | return False 263 | 264 | def init_videos(self): 265 | print('processing videos from %s' % self.v_root) 266 | 267 | self.v_names = [] 268 | 269 | v_act_root = sorted(glob.glob(os.path.join(self.v_root, '*/'))) 270 | 271 | num_skip, tot_files = 0, 0 272 | for vid_dir in v_act_root: 273 | v_class = vid_dir.split('/')[-2] 274 | 275 | if (v_class[0].lower() >= self.vid_range[0]) and (v_class[0].lower() <= self.vid_range[1]): 276 | v_paths = glob.glob(os.path.join(vid_dir, '*.avi')) 277 | v_paths = sorted(v_paths) 278 | 279 | for v_path in v_paths: 280 | tot_files += 1 281 | if self.vid_already_processed(v_path): 282 | num_skip += 1 283 | continue 284 | self.v_names.append(v_path) 285 | 286 | print('Processing: {} files. Skipped: {}/{} files.'.format(len(self.v_names), num_skip, tot_files)) 287 | 288 | def __getitem__(self, idx): 289 | vname = self.v_names[idx] 290 | imgs, fnames = self.extract_video_opencv(vname) 291 | return {"img": imgs, "filename": fnames} 292 | 293 | def __len__(self): 294 | return self.max_idx 295 | 296 | 297 | def get_video_data_loader(path, vid_range, save_path, batch_size=2): 298 | dataset = VideoDataset(path, vid_range, save_path) 299 | data_loader = data.DataLoader( 300 | dataset, 301 | batch_size=batch_size, 302 | sampler=data.SequentialSampler(dataset), 303 | shuffle=False, 304 | num_workers=2, 305 | collate_fn=individual_collate, 306 | pin_memory=True, 307 | drop_last=True 308 | ) 309 | return data_loader 310 | 311 | 312 | def write_heatmap_to_file(root, fname, heatmap): 313 | # fname is a list of (class, vname, count) 314 | v_class, v_name, count = fname 315 | out_dir = os.path.join(root, v_class, v_name) 316 | 317 | if not os.path.exists(out_dir): 318 | os.makedirs(out_dir) 319 | 320 | np.savez_compressed(os.path.join(out_dir, 'heatmap_%05d.npz' % count), hm=heatmap) 321 | 322 | 323 | def write_segmask_to_file(root, fname, segmask): 324 | # fname is a list of (class, vname, count) 325 | v_class, v_name, count = fname 326 | out_dir = os.path.join(root, v_class, v_name) 327 | 328 | if not os.path.exists(out_dir): 329 | os.makedirs(out_dir) 330 | 331 | np.savez_compressed(os.path.join(out_dir, 'segmask_%05d.npz' % count), seg=segmask) 332 | 333 | 334 | def convert_to_uint8(x): 335 | x[x < 0.0] = 0.0 336 | x[x > 255.0] = 255.0 337 | nx = x.to(torch.uint8).numpy() 338 | return nx 339 | 340 | 341 | def process_videos(root, vid_provider, args, batch_size=32, debug=False): 342 | 343 | _, modelKP = get_heatmap_detection_module() 344 | _, modelPS = get_panoptic_segmentation_module() 345 | 346 | for batch in tqdm(vid_provider): 347 | imgsTot, fnamesTot = batch['img'], batch['filename'] 348 | 349 | for idx in range(0, len(imgsTot), batch_size): 350 | 351 | imgs, fnames = imgsTot[idx: idx + batch_size], fnamesTot[idx: idx + batch_size] 352 | 353 | imgsDict = [{'image': torch.Tensor(img).float().permute(2, 0, 1)} for img in imgs] 354 | 355 | with torch.no_grad(): 356 | if args.heatmap: 357 | outputsKP = modelKP(imgsDict) 358 | if args.segmask: 359 | outputsPS = modelPS(imgsDict) 360 | 361 | for i in range(len(imgs)): 362 | if args.heatmap: 363 | # Process the keypoints 364 | try: 365 | heatmap = outputsKP[i]['instances'].heat_maps.cpu() 366 | scores = outputsKP[i]['instances'].scores.cpu() 367 | avgHeatmap = (heatmap * scores.view(-1, 1, 1, 1)).sum(dim=0) 368 | # Clamp the max values 369 | avgHeatmap = convert_to_uint8(avgHeatmap) 370 | except: 371 | print("Heatmap generation:", fnames[i]) 372 | print(outputsKP[i]) 373 | else: 374 | assert avgHeatmap.shape[0] == 17, "Invalid size: {}".format(heatmap.shape) 375 | if not debug: 376 | write_heatmap_to_file(root, fnames[i], avgHeatmap) 377 | 378 | if args.segmask: 379 | # Process the segmentation mask 380 | try: 381 | semantic_map = torch.softmax(outputsPS[i]['sem_seg'].detach(), dim=0)[0].cpu() * 255.0 382 | semantic_map = convert_to_uint8(semantic_map) 383 | except: 384 | print("Segmask generation:", fnames[i]) 385 | print(outputsPS[i]) 386 | else: 387 | if not debug: 388 | write_segmask_to_file(root, fnames[i], semantic_map) 389 | 390 | 391 | if __name__ == '__main__': 392 | 393 | parser = argparse.ArgumentParser() 394 | parser.add_argument('--save_path', default='/scr/data/ucf101/features/', type=str) 395 | parser.add_argument('--dataset', default='/scr/data/ucf101/videos', type=str) 396 | parser.add_argument('--batch_size', default=32, type=int) 397 | parser.add_argument('--vid_range', default='az', type=str) 398 | parser.add_argument('--debug', default=0, type=int) 399 | parser.add_argument('--heatmap', default=0, type=int) 400 | parser.add_argument('--segmask', default=0, type=int) 401 | args = parser.parse_args() 402 | 403 | vid_provider = get_video_data_loader(args.dataset, args.vid_range, args.save_path) 404 | 405 | process_videos(args.save_path, vid_provider, batch_size=args.batch_size, debug=args.debug, args=args) 406 | -------------------------------------------------------------------------------- /test/dataset_3d_lc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import csv 9 | import glob 10 | import pandas as pd 11 | import numpy as np 12 | import cv2 13 | 14 | sys.path.append('../train') 15 | import model_utils as mu 16 | 17 | sys.path.append('../utils') 18 | from augmentation import * 19 | from tqdm import tqdm 20 | from joblib import Parallel, delayed 21 | 22 | 23 | def pil_loader(path): 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | 29 | toTensor = transforms.ToTensor() 30 | toPILImage = transforms.ToPILImage() 31 | def flow_loader(path): 32 | try: 33 | img = Image.open(path) 34 | except: 35 | return None 36 | f = toTensor(img) 37 | if f.mean() > 0.3: 38 | f -= 0.5 39 | return f 40 | 41 | 42 | def fetch_imgs_seq(vpath, idx_block): 43 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 44 | return seq 45 | 46 | 47 | def fill_nones(l): 48 | l = [l[i-1] if l[i] is None else l[i] for i in range(len(l))] 49 | l = [l[i-1] if l[i] is None else l[i] for i in range(len(l))] 50 | try: 51 | nonNoneL = [item for item in l if item is not None][0] 52 | except: 53 | nonNoneL = torch.zeros((1, 256, 256)) 54 | return [torch.zeros(nonNoneL.shape) if l[i] is None else l[i] for i in range(len(l))] 55 | 56 | 57 | def get_u_flow_path_list(vpath, idx_block): 58 | dataset = 'ucf101' if 'ucf101' in vpath else 'hmdb51' 59 | flow_base_path = os.path.join('/dev/shm/data/nishantr/flow/', dataset + '_flow/') 60 | vid_name = os.path.basename(os.path.normpath(vpath)) 61 | return [os.path.join(flow_base_path, 'u', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 62 | 63 | 64 | def get_v_flow_path_list(vpath, idx_block): 65 | dataset = 'ucf101' if 'ucf101' in vpath else 'hmdb51' 66 | flow_base_path = os.path.join('/dev/shm/data/nishantr/flow/', dataset + '_flow/') 67 | vid_name = os.path.basename(os.path.normpath(vpath)) 68 | return [os.path.join(flow_base_path, 'v', vid_name, 'frame%06d.jpg' % (i + 1)) for i in idx_block] 69 | 70 | 71 | def fetch_flow_seq(vpath, idx_block): 72 | u_flow_list = get_u_flow_path_list(vpath, idx_block) 73 | v_flow_list = get_v_flow_path_list(vpath, idx_block) 74 | 75 | u_seq = fill_nones([flow_loader(f) for f in u_flow_list]) 76 | v_seq = fill_nones([flow_loader(f) for f in v_flow_list]) 77 | 78 | seq = [toPILImage(torch.cat([u, v])) for u, v in zip(u_seq, v_seq)] 79 | return seq 80 | 81 | 82 | def get_class_vid(vpath): 83 | return os.path.normpath(vpath).split('/')[-2:] 84 | 85 | 86 | def load_detectron_feature(fdir, idx, opt): 87 | # opt is either hm or seg 88 | 89 | shape = (192, 256) 90 | num_channels = 17 if opt == 'hm' else 1 91 | 92 | def load_feature(path): 93 | try: 94 | x = np.load(path)[opt] 95 | except: 96 | x = np.zeros((0, 0, 0)) 97 | 98 | # Match non-existent values 99 | if x.shape[1] == 0: 100 | x = np.zeros((num_channels, shape[0], shape[1])) 101 | 102 | x = torch.tensor(x, dtype=torch.float) / 255.0 103 | 104 | # Add extra channel in case it's not present 105 | if len(x.shape) < 3: 106 | x = x.unsqueeze(0) 107 | return x 108 | 109 | suffix = 'heatmap' if opt == 'hm' else 'segmask' 110 | fpath = os.path.join(fdir, suffix + '_%05d.npz' % idx) 111 | if os.path.isfile(fpath): 112 | return load_feature(fpath) 113 | else: 114 | # We do not have results lower than idx=2 115 | idx = max(3, idx) 116 | # We assume having all results for every two frames 117 | fpath0 = os.path.join(fdir, suffix + '_%05d.npz' % (idx - 1)) 118 | fpath1 = os.path.join(fdir, suffix + '_%05d.npz' % (idx + 1)) 119 | # This is not guaranteed to exist 120 | if not os.path.isfile(fpath1): 121 | fpath1 = fpath0 122 | a0, a1 = load_feature(fpath0), load_feature(fpath1) 123 | try: 124 | a_avg = (a0 + a1) / 2.0 125 | except: 126 | a_avg = None 127 | return a_avg 128 | 129 | 130 | def fetch_kp_heatmap_seq(vpath, idx_block): 131 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 132 | feature_vpath = vpath.replace('/frame/', '/heatmaps/') 133 | seq = fill_nones([load_detectron_feature(feature_vpath, idx, opt='hm') for idx in idx_block]) 134 | 135 | if len(set([x.shape for x in seq])) > 1: 136 | # We now know the invalid paths, so no need to print them 137 | # print("Invalid path:", vpath) 138 | seq = [seq[len(seq) // 2] for _ in seq] 139 | return seq 140 | 141 | 142 | def fetch_seg_mask_seq(vpath, idx_block): 143 | assert '/frame/' in vpath, "Incorrect vpath received: {}".format(vpath) 144 | feature_vpath = vpath.replace('/frame/', '/segmasks/') 145 | seq = fill_nones([load_detectron_feature(feature_vpath, idx, opt='seg') for idx in idx_block]) 146 | return seq 147 | 148 | 149 | class UCF101_3d(data.Dataset): 150 | def __init__(self, 151 | mode='train', 152 | transform=None, 153 | seq_len=10, 154 | num_seq =1, 155 | downsample=3, 156 | epsilon=5, 157 | which_split=1, 158 | modality=mu.ImgMode): 159 | self.mode = mode 160 | self.transform = transform 161 | self.seq_len = seq_len 162 | self.num_seq = num_seq 163 | self.downsample = downsample 164 | self.epsilon = epsilon 165 | self.which_split = which_split 166 | self.modality = modality 167 | 168 | # splits 169 | if mode == 'train': 170 | split = '../data/ucf101/train_split%02d.csv' % self.which_split 171 | video_info = pd.read_csv(split, header=None) 172 | elif (mode == 'val') or (mode == 'test'): 173 | split = '../data/ucf101/test_split%02d.csv' % self.which_split # use test for val, temporary 174 | video_info = pd.read_csv(split, header=None) 175 | else: raise ValueError('wrong mode') 176 | 177 | # get action list 178 | self.action_dict_encode = {} 179 | self.action_dict_decode = {} 180 | 181 | action_file = os.path.join('../data/ucf101', 'classInd.txt') 182 | action_df = pd.read_csv(action_file, sep=' ', header=None) 183 | for _, row in action_df.iterrows(): 184 | act_id, act_name = row 185 | act_id = int(act_id) - 1 # let id start from 0 186 | self.action_dict_decode[act_id] = act_name 187 | self.action_dict_encode[act_name] = act_id 188 | 189 | # filter out too short videos: 190 | drop_idx = [] 191 | for idx, row in video_info.iterrows(): 192 | vpath, vlen = row 193 | if vlen <= 0: 194 | drop_idx.append(idx) 195 | self.video_info = video_info.drop(drop_idx, axis=0) 196 | 197 | # if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 198 | # shuffle not required 199 | 200 | def idx_sampler(self, vlen, vpath): 201 | '''sample index from a video''' 202 | downsample = self.downsample 203 | if (vlen - (self.num_seq * self.seq_len * self.downsample)) <= 0: 204 | downsample = ((vlen - 1) / (self.num_seq * self.seq_len * 1.0)) * 0.9 205 | 206 | n = 1 207 | if self.mode == 'test': 208 | seq_idx_block = np.arange(0, vlen, downsample) # all possible frames with downsampling 209 | seq_idx_block = seq_idx_block.astype(int) 210 | return [seq_idx_block, vpath] 211 | start_idx = np.random.choice(range(vlen-int(self.num_seq*self.seq_len*downsample)), n) 212 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*downsample*self.seq_len + start_idx 213 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*downsample 214 | seq_idx_block = seq_idx_block.astype(int) 215 | return [seq_idx_block, vpath] 216 | 217 | def __getitem__(self, index): 218 | vpath, vlen = self.video_info.iloc[index] 219 | items = self.idx_sampler(vlen, vpath) 220 | if items is None: print(vpath) 221 | 222 | idx_block, vpath = items 223 | if self.mode != 'test': 224 | assert idx_block.shape == (self.num_seq, self.seq_len) 225 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 226 | 227 | seq = None 228 | if self.modality == mu.ImgMode: 229 | seq = fetch_imgs_seq(vpath, idx_block) 230 | elif self.modality == mu.FlowMode: 231 | seq = fetch_flow_seq(vpath, idx_block) 232 | elif self.modality == mu.KeypointHeatmap: 233 | seq = fetch_kp_heatmap_seq(vpath, idx_block) 234 | elif self.modality == mu.SegMask: 235 | seq = fetch_seg_mask_seq(vpath, idx_block) 236 | 237 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 238 | seq = torch.stack(seq) 239 | 240 | # if self.mode == 'test': 241 | # # apply same transform 242 | # t_seq = [self.transform(seq) for _ in range(5)] 243 | # else: 244 | t_seq = self.transform(seq) # apply same transform 245 | # Convert tensor into list of tensors 246 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 247 | t_seq = [t_seq[idx] for idx in range(t_seq.shape[0])] 248 | 249 | num_crop = None 250 | try: 251 | (C, H, W) = t_seq[0].size() 252 | t_seq = torch.stack(t_seq, 0) 253 | except: 254 | (C, H, W) = t_seq[0][0].size() 255 | tmp = [torch.stack(i, 0) for i in t_seq] 256 | assert len(tmp) == 5 257 | num_crop = 5 258 | t_seq = torch.stack(tmp, 1) 259 | 260 | if self.mode == 'test': 261 | # return all available clips, but cut into length = num_seq 262 | SL = t_seq.size(0) 263 | clips = []; i = 0 264 | while i+self.seq_len <= SL: 265 | clips.append(t_seq[i:i+self.seq_len, :]) 266 | # i += self.seq_len//2 267 | i += self.seq_len 268 | if num_crop: 269 | # half overlap: 270 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 271 | NC = len(clips) 272 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 273 | else: 274 | # half overlap: 275 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 276 | t_seq = torch.stack(clips, 0) 277 | else: 278 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 279 | 280 | try: 281 | vname = vpath.split('/')[-3] 282 | vid = self.encode_action(vname) 283 | except: 284 | vname = vpath.split('/')[-2] 285 | vid = self.encode_action(vname) 286 | 287 | label = torch.LongTensor([vid]) 288 | idx = torch.LongTensor([index]) 289 | 290 | return t_seq, label, idx 291 | 292 | def __len__(self): 293 | return len(self.video_info) 294 | 295 | def encode_action(self, action_name): 296 | '''give action name, return category''' 297 | return self.action_dict_encode[action_name] 298 | 299 | def decode_action(self, action_code): 300 | '''give action code, return action name''' 301 | return self.action_dict_decode[action_code] 302 | 303 | 304 | class HMDB51_3d(data.Dataset): 305 | def __init__(self, 306 | mode='train', 307 | transform=None, 308 | seq_len=10, 309 | num_seq=1, 310 | downsample=1, 311 | epsilon=5, 312 | which_split=1, 313 | modality=mu.ImgMode): 314 | self.mode = mode 315 | self.transform = transform 316 | self.seq_len = seq_len 317 | self.num_seq = num_seq 318 | self.downsample = downsample 319 | self.epsilon = epsilon 320 | self.which_split = which_split 321 | self.modality = modality 322 | 323 | # splits 324 | if mode == 'train': 325 | split = '../data/hmdb51/train_split%02d.csv' % self.which_split 326 | video_info = pd.read_csv(split, header=None) 327 | elif (mode == 'val') or (mode == 'test'): 328 | split = '../data/hmdb51/test_split%02d.csv' % self.which_split # use test for val, temporary 329 | video_info = pd.read_csv(split, header=None) 330 | else: raise ValueError('wrong mode') 331 | 332 | # get action list 333 | self.action_dict_encode = {} 334 | self.action_dict_decode = {} 335 | 336 | action_file = os.path.join('../data/hmdb51', 'classInd.txt') 337 | action_df = pd.read_csv(action_file, sep=' ', header=None) 338 | for _, row in action_df.iterrows(): 339 | act_id, act_name = row 340 | act_id = int(act_id) - 1 # let id start from 0 341 | self.action_dict_decode[act_id] = act_name 342 | self.action_dict_encode[act_name] = act_id 343 | 344 | # filter out too short videos: 345 | drop_idx = [] 346 | for idx, row in video_info.iterrows(): 347 | vpath, vlen = row 348 | if vlen <= 0: 349 | drop_idx.append(idx) 350 | self.video_info = video_info.drop(drop_idx, axis=0) 351 | 352 | # if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 353 | # shuffle not required 354 | 355 | def idx_sampler(self, vlen, vpath): 356 | '''sample index from a video''' 357 | downsample = self.downsample 358 | if (vlen - (self.num_seq * self.seq_len * self.downsample)) <= 0: 359 | downsample = ((vlen - 1) / (self.num_seq * self.seq_len * 1.0)) * 0.9 360 | 361 | n=1 362 | if self.mode == 'test': 363 | seq_idx_block = np.arange(0, vlen, downsample) # all possible frames with downsampling 364 | seq_idx_block = seq_idx_block.astype(int) 365 | return [seq_idx_block, vpath] 366 | start_idx = np.random.choice(range(vlen-int(self.num_seq*self.seq_len*downsample)), n) 367 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*downsample*self.seq_len + start_idx 368 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*downsample 369 | seq_idx_block = seq_idx_block.astype(int) 370 | return [seq_idx_block, vpath] 371 | 372 | def __getitem__(self, index): 373 | vpath, vlen = self.video_info.iloc[index] 374 | items = self.idx_sampler(vlen, vpath) 375 | if items is None: print(vpath) 376 | 377 | idx_block, vpath = items 378 | if self.mode != 'test': 379 | assert idx_block.shape == (self.num_seq, self.seq_len) 380 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 381 | 382 | seq = None 383 | if self.modality == mu.ImgMode: 384 | seq = fetch_imgs_seq(vpath, idx_block) 385 | elif self.modality == mu.FlowMode: 386 | seq = fetch_flow_seq(vpath, idx_block) 387 | elif self.modality == mu.KeypointHeatmap: 388 | seq = fetch_kp_heatmap_seq(vpath, idx_block) 389 | elif self.modality == mu.SegMask: 390 | seq = fetch_seg_mask_seq(vpath, idx_block) 391 | 392 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 393 | seq = torch.stack(seq) 394 | 395 | t_seq = self.transform(seq) # apply same transform 396 | # Convert tensor into list of tensors 397 | if self.modality in [mu.KeypointHeatmap, mu.SegMask]: 398 | t_seq = [t_seq[idx] for idx in range(t_seq.shape[0])] 399 | 400 | num_crop = None 401 | try: 402 | (C, H, W) = t_seq[0].size() 403 | t_seq = torch.stack(t_seq, 0) 404 | except: 405 | (C, H, W) = t_seq[0][0].size() 406 | tmp = [torch.stack(i, 0) for i in t_seq] 407 | assert len(tmp) == 5 408 | num_crop = 5 409 | t_seq = torch.stack(tmp, 1) 410 | # print(t_seq.size()) 411 | # import ipdb; ipdb.set_trace() 412 | if self.mode == 'test': 413 | # return all available clips, but cut into length = num_seq 414 | SL = t_seq.size(0) 415 | clips = []; i = 0 416 | while i+self.seq_len <= SL: 417 | clips.append(t_seq[i:i+self.seq_len, :]) 418 | # i += self.seq_len//2 419 | i += self.seq_len 420 | if num_crop: 421 | # half overlap: 422 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 423 | NC = len(clips) 424 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 425 | else: 426 | # half overlap: 427 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,3*self.num_seq//4)] 428 | t_seq = torch.stack(clips, 0) 429 | else: 430 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 431 | 432 | try: 433 | vname = vpath.split('/')[-3] 434 | vid = self.encode_action(vname) 435 | except: 436 | vname = vpath.split('/')[-2] 437 | vid = self.encode_action(vname) 438 | 439 | label = torch.LongTensor([vid]) 440 | idx = torch.LongTensor([index]) 441 | 442 | return t_seq, label, idx 443 | 444 | def __len__(self): 445 | return len(self.video_info) 446 | 447 | def encode_action(self, action_name): 448 | '''give action name, return category''' 449 | return self.action_dict_encode[action_name] 450 | 451 | def decode_action(self, action_code): 452 | '''give action code, return action name''' 453 | return self.action_dict_decode[action_code] 454 | 455 | -------------------------------------------------------------------------------- /train/model_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import namedtuple 3 | 4 | import data_utils 5 | import os 6 | 7 | 8 | from torch.utils import data 9 | from tensorboardX import SummaryWriter 10 | from torchvision import transforms 11 | from copy import deepcopy 12 | from collections import defaultdict 13 | 14 | from dataset_3d import * 15 | 16 | sys.path.append('../utils') 17 | from utils import AverageMeter, calc_hamming_loss, calc_mAP 18 | 19 | sys.path.append('../backbone') 20 | from resnet_2d3d import neq_load_customized 21 | 22 | import torch.nn as nn 23 | 24 | # Constants for the framework 25 | eps = 1e-7 26 | 27 | CPCLoss = "cpc" 28 | CooperativeLoss = "coop" 29 | SupervisionLoss = "super" 30 | DistillLoss = "distill" 31 | HierarchicalLoss = "hierarchical" 32 | WeighedHierarchicalLoss = "wgt-hier" 33 | 34 | # Losses for mode sync 35 | ModeSim = "sim" 36 | AlignLoss = "align" 37 | CosSimLoss = "cossim" 38 | CorrLoss = "corr" 39 | DenseCosSimLoss = "dcssim" 40 | DenseCorrLoss = "dcrr" 41 | 42 | # Sets of different losses 43 | LossList = [CPCLoss, CosSimLoss, CorrLoss, DenseCorrLoss, DenseCosSimLoss, 44 | CooperativeLoss, SupervisionLoss, DistillLoss, HierarchicalLoss, WeighedHierarchicalLoss] 45 | ModeSyncLossList = [CosSimLoss, CorrLoss, DenseCorrLoss, DenseCosSimLoss] 46 | 47 | # Type of base model 48 | ModelSSL = 'ssl' 49 | ModelSupervised = 'super' 50 | 51 | ImgMode = "imgs" 52 | AudioMode = "audio" 53 | FlowMode = "flow" 54 | FnbFlowMode = "farne" 55 | KeypointHeatmap = "kphm" 56 | SegMask = "seg" 57 | # FIXME: enable multiple views from the same modality 58 | ModeList = [ImgMode, AudioMode, 'imgs-0', 'imgs-1'] 59 | 60 | ModeParams = namedtuple('ModeParams', ['mode', 'img_fet_dim', 'img_fet_segments', 'final_dim']) 61 | 62 | 63 | def str2bool(s): 64 | """Convert string to bool (in argparse context).""" 65 | if s.lower() not in ['true', 'false']: 66 | raise ValueError('Need bool; got %r' % s) 67 | return {'true': True, 'false': False}[s.lower()] 68 | 69 | 70 | def str2list(s): 71 | """Convert string to list of strs, split on _""" 72 | return s.split('_') 73 | 74 | 75 | def get_multi_modal_model_train_args(): 76 | parser = argparse.ArgumentParser() 77 | 78 | # General global training parameters 79 | parser.add_argument('--save_dir', default='', type=str, help='dir to save intermediate results') 80 | parser.add_argument('--dataset', default='ucf101', type=str) 81 | parser.add_argument('--ds', default=3, type=int, help='frame downsampling rate') 82 | parser.add_argument('--seq_len', default=5, type=int, help='number of frames in each video block') 83 | parser.add_argument('--num_seq', default=8, type=int, help='number of video blocks') 84 | parser.add_argument('--pred_step', default=3, type=int) 85 | parser.add_argument('--batch_size', default=8, type=int) 86 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 87 | parser.add_argument('--eval-freq', default=1, type=int, help='frequency of evaluation') 88 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 89 | parser.add_argument('--print_freq', default=5, type=int, help='frequency of printing output during training') 90 | parser.add_argument('--num_workers', default=4, type=int, help='Number of workers for dataloader') 91 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 92 | parser.add_argument('--notes', default="", type=str, help='Additional notes') 93 | parser.add_argument('--vis_log_freq', default=100, type=int, help='Visualization frequency') 94 | 95 | # Evaluation specific flags 96 | parser.add_argument('--ft_freq', default=10, type=int, help='frequency to perform finetuning') 97 | 98 | # Global network and model details. Can be overriden using specific flags 99 | parser.add_argument('--model', default='super', type=str, help='Options: ssl, super') 100 | parser.add_argument('--net', default='resnet18', type=str) 101 | parser.add_argument('--train_what', default='all', type=str) 102 | parser.add_argument('--img_dim', default=128, type=int) 103 | parser.add_argument('--sampling', default="dynamic", type=str, help='sampling method (disjoint, random, dynamic)') 104 | parser.add_argument('--temp', default=0.07, type=float, help='Temperature to use with L2 normalization') 105 | parser.add_argument('--attention', default=False, type=str2bool, help='Whether to use attention') 106 | 107 | # Knowledge distillation, hierarchical flags 108 | parser.add_argument('--distill', default=False, type=str2bool, help='Whether to distill knowledge') 109 | parser.add_argument('--students', default="imgs-0", type=str2list, help='Modalities which are students') 110 | parser.add_argument('--hierarchical', default=False, type=str2bool, help='Whether to use hierarchical loss') 111 | 112 | # Training specific flags 113 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 114 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 115 | parser.add_argument('--losses', default="cpc", type=str2list, help='Losses to use (CPC, Align, Rep, Sim)') 116 | parser.add_argument('--dropout', default=0.3, type=float, help='Dropout to use for supervised training') 117 | parser.add_argument('--tune_bb', default=-1.0, type=float, 118 | help='Fine-tune back-bone lr degradation. Useful for pretrained weights. (0.5, 0.1, 0.05)') 119 | 120 | # Hyper-parameters 121 | parser.add_argument('--msync_wt', default=1.0, type=float, help='Loss weight to use for mode sync loss') 122 | parser.add_argument('--dot_wt', default=1.0, type=float, help='Dot weight to use for cooperative loss') 123 | 124 | # Multi-modal related flags 125 | parser.add_argument('--data_sources', default='imgs', type=str2list, help='data sources separated by _') 126 | parser.add_argument('--modalities', default="imgs", type=str2list, help='Modalitiles to consider. Separate by _') 127 | 128 | # Checkpoint flags 129 | for m in ModeList: 130 | parser.add_argument('--{}_restore_ckpt'.format(m), default=None, type=str, 131 | help='Restore checkpoint for {}'.format(m)) 132 | 133 | # Flags which need not be touched 134 | parser.add_argument('--resume', default='', type=str, help='path of model to resume') 135 | parser.add_argument('--pretrain', default='', type=str, help='path of pretrained model') 136 | parser.add_argument('--prefix', default='noprefix', type=str, help='prefix of checkpoint filename') 137 | 138 | # supervision categories 139 | parser.add_argument('--multilabel_supervision', action='store_true', help='allowing multiple categories in ' 140 | 'supervised training') 141 | # Extra arguments 142 | parser.add_argument('--debug', default=False, type=str2bool, help='Reduces latency for data ops') 143 | 144 | # Testing flags 145 | parser.add_argument('--test', action='store_true', help='Perform testing on the sample') 146 | parser.add_argument('--test_split', default='test', help='Which split to perform testing on (val, test)') 147 | 148 | # wandb 149 | parser.add_argument('--wandb_project_name', default="", type=str, help='wandb project name') 150 | 151 | return parser 152 | 153 | 154 | def get_num_classes(dataset): 155 | if 'kinetics' in dataset: 156 | return 400 157 | elif dataset == 'ucf101': 158 | return 101 159 | elif dataset == 'jhmdb': 160 | return 21 161 | elif dataset == 'hmdb51': 162 | return 51 163 | elif dataset == 'panasonic': 164 | return 75 165 | elif dataset == 'panasonic-atomic': 166 | return 448 167 | elif dataset == 'hierarchical-panasonic': 168 | return 75 169 | else: 170 | return None 171 | 172 | 173 | def get_transforms(args): 174 | return { 175 | ImgMode: get_imgs_transforms(args), 176 | AudioMode: get_audio_transforms(args), 177 | } 178 | 179 | 180 | def get_test_transforms(args): 181 | return { 182 | ImgMode: get_imgs_test_transforms(args), 183 | AudioMode: get_audio_transforms(args), 184 | } 185 | 186 | 187 | def convert_to_dict(args): 188 | if type(args) != dict: 189 | args_dict = vars(args) 190 | else: 191 | args_dict = args 192 | return args_dict 193 | 194 | 195 | def get_imgs_test_transforms(args): 196 | args_dict = convert_to_dict(args) 197 | 198 | transform = transforms.Compose([ 199 | CenterCrop(size=224, consistent=True), 200 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 201 | ToTensor(), 202 | Normalize() 203 | ]) 204 | 205 | return transform 206 | 207 | 208 | def get_imgs_transforms(args): 209 | 210 | args_dict = convert_to_dict(args) 211 | transform = None 212 | 213 | if args_dict["debug"]: 214 | return transforms.Compose([ 215 | CenterCrop(size=224, consistent=True), 216 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 217 | ToTensor(), 218 | Normalize() 219 | ]) 220 | 221 | if 'panasonic' in args_dict["dataset"]: 222 | transform = transforms.Compose([ 223 | RandomHorizontalFlip(consistent=True), 224 | PadToSize(size=(256, 256)), 225 | RandomCrop(size=224, consistent=True), 226 | Scale(size=(args_dict["img_dim"], args_dict["img_dim"])), 227 | RandomGray(consistent=False, p=0.5), 228 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 229 | ToTensor(), 230 | Normalize() 231 | ]) 232 | 233 | return transform 234 | 235 | 236 | def get_audio_transforms(args): 237 | 238 | return transforms.Compose([transforms.ToTensor()]) 239 | 240 | 241 | def get_writers(img_path): 242 | 243 | try: # old version 244 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 245 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 246 | except: # v1.7 247 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 248 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 249 | 250 | return writer_train, writer_val 251 | 252 | 253 | def get_dataset_loaders(args, transform, mode='train', test_split=None): 254 | ''' 255 | test_split is relevant in case of testing, either val or test 256 | ''' 257 | 258 | print('Loading data for "%s" ...' % mode) 259 | 260 | if type(args) != dict: 261 | args_dict = deepcopy(vars(args)) 262 | else: 263 | args_dict = args 264 | 265 | if args_dict['debug']: 266 | orig_mode = mode 267 | mode = 'train' 268 | 269 | if args_dict['test']: 270 | if test_split is None: 271 | test_split = mode 272 | # Only use the hierarchical panasonic for test 273 | dataset = HierarchicalPanasonic( 274 | mode=test_split, 275 | transform=transform, 276 | seq_len=args_dict["seq_len"], 277 | num_seq=args_dict["num_seq"], 278 | downsample=args_dict["ds"], 279 | vals_to_return=args_dict['data_sources'].split('_') + ["labels"], 280 | sampling='all', 281 | ) 282 | elif args_dict["dataset"] == 'hierarchical-panasonic': 283 | dataset = HierarchicalPanasonic( 284 | mode=mode, 285 | transform=transform, 286 | seq_len=args_dict["seq_len"], 287 | num_seq=args_dict["num_seq"], 288 | downsample=args_dict["ds"], 289 | vals_to_return=args_dict["data_sources"].split('_'), 290 | debug=args_dict["debug"], 291 | sampling='single') 292 | elif args_dict["dataset"].startswith('panasonic'): 293 | dataset = Panasonic_3d( 294 | mode=mode, 295 | transform=transform, 296 | seq_len=args_dict["seq_len"], 297 | num_seq=args_dict["num_seq"], 298 | downsample=args_dict["ds"], 299 | vals_to_return=args_dict["data_sources"].split('_'), 300 | debug=args_dict["debug"], 301 | dataset=args_dict["dataset"].split('-')[0], 302 | postfix='atomic' if args_dict["dataset"].endswith('atomic') else '', 303 | multilabel_supervision=args_dict["multilabel_supervision"]) 304 | else: 305 | raise ValueError('dataset not supported') 306 | 307 | val_sampler = data.SequentialSampler(dataset) 308 | train_sampler = data.RandomSampler(dataset) 309 | 310 | if args_dict["debug"]: 311 | if orig_mode == 'val': 312 | train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=100) 313 | else: 314 | train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=200) 315 | val_sampler = data.RandomSampler(dataset) 316 | 317 | data_loader = None 318 | if mode == 'train': 319 | data_loader = data.DataLoader(dataset, 320 | batch_size=args_dict["batch_size"], 321 | sampler=train_sampler, 322 | shuffle=False, 323 | num_workers=args_dict["num_workers"], 324 | collate_fn=data_utils.individual_collate, 325 | pin_memory=True, 326 | drop_last=True) 327 | elif mode == 'val': 328 | data_loader = data.DataLoader(dataset, 329 | sampler=val_sampler, 330 | batch_size=args_dict["batch_size"], 331 | shuffle=False, 332 | num_workers=args_dict["num_workers"], 333 | collate_fn=data_utils.individual_collate, 334 | pin_memory=True, 335 | # Do not change drop last to false, integration issues 336 | drop_last=True) 337 | elif mode == 'test': 338 | test_sampler = val_sampler 339 | data_loader = data.DataLoader(dataset, 340 | sampler=test_sampler, 341 | batch_size=1, 342 | shuffle=False, 343 | num_workers=args_dict["num_workers"], 344 | collate_fn=data_utils.individual_collate, 345 | pin_memory=True, 346 | # Do not change drop last to false, integration issues 347 | drop_last=True) 348 | 349 | print('"%s" dataset size: %d' % (mode, len(dataset))) 350 | return data_loader 351 | 352 | 353 | def set_multi_modal_path(args): 354 | if args.resume: 355 | exp_path = os.path.dirname(os.path.dirname(args.resume)) 356 | else: 357 | args.modes_str = '_'.join(args.modes) 358 | args.model_str = 'pred{}'.format(args.pred_step) 359 | args.loss_str = 'loss_{}'.format('_'.join(args.losses)) 360 | if args.model == ModelSupervised: 361 | args.model_str = 'supervised' 362 | exp_path = '{0}/logs/{args.prefix}/{args.dataset}-{args.img_dim}_{1}_' \ 363 | 'bs{args.batch_size}_seq{args.num_seq}_len{args.seq_len}_{args.model_str}_loss-{args.loss_str}' \ 364 | '_ds{args.ds}_train-{args.train_what}{2}_modes-{args.modes_str}_multilabel_' \ 365 | '{args.multilabel_supervision}_attention-{args.attention}_hierarchical-{args.hierarchical}_' \ 366 | '{args.notes}'.format( 367 | os.environ['BASE_DIR'], 368 | 'r%s' % args.net[6::], 369 | '_pt=%s' % args.pretrain.replace('/','-') if args.pretrain else '', 370 | args=args 371 | ) 372 | exp_path = os.path.join(args.save_dir, exp_path) 373 | 374 | img_path = os.path.join(exp_path, 'img') 375 | model_path = os.path.join(exp_path, 'model') 376 | if not os.path.exists(img_path): os.makedirs(img_path) 377 | if not os.path.exists(model_path): os.makedirs(model_path) 378 | return img_path, model_path 379 | 380 | 381 | def process_output(mask): 382 | '''task mask as input, compute the target for contrastive loss''' 383 | # dot product is computed in parallel gpus, so get less easy neg, bounded by batch size in each gpu''' 384 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 385 | (B, NP, SQ, B2, NS, _) = mask.size() # [B, P, SQ, B, N, SQ] 386 | target = mask == 1 387 | target.requires_grad = False 388 | return target, (B, B2, NS, NP, SQ) 389 | 390 | 391 | def check_name_to_be_avoided(k): 392 | # modules_to_avoid = ['.agg.', '.network_pred.'] 393 | modules_to_avoid = [] 394 | for m in modules_to_avoid: 395 | if m in k: 396 | return True 397 | return False 398 | 399 | 400 | def load_model(model, model_path): 401 | if os.path.isfile(model_path): 402 | print("=> loading resumed checkpoint '{}'".format(model_path)) 403 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 404 | model = neq_load_customized(model, checkpoint['state_dict']) 405 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(model_path, checkpoint['epoch'])) 406 | else: 407 | print("[WARNING] no checkpoint found at '{}'".format(model_path)) 408 | return model 409 | 410 | 411 | def get_stats_dict(losses_dict, stats): 412 | postfix_dict = {} 413 | 414 | def get_short_key(a, b, c): 415 | return '{}/{}/{}'.format(a, b, c) 416 | 417 | # Populate accuracies 418 | for loss in stats.keys(): 419 | for mode in stats[loss].keys(): 420 | for stat, meter in stats[loss][mode].items(): 421 | key = get_short_key(loss, mode, str(stat)) 422 | # FIXME: temporary fix 423 | if stat == 'multilabel_counts': 424 | # FIXME: Removing this as it pollutes logs rn. Move this to test instead of val 425 | pass 426 | # val = meter 427 | # postfix_dict[key] = val 428 | else: 429 | val = meter.avg if eval else meter.local_avg 430 | postfix_dict[key] = round(val, 3) 431 | 432 | # Populate losses 433 | for loss in losses_dict.keys(): 434 | for key, meter in losses_dict[loss].items(): 435 | val = meter.avg if eval else meter.local_avg 436 | postfix_dict[get_short_key('loss', loss, key)] = round(val, 3) 437 | 438 | return postfix_dict 439 | 440 | 441 | def compute_val_metrics(stats, prefix=None): 442 | val_stats = dict() 443 | keys_to_remove = [] 444 | for key in stats.keys(): 445 | if 'multilabel_counts' in key: 446 | k = '_'.join(key.split('_')[:2]) if prefix is None else prefix 447 | val_stats['{}_multilabel_hamming_loss'.format(k)] = calc_hamming_loss(stats[key]['tp'], stats[key]['tn'], stats[key]['fp'], stats[key]['fn']) 448 | val_stats['{}_multilabel_mAP'.format(k)] = calc_mAP(stats[key]['tp'], stats[key]['fp']) 449 | val_stats['{}_multilabel_acc1'.format(k)] = stats[key]['top1_single'].cpu().data.numpy() / stats[key]['all_single'].cpu().data.numpy() 450 | val_stats['{}_multilabel_acc3'.format(k)] = stats[key]['top3_single'].cpu().data.numpy() / stats[key]['all_single'].cpu().data.numpy() 451 | keys_to_remove.append(key) 452 | for key in keys_to_remove: 453 | del stats[key] 454 | stats.update(val_stats) 455 | 456 | 457 | def shorten_stats(overall_stats): 458 | 459 | def get_short_key(a, b, c): 460 | return '{}{}{}'.format(a[0], b[0] + b[-1], c[0] + c[-1]) 461 | 462 | shortened_stats = {} 463 | for k, v in overall_stats.items(): 464 | a, b, c = k.split('/') 465 | # Don't include accuracy 3 stats or align 466 | shouldInclude = (c[-1] != '3') and (a != 'align') 467 | if shouldInclude: 468 | shortened_stats[get_short_key(a, b, c)] = round(v, 2) 469 | 470 | return shortened_stats 471 | 472 | 473 | def init_loggers(losses): 474 | losses_dict = defaultdict(lambda: defaultdict(lambda: AverageMeter())) 475 | stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: AverageMeter()))) 476 | return losses_dict, stats 477 | 478 | 479 | class BinaryFocalLossWithLogits(nn.Module): 480 | 481 | def __init__(self, gamma=2, alpha_ratio=10, eps=1e-7): 482 | super(BinaryFocalLossWithLogits, self).__init__() 483 | self.gamma = gamma 484 | self.alpha_ratio = alpha_ratio 485 | self.eps = eps 486 | 487 | def forward(self, input, y): 488 | logit = torch.sigmoid(input) 489 | logit = logit.clamp(self.eps, 1. - self.eps) 490 | 491 | loss = -(y * torch.log(logit) + (1 - y) * torch.log(1 - logit)) # cross entropy 492 | alpha = self.alpha_ratio * y + (1 - y) 493 | p_t = logit * y + (1 - logit) * (1 - y) 494 | loss = alpha * loss * (1 - p_t) ** self.gamma # focal loss 495 | 496 | return loss.mean() -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | import numpy as np 6 | from PIL import ImageOps, Image 7 | from joblib import Parallel, delayed 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as F 12 | 13 | 14 | class Padding: 15 | def __init__(self, pad): 16 | self.pad = pad 17 | 18 | def __call__(self, img): 19 | return ImageOps.expand(img, border=self.pad, fill=0) 20 | 21 | 22 | class PadToSize: 23 | def __init__(self, size): 24 | self.size = size 25 | 26 | def __call__(self, img_list, color=None): 27 | return [ImageOps.pad(img, self.size, color=color) for img in img_list] 28 | 29 | 30 | class Scale: 31 | def __init__(self, size, interpolation=Image.NEAREST): 32 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 33 | self.size = size 34 | self.interpolation = interpolation 35 | 36 | def __call__(self, imgmap): 37 | # assert len(imgmap) > 1 # list of images 38 | img1 = imgmap[0] 39 | if isinstance(self.size, int): 40 | w, h = img1.size 41 | if (w <= h and w == self.size) or (h <= w and h == self.size): 42 | return imgmap 43 | if w < h: 44 | ow = self.size 45 | oh = int(self.size * h / w) 46 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 47 | else: 48 | oh = self.size 49 | ow = int(self.size * w / h) 50 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 51 | else: 52 | return [i.resize(self.size, self.interpolation) for i in imgmap] 53 | 54 | 55 | class ScaleForTensors: 56 | def __init__(self, size, interpolation=Image.NEAREST): 57 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 58 | self.size = size 59 | self.interpolation = interpolation 60 | self.toTensor = transforms.ToTensor() 61 | self.toPILImage = transforms.ToPILImage() 62 | 63 | def resize_multi_channel_image(self, img_tensor_list, size): 64 | c, h, w = img_tensor_list[0].shape 65 | assert c < 20, "Invalid shape: {}".format(img_tensor_list.shape) 66 | 67 | resized_channels = [ 68 | torch.stack([ 69 | self.toTensor(self.toPILImage(img_tensor_list[idx][c]).resize(size, self.interpolation)).squeeze(0) 70 | for c in range(img_tensor_list[idx].shape[0]) 71 | ]) for idx in range(len(img_tensor_list)) 72 | ] 73 | resized_img_tensor = torch.stack(resized_channels) 74 | assert resized_img_tensor[0].shape == (c, size[0], size[1]), \ 75 | "Invalid shape: {}, orig: {}".format(resized_img_tensor.shape, img_tensor_list[0].shape) 76 | return resized_img_tensor 77 | 78 | def __call__(self, img_tensor_list): 79 | # assert len(imgmap) > 1 # list of images 80 | img1 = img_tensor_list[0] 81 | if isinstance(self.size, int): 82 | c, h, w = img1.shape 83 | if (w <= h and w == self.size) or (h <= w and h == self.size): 84 | return img_tensor_list 85 | if w < h: 86 | ow = self.size 87 | oh = int(self.size * h / w) 88 | return self.resize_multi_channel_image(img_tensor_list, (ow, oh)) 89 | else: 90 | oh = self.size 91 | ow = int(self.size * w / h) 92 | return self.resize_multi_channel_image(img_tensor_list, (ow, oh)) 93 | else: 94 | return self.resize_multi_channel_image(img_tensor_list, self.size) 95 | 96 | 97 | class CenterCrop: 98 | def __init__(self, size, consistent=True): 99 | if isinstance(size, numbers.Number): 100 | self.size = (int(size), int(size)) 101 | else: 102 | self.size = size 103 | 104 | def __call__(self, imgmap): 105 | img1 = imgmap[0] 106 | w, h = img1.size 107 | th, tw = self.size 108 | x1 = int(round((w - tw) / 2.)) 109 | y1 = int(round((h - th) / 2.)) 110 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 111 | 112 | 113 | class RandomCropWithProb: 114 | def __init__(self, size, p=0.8, consistent=True): 115 | if isinstance(size, numbers.Number): 116 | self.size = (int(size), int(size)) 117 | else: 118 | self.size = size 119 | self.consistent = consistent 120 | self.threshold = p 121 | 122 | def __call__(self, imgmap): 123 | img1 = imgmap[0] 124 | w, h = img1.size 125 | if self.size is not None: 126 | th, tw = self.size 127 | if w == tw and h == th: 128 | return imgmap 129 | if self.consistent: 130 | if random.random() < self.threshold: 131 | x1 = random.randint(0, w - tw) 132 | y1 = random.randint(0, h - th) 133 | else: 134 | x1 = int(round((w - tw) / 2.)) 135 | y1 = int(round((h - th) / 2.)) 136 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 137 | else: 138 | result = [] 139 | for i in imgmap: 140 | if random.random() < self.threshold: 141 | x1 = random.randint(0, w - tw) 142 | y1 = random.randint(0, h - th) 143 | else: 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 147 | return result 148 | else: 149 | return imgmap 150 | 151 | 152 | class RandomCrop: 153 | def __init__(self, size, consistent=True): 154 | if isinstance(size, numbers.Number): 155 | self.size = (int(size), int(size)) 156 | else: 157 | self.size = size 158 | self.consistent = consistent 159 | 160 | def __call__(self, imgmap, flowmap=None): 161 | img1 = imgmap[0] 162 | w, h = img1.size 163 | if self.size is not None: 164 | th, tw = self.size 165 | if w == tw and h == th: 166 | return imgmap 167 | if not flowmap: 168 | if self.consistent: 169 | x1 = random.randint(0, w - tw) 170 | y1 = random.randint(0, h - th) 171 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 172 | else: 173 | result = [] 174 | for i in imgmap: 175 | x1 = random.randint(0, w - tw) 176 | y1 = random.randint(0, h - th) 177 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 178 | return result 179 | elif flowmap is not None: 180 | assert (not self.consistent) 181 | result = [] 182 | for idx, i in enumerate(imgmap): 183 | proposal = [] 184 | for j in range(3): # number of proposal: use the one with largest optical flow 185 | x = random.randint(0, w - tw) 186 | y = random.randint(0, h - th) 187 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 188 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 189 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 190 | return result 191 | else: 192 | raise ValueError('wrong case') 193 | else: 194 | return imgmap 195 | 196 | 197 | import torch 198 | 199 | 200 | class RandomIntensityCropForTensors: 201 | def __init__(self, size): 202 | if isinstance(size, numbers.Number): 203 | self.size = (int(size), int(size)) 204 | else: 205 | self.size = size 206 | 207 | def __call__(self, img_tensor_list): 208 | img1 = img_tensor_list[0] 209 | # Expected format 210 | c, h, w = img1.shape 211 | assert c < 20, "Invalid channel size: {}".format(img1.shape) 212 | 213 | if self.size is not None: 214 | th, tw = self.size 215 | if w == tw and h == th: 216 | return img_tensor_list 217 | 218 | proposals = [] 219 | # number of proposal: use the one with largest sum of values 220 | for j in range(3): 221 | x = random.randint(0, w - tw) 222 | y = random.randint(0, h - th) 223 | val = \ 224 | sum([torch.mean(torch.abs(img_tensor_list[idx][:, y:y + th, x:x + tw])) for idx in range(len(img_tensor_list))]) 225 | proposals.append(((x, y), val)) 226 | 227 | (x, y), _ = max(proposals, key=lambda x: x[1]) 228 | crops = [i[:, y:y + th, x:x + tw] for i in img_tensor_list] 229 | return crops 230 | else: 231 | return img_tensor_list 232 | 233 | 234 | class RandomIntensityCropForFlow: 235 | def __init__(self, size): 236 | if isinstance(size, numbers.Number): 237 | self.size = (int(size), int(size)) 238 | else: 239 | self.size = size 240 | 241 | def __call__(self, imgmap): 242 | img1 = imgmap[0] 243 | w, h = img1.size 244 | if self.size is not None: 245 | th, tw = self.size 246 | if w == tw and h == th: 247 | return imgmap 248 | 249 | proposals = [] 250 | 251 | # Process img_arrs 252 | img_arrs = [np.asarray(img, dtype=float) for img in imgmap] 253 | img_arrs = [(img * 0.0) + 127. if np.max(img) < 10.0 else img for img in img_arrs] 254 | # Assuming that flow data passed has mean > 100.0 255 | img_arrs = np.array(img_arrs) - 127. 256 | 257 | # number of proposal: use the one with largest sum of values 258 | for j in range(3): 259 | try: 260 | x = random.randint(0, w - tw) 261 | y = random.randint(0, h - th) 262 | except: 263 | print("Error:", w, h, tw, th, img_arrs.shape) 264 | val = np.mean(np.abs(img_arrs[:, y:y + th, x:x + tw, :])) 265 | proposals.append(((x, y), val)) 266 | 267 | (x, y), _ = max(proposals, key=lambda x: x[1]) 268 | crops = [i.crop((x, y, x + tw, y + th)) for i in imgmap] 269 | return crops 270 | else: 271 | return imgmap 272 | 273 | 274 | class CenterCropForTensors: 275 | def __init__(self, size): 276 | if isinstance(size, numbers.Number): 277 | self.size = (int(size), int(size)) 278 | else: 279 | self.size = size 280 | 281 | def __call__(self, img_tensor_list): 282 | img1 = img_tensor_list[0] 283 | # Expected format 284 | c, h, w = img1.shape 285 | assert c < 20, "Invalid channel size: {}".format(img1.shape) 286 | 287 | th, tw = self.size 288 | x = int(round((w - tw) / 2.)) 289 | y = int(round((h - th) / 2.)) 290 | try: 291 | result = [img_tensor[:, y:y + th, x:x + tw] for img_tensor in img_tensor_list] 292 | except: 293 | print(img_tensor_list[0].shape, y, th, x, tw) 294 | return result 295 | 296 | 297 | class RandomSizedCrop: 298 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 299 | self.size = size 300 | self.interpolation = interpolation 301 | self.consistent = consistent 302 | self.threshold = p 303 | 304 | def __call__(self, imgmap): 305 | img1 = imgmap[0] 306 | if random.random() < self.threshold: # do RandomSizedCrop 307 | for attempt in range(10): 308 | area = img1.size[0] * img1.size[1] 309 | target_area = random.uniform(0.5, 1) * area 310 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 311 | 312 | w = int(round(math.sqrt(target_area * aspect_ratio))) 313 | h = int(round(math.sqrt(target_area / aspect_ratio))) 314 | 315 | if self.consistent: 316 | if random.random() < 0.5: 317 | w, h = h, w 318 | if w <= img1.size[0] and h <= img1.size[1]: 319 | x1 = random.randint(0, img1.size[0] - w) 320 | y1 = random.randint(0, img1.size[1] - h) 321 | 322 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 323 | for i in imgmap: assert(i.size == (w, h)) 324 | 325 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 326 | else: 327 | result = [] 328 | for i in imgmap: 329 | if random.random() < 0.5: 330 | w, h = h, w 331 | if w <= img1.size[0] and h <= img1.size[1]: 332 | x1 = random.randint(0, img1.size[0] - w) 333 | y1 = random.randint(0, img1.size[1] - h) 334 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 335 | assert(result[-1].size == (w, h)) 336 | else: 337 | result.append(i) 338 | 339 | assert len(result) == len(imgmap) 340 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 341 | 342 | # Fallback 343 | scale = Scale(self.size, interpolation=self.interpolation) 344 | crop = CenterCrop(self.size) 345 | return crop(scale(imgmap)) 346 | else: # don't do RandomSizedCrop, do CenterCrop 347 | crop = CenterCrop(self.size) 348 | return crop(imgmap) 349 | 350 | 351 | class RandomHorizontalFlip: 352 | def __init__(self, consistent=True, command=None): 353 | self.consistent = consistent 354 | if command == 'left': 355 | self.threshold = 0 356 | elif command == 'right': 357 | self.threshold = 1 358 | else: 359 | self.threshold = 0.5 360 | 361 | def __call__(self, imgmap): 362 | if self.consistent: 363 | if random.random() < self.threshold: 364 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 365 | else: 366 | return imgmap 367 | else: 368 | result = [] 369 | for i in imgmap: 370 | if random.random() < self.threshold: 371 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 372 | else: 373 | result.append(i) 374 | assert len(result) == len(imgmap) 375 | return result 376 | 377 | 378 | class RandomGray: 379 | '''Actually it is a channel splitting, not strictly grayscale images''' 380 | def __init__(self, consistent=True, p=0.5): 381 | self.consistent = consistent 382 | self.p = p # probability to apply grayscale 383 | 384 | def __call__(self, imgmap): 385 | if self.consistent: 386 | if random.random() < self.p: 387 | return [self.grayscale(i) for i in imgmap] 388 | else: 389 | return imgmap 390 | else: 391 | result = [] 392 | for i in imgmap: 393 | if random.random() < self.p: 394 | result.append(self.grayscale(i)) 395 | else: 396 | result.append(i) 397 | assert len(result) == len(imgmap) 398 | return result 399 | 400 | def grayscale(self, img): 401 | channel = np.random.choice(3) 402 | np_img = np.array(img)[:,:,channel] 403 | np_img = np.dstack([np_img, np_img, np_img]) 404 | img = Image.fromarray(np_img, 'RGB') 405 | return img 406 | 407 | 408 | class ColorJitter(object): 409 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code 410 | Args: 411 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 412 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 413 | or the given [min, max]. Should be non negative numbers. 414 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 415 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 416 | or the given [min, max]. Should be non negative numbers. 417 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 418 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 419 | or the given [min, max]. Should be non negative numbers. 420 | hue (float or tuple of float (min, max)): How much to jitter hue. 421 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 422 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 423 | """ 424 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 425 | self.brightness = self._check_input(brightness, 'brightness') 426 | self.contrast = self._check_input(contrast, 'contrast') 427 | self.saturation = self._check_input(saturation, 'saturation') 428 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 429 | clip_first_on_zero=False) 430 | self.consistent = consistent 431 | self.threshold = p 432 | 433 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 434 | if isinstance(value, numbers.Number): 435 | if value < 0: 436 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 437 | value = [center - value, center + value] 438 | if clip_first_on_zero: 439 | value[0] = max(value[0], 0) 440 | elif isinstance(value, (tuple, list)) and len(value) == 2: 441 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 442 | raise ValueError("{} values should be between {}".format(name, bound)) 443 | else: 444 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 445 | 446 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 447 | # or (0., 0.) for hue, do nothing 448 | if value[0] == value[1] == center: 449 | value = None 450 | return value 451 | 452 | @staticmethod 453 | def get_params(brightness, contrast, saturation, hue): 454 | """Get a randomized transform to be applied on image. 455 | Arguments are same as that of __init__. 456 | Returns: 457 | Transform which randomly adjusts brightness, contrast and 458 | saturation in a random order. 459 | """ 460 | transforms = [] 461 | 462 | if brightness is not None: 463 | brightness_factor = random.uniform(brightness[0], brightness[1]) 464 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 465 | 466 | if contrast is not None: 467 | contrast_factor = random.uniform(contrast[0], contrast[1]) 468 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 469 | 470 | if saturation is not None: 471 | saturation_factor = random.uniform(saturation[0], saturation[1]) 472 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 473 | 474 | if hue is not None: 475 | hue_factor = random.uniform(hue[0], hue[1]) 476 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 477 | 478 | random.shuffle(transforms) 479 | transform = torchvision.transforms.Compose(transforms) 480 | 481 | return transform 482 | 483 | def __call__(self, imgmap): 484 | if random.random() < self.threshold: # do ColorJitter 485 | if self.consistent: 486 | transform = self.get_params(self.brightness, self.contrast, 487 | self.saturation, self.hue) 488 | return [transform(i) for i in imgmap] 489 | else: 490 | result = [] 491 | for img in imgmap: 492 | transform = self.get_params(self.brightness, self.contrast, 493 | self.saturation, self.hue) 494 | result.append(transform(img)) 495 | return result 496 | else: # don't do ColorJitter, do nothing 497 | return imgmap 498 | 499 | def __repr__(self): 500 | format_string = self.__class__.__name__ + '(' 501 | format_string += 'brightness={0}'.format(self.brightness) 502 | format_string += ', contrast={0}'.format(self.contrast) 503 | format_string += ', saturation={0}'.format(self.saturation) 504 | format_string += ', hue={0})'.format(self.hue) 505 | return format_string 506 | 507 | 508 | class RandomRotation: 509 | def __init__(self, consistent=True, degree=15, p=1.0): 510 | self.consistent = consistent 511 | self.degree = degree 512 | self.threshold = p 513 | def __call__(self, imgmap): 514 | if random.random() < self.threshold: # do RandomRotation 515 | if self.consistent: 516 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 517 | return [i.rotate(deg, expand=True) for i in imgmap] 518 | else: 519 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 520 | else: # don't do RandomRotation, do nothing 521 | return imgmap 522 | 523 | class ToTensor: 524 | def __call__(self, imgmap): 525 | totensor = transforms.ToTensor() 526 | return [totensor(i) for i in imgmap] 527 | 528 | class Normalize: 529 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 530 | self.mean = mean 531 | self.std = std 532 | def __call__(self, imgmap): 533 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 534 | return [normalize(i) for i in imgmap] 535 | 536 | 537 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import pickle 6 | import re 7 | import numpy as np 8 | import transform_utils as tu 9 | 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | 13 | sys.path.append('../utils') 14 | sys.path.append('../backbone') 15 | from dataset_3d_lc import UCF101_3d, HMDB51_3d 16 | from model_3d_lc import * 17 | from resnet_2d3d import neq_load_customized 18 | from augmentation import * 19 | from utils import AverageMeter, AccuracyTable, ConfusionMeter, save_checkpoint, write_log, calc_topk_accuracy, denorm, calc_accuracy 20 | 21 | import torch 22 | import torch.optim as optim 23 | from torch.utils import data 24 | import torch.nn as nn 25 | from torchvision import datasets, models, transforms 26 | import torchvision.utils as vutils 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--save_dir', default='/data/nishantr/svl/', type=str, help='dir to save intermediate results') 30 | parser.add_argument('--net', default='resnet18', type=str) 31 | parser.add_argument('--model', default='lc', type=str) 32 | parser.add_argument('--dataset', default='ucf101', type=str) 33 | parser.add_argument('--modality', required=True, type=str, help="Modality to use") 34 | parser.add_argument('--split', default=1, type=int) 35 | parser.add_argument('--seq_len', default=5, type=int) 36 | parser.add_argument('--num_seq', default=8, type=int) 37 | parser.add_argument('--num_class', default=101, type=int) 38 | parser.add_argument('--dropout', default=0.5, type=float) 39 | parser.add_argument('--ds', default=3, type=int) 40 | parser.add_argument('--batch_size', default=4, type=int) 41 | parser.add_argument('--lr', default=1e-3, type=float) 42 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 43 | parser.add_argument('--resume', default='', type=str) 44 | parser.add_argument('--pretrain', default='random', type=str) 45 | parser.add_argument('--test', default='', type=str) 46 | parser.add_argument('--extensive', default=0, type=int) 47 | parser.add_argument('--epochs', default=50, type=int, help='number of total epochs to run') 48 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 49 | parser.add_argument('--print_freq', default=5, type=int) 50 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 51 | parser.add_argument('--train_what', default='last', type=str, help='Train what parameters?') 52 | parser.add_argument('--prefix', default='tmp', type=str) 53 | parser.add_argument('--img_dim', default=128, type=int) 54 | parser.add_argument('--full_eval_freq', default=10, type=int) 55 | parser.add_argument('--num_workers', default=8, type=int) 56 | parser.add_argument('--notes', default='', type=str) 57 | 58 | parser.add_argument('--ensemble', default=0, type=int) 59 | parser.add_argument('--prob_imgs', default='', type=str) 60 | parser.add_argument('--prob_flow', default='', type=str) 61 | parser.add_argument('--prob_seg', default='', type=str) 62 | parser.add_argument('--prob_kphm', default='', type=str) 63 | 64 | 65 | def get_data_loader(args, mode='train'): 66 | print("Getting data loader for:", args.modality) 67 | transform = None 68 | if mode == 'train': 69 | transform = tu.get_train_transforms(args) 70 | elif mode == 'val': 71 | transform = tu.get_val_transforms(args) 72 | elif mode == 'test': 73 | transform = tu.get_test_transforms(args) 74 | loader = get_data(transform, mode) 75 | return loader 76 | 77 | 78 | def get_num_channels(modality): 79 | if modality == mu.ImgMode: 80 | return 3 81 | elif modality == mu.FlowMode: 82 | return 2 83 | elif modality == mu.FnbFlowMode: 84 | return 2 85 | elif modality == mu.KeypointHeatmap: 86 | return 17 87 | elif modality == mu.SegMask: 88 | return 1 89 | else: 90 | assert False, "Invalid modality: {}".format(modality) 91 | 92 | 93 | def freeze_backbone(model): 94 | print('Freezing the backbone...') 95 | for name, param in model.module.named_parameters(): 96 | if ('resnet' in name) or ('rnn' in name) or ('agg' in name): 97 | param.requires_grad = False 98 | return model 99 | 100 | 101 | def unfreeze_backbone(model): 102 | print('Unfreezing the backbone...') 103 | for name, param in model.module.named_parameters(): 104 | if ('resnet' in name) or ('rnn' in name): 105 | param.requires_grad = True 106 | return model 107 | 108 | 109 | def main(): 110 | global args; args = parser.parse_args() 111 | global device; device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 112 | 113 | if args.dataset == 'ucf101': args.num_class = 101 114 | elif args.dataset == 'hmdb51': args.num_class = 51 115 | 116 | if args.ensemble: 117 | def read_pkl(fname): 118 | if fname == '': 119 | return None 120 | with open(fname, 'rb') as f: 121 | prob = pickle.load(f) 122 | return prob 123 | ensemble(read_pkl(args.prob_imgs), read_pkl(args.prob_flow), read_pkl(args.prob_seg), read_pkl(args.prob_kphm)) 124 | sys.exit() 125 | 126 | args.in_channels = get_num_channels(args.modality) 127 | 128 | ### classifier model ### 129 | if args.model == 'lc': 130 | model = LC(sample_size=args.img_dim, 131 | num_seq=args.num_seq, 132 | seq_len=args.seq_len, 133 | in_channels=args.in_channels, 134 | network=args.net, 135 | num_class=args.num_class, 136 | dropout=args.dropout) 137 | else: 138 | raise ValueError('wrong model!') 139 | 140 | model = nn.DataParallel(model) 141 | model = model.to(device) 142 | global criterion; criterion = nn.CrossEntropyLoss() 143 | 144 | ### optimizer ### 145 | params = None 146 | if args.train_what == 'ft': 147 | print('=> finetune backbone with smaller lr') 148 | params = [] 149 | for name, param in model.module.named_parameters(): 150 | if ('resnet' in name) or ('rnn' in name): 151 | params.append({'params': param, 'lr': args.lr/10}) 152 | else: 153 | params.append({'params': param}) 154 | elif args.train_what == 'freeze': 155 | print('=> Freeze backbone') 156 | params = [] 157 | for name, param in model.module.named_parameters(): 158 | param.requires_grad = False 159 | else: 160 | pass # train all layers 161 | 162 | print('\n===========Check Grad============') 163 | for name, param in model.named_parameters(): 164 | if param.requires_grad == False: 165 | print(name, param.requires_grad) 166 | print('=================================\n') 167 | 168 | if params is None: 169 | params = model.parameters() 170 | 171 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 172 | # Old version 173 | # if args.dataset == 'hmdb51': 174 | # lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50,70,90], repeat=1) 175 | # elif args.dataset == 'ucf101': 176 | # if args.img_dim == 224: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[90,140,180], repeat=1) 177 | # else: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50, 70, 90], repeat=1) 178 | if args.img_dim == 224: 179 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60,120,180], repeat=1) 180 | else: 181 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[50, 70, 90], repeat=1) 182 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 183 | 184 | args.old_lr = None 185 | best_acc = 0 186 | global iteration; iteration = 0 187 | global num_epoch; num_epoch = 0 188 | 189 | ### restart training ### 190 | if args.test: 191 | if os.path.isfile(args.test): 192 | print("=> loading testing checkpoint '{}'".format(args.test)) 193 | checkpoint = torch.load(args.test) 194 | try: model.load_state_dict(checkpoint['state_dict']) 195 | except: 196 | print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==') 197 | model = neq_load_customized(model, checkpoint['state_dict']) 198 | print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch'])) 199 | num_epoch = checkpoint['epoch'] 200 | elif args.test == 'random': 201 | print("=> [Warning] loaded random weights") 202 | else: 203 | raise ValueError() 204 | 205 | test_loader = get_data_loader(args, 'test') 206 | test_loss, test_acc = test(test_loader, model, extensive=args.extensive) 207 | sys.exit() 208 | else: # not test 209 | torch.backends.cudnn.benchmark = True 210 | 211 | if args.resume: 212 | if os.path.isfile(args.resume): 213 | # args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1)) 214 | args.old_lr = 1e-3 215 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 216 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 217 | args.start_epoch = checkpoint['epoch'] 218 | best_acc = checkpoint['best_acc'] 219 | model.load_state_dict(checkpoint['state_dict']) 220 | if not args.reset_lr: # if didn't reset lr, load old optimizer 221 | optimizer.load_state_dict(checkpoint['optimizer']) 222 | else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr)) 223 | iteration = checkpoint['iteration'] 224 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 225 | else: 226 | print("=> no checkpoint found at '{}'".format(args.resume)) 227 | 228 | if (not args.resume) and args.pretrain: 229 | if args.pretrain == 'random': 230 | print('=> using random weights') 231 | elif os.path.isfile(args.pretrain): 232 | print("=> loading pretrained checkpoint '{}'".format(args.pretrain)) 233 | checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) 234 | model = neq_load_customized(model, checkpoint['state_dict']) 235 | print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch'])) 236 | else: 237 | print("=> no checkpoint found at '{}'".format(args.pretrain)) 238 | 239 | ### load data ### 240 | train_loader = get_data_loader(args, 'train') 241 | val_loader = get_data_loader(args, 'val') 242 | test_loader = get_data_loader(args, 'test') 243 | 244 | # setup tools 245 | global de_normalize; de_normalize = denorm() 246 | global img_path; img_path, model_path = set_path(args) 247 | global writer_train 248 | try: # old version 249 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 250 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 251 | except: # v1.7 252 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 253 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 254 | 255 | args.test = model_path 256 | print("Model path:", model_path) 257 | 258 | # Freeze the model backbone initially 259 | model = freeze_backbone(model) 260 | cooldown = 0 261 | 262 | ### main loop ### 263 | for epoch in range(args.start_epoch, args.epochs): 264 | num_epoch = epoch 265 | 266 | train_loss, train_acc = train(train_loader, model, optimizer, epoch) 267 | val_loss, val_acc = validate(val_loader, model) 268 | scheduler.step(epoch) 269 | 270 | writer_train.add_scalar('global/loss', train_loss, epoch) 271 | writer_train.add_scalar('global/accuracy', train_acc, epoch) 272 | writer_val.add_scalar('global/loss', val_loss, epoch) 273 | writer_val.add_scalar('global/accuracy', val_acc, epoch) 274 | 275 | # save check_point 276 | is_best = val_acc > best_acc 277 | best_acc = max(val_acc, best_acc) 278 | 279 | # Perform testing if either the frequency is hit or the model is the best after a few epochs 280 | if (epoch + 1) % args.full_eval_freq == 0: 281 | test(test_loader, model) 282 | elif (epoch > 70) and (cooldown >= 5) and is_best: 283 | test(test_loader, model) 284 | cooldown = 0 285 | else: 286 | cooldown += 1 287 | 288 | save_checkpoint( 289 | state={ 290 | 'epoch': epoch+1, 291 | 'net': args.net, 292 | 'state_dict': model.state_dict(), 293 | 'best_acc': best_acc, 294 | 'optimizer': optimizer.state_dict(), 295 | 'iteration': iteration 296 | }, 297 | mode=args.modality, 298 | is_best=is_best, 299 | gap=5, 300 | filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch+1)), 301 | keep_all=False) 302 | 303 | # Unfreeze the model backbone after the first run 304 | if epoch == (args.start_epoch): 305 | model = unfreeze_backbone(model) 306 | 307 | print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) 308 | print("Model path:", model_path) 309 | 310 | 311 | def train(data_loader, model, optimizer, epoch): 312 | losses = AverageMeter() 313 | accuracy = AverageMeter() 314 | model.train() 315 | global iteration 316 | 317 | tq = tqdm(data_loader, desc="Train: Ep {}".format(epoch)) 318 | 319 | for idx, (input_seq, target, _) in enumerate(tq): 320 | tic = time.time() 321 | input_seq = input_seq.to(device) 322 | target = target.to(device) 323 | B = input_seq.size(0) 324 | output, _ = model(input_seq) 325 | 326 | # visualize 327 | if (iteration == 0) or (iteration == args.print_freq): 328 | if B > 2: input_seq = input_seq[0:2,:] 329 | writer_train.add_image('input_seq', 330 | de_normalize(vutils.make_grid( 331 | input_seq[:, :3, ...].transpose(2,3).contiguous().view(-1,3,args.img_dim,args.img_dim), 332 | nrow=args.num_seq*args.seq_len)), 333 | iteration) 334 | del input_seq 335 | 336 | [_, N, D] = output.size() 337 | output = output.view(B*N, D) 338 | target = target.repeat(1, N).view(-1) 339 | 340 | loss = criterion(output, target) 341 | acc = calc_accuracy(output, target) 342 | 343 | del target 344 | 345 | losses.update(loss.item(), B) 346 | accuracy.update(acc.item(), B) 347 | 348 | optimizer.zero_grad() 349 | loss.backward() 350 | optimizer.step() 351 | 352 | total_weight = 0.0 353 | decay_weight = 0.0 354 | for m in model.parameters(): 355 | if m.requires_grad: decay_weight += m.norm(2).data 356 | total_weight += m.norm(2).data 357 | 358 | tq_stats = { 359 | 'loss': losses.local_avg, 360 | 'acc': accuracy.local_avg, 361 | 'decay_wt': decay_weight.item(), 362 | 'total_wt': total_weight.item(), 363 | } 364 | 365 | tq.set_postfix(tq_stats) 366 | 367 | if idx % args.print_freq == 0: 368 | writer_train.add_scalar('local/loss', losses.val, iteration) 369 | writer_train.add_scalar('local/accuracy', accuracy.val, iteration) 370 | 371 | iteration += 1 372 | 373 | return losses.local_avg, accuracy.local_avg 374 | 375 | 376 | def validate(data_loader, model): 377 | losses = AverageMeter() 378 | accuracy = AverageMeter() 379 | model.eval() 380 | with torch.no_grad(): 381 | tq = tqdm(data_loader, desc="Val: ") 382 | for idx, (input_seq, target, _) in enumerate(tq): 383 | input_seq = input_seq.to(device) 384 | target = target.to(device) 385 | B = input_seq.size(0) 386 | output, _ = model(input_seq) 387 | 388 | [_, N, D] = output.size() 389 | output = output.view(B*N, D) 390 | target = target.repeat(1, N).view(-1) 391 | 392 | loss = criterion(output, target) 393 | acc = calc_accuracy(output, target) 394 | 395 | losses.update(loss.item(), B) 396 | accuracy.update(acc.item(), B) 397 | 398 | tq.set_postfix({ 399 | 'loss': losses.avg, 400 | 'acc': accuracy.avg, 401 | }) 402 | 403 | print('Val - Loss {loss.avg:.4f}\t' 404 | 'Acc: {acc.avg:.4f} \t'.format(loss=losses, acc=accuracy)) 405 | return losses.avg, accuracy.avg 406 | 407 | 408 | def test(data_loader, model, extensive=False): 409 | losses = AverageMeter() 410 | acc_top1 = AverageMeter() 411 | acc_top5 = AverageMeter() 412 | acc_table = AccuracyTable(data_loader.dataset.action_dict_decode) 413 | confusion_mat = ConfusionMeter(args.num_class) 414 | probs = {} 415 | 416 | model.eval() 417 | with torch.no_grad(): 418 | tq = tqdm(data_loader, desc="Test: ") 419 | for idx, (input_seq, target, index) in enumerate(tq): 420 | input_seq = input_seq.to(device) 421 | target = target.to(device) 422 | B = input_seq.size(0) 423 | input_seq = input_seq.squeeze(0) # squeeze the '1' batch dim 424 | output, _ = model(input_seq) 425 | del input_seq 426 | 427 | prob = torch.mean(torch.mean(nn.functional.softmax(output, 2), 0), 0, keepdim=True) 428 | top1, top5 = calc_topk_accuracy(prob, target, (1,5)) 429 | acc_top1.update(top1.item(), B) 430 | acc_top5.update(top5.item(), B) 431 | del top1, top5 432 | 433 | output = torch.mean(torch.mean(output, 0), 0, keepdim=True) 434 | loss = criterion(output, target.squeeze(-1)) 435 | 436 | losses.update(loss.item(), B) 437 | del loss 438 | 439 | _, pred = torch.max(output, 1) 440 | confusion_mat.update(pred, target.view(-1).byte()) 441 | acc_table.update(pred, target) 442 | probs[index] = {'prob': prob.detach().cpu(), 'target': target.detach().cpu()} 443 | 444 | tq.set_postfix({ 445 | 'loss': losses.avg, 446 | 'acc1': acc_top1.avg, 447 | 'acc5': acc_top5.avg, 448 | }) 449 | 450 | print('Test - Loss {loss.avg:.4f}\t' 451 | 'Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5)) 452 | confusion_mat.plot_mat(args.test+'.svg') 453 | write_log(content='Loss {loss.avg:.4f}\t Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5, args=args), 454 | epoch=num_epoch, 455 | filename=os.path.join(os.path.dirname(args.test), 'test_log_{}.md').format(args.notes)) 456 | with open(os.path.join(os.path.dirname(args.test), 'test_probs_{}.pkl').format(args.notes), 'wb') as f: 457 | pickle.dump(probs, f) 458 | 459 | if extensive: 460 | acc_table.print_table() 461 | acc_table.print_dict() 462 | 463 | # import ipdb; ipdb.set_trace() 464 | return losses.avg, [acc_top1.avg, acc_top5.avg] 465 | 466 | 467 | def ensemble(prob_imgs=None, prob_flow=None, prob_seg=None, prob_kphm=None): 468 | acc_top1 = AverageMeter() 469 | acc_top5 = AverageMeter() 470 | 471 | probs = [prob_imgs, prob_flow, prob_seg, prob_kphm] 472 | for idx in range(len(probs)): 473 | if probs[idx] is not None: 474 | probs[idx] = {k[0][0].data: v for k, v in probs[idx].items()} 475 | valid_probs = [x for x in probs if x is not None] 476 | weights = [2, 2, 1, 1] 477 | 478 | ovr_probs = {} 479 | for k in valid_probs[0].keys(): 480 | ovr_probs[k] = valid_probs[0][k]['prob'] * 0.0 481 | total = 0 482 | for idx in range(len(probs)): 483 | p = probs[idx] 484 | if p is not None: 485 | total += weights[idx] 486 | ovr_probs[k] += p[k]['prob'] * weights[idx] 487 | ovr_probs[k] /= total 488 | 489 | top1, top5 = calc_topk_accuracy(ovr_probs[k], valid_probs[0][k]['target'], (1, 5)) 490 | acc_top1.update(top1.item(), 1) 491 | acc_top5.update(top5.item(), 1) 492 | 493 | print('Test - Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(top1=acc_top1, top5=acc_top5)) 494 | 495 | 496 | def get_data(transform, mode='train'): 497 | print('Loading data for "%s" ...' % mode) 498 | global dataset 499 | if args.dataset == 'ucf101': 500 | dataset = UCF101_3d(mode=mode, 501 | transform=transform, 502 | seq_len=args.seq_len, 503 | num_seq=args.num_seq, 504 | downsample=args.ds, 505 | which_split=args.split, 506 | modality=args.modality 507 | ) 508 | elif args.dataset == 'hmdb51': 509 | dataset = HMDB51_3d(mode=mode, 510 | transform=transform, 511 | seq_len=args.seq_len, 512 | num_seq=args.num_seq, 513 | downsample=args.ds, 514 | which_split=args.split, 515 | modality=args.modality 516 | ) 517 | else: 518 | raise ValueError('dataset not supported') 519 | my_sampler = data.RandomSampler(dataset) 520 | if mode == 'train': 521 | data_loader = data.DataLoader(dataset, 522 | batch_size=args.batch_size, 523 | sampler=my_sampler, 524 | shuffle=False, 525 | num_workers=args.num_workers, 526 | pin_memory=True, 527 | drop_last=True) 528 | elif mode == 'val': 529 | data_loader = data.DataLoader(dataset, 530 | batch_size=args.batch_size, 531 | sampler=my_sampler, 532 | shuffle=False, 533 | num_workers=args.num_workers, 534 | pin_memory=True, 535 | drop_last=True) 536 | elif mode == 'test': 537 | data_loader = data.DataLoader(dataset, 538 | batch_size=1, 539 | sampler=my_sampler, 540 | shuffle=False, 541 | num_workers=args.num_workers, 542 | pin_memory=True) 543 | print('"%s" dataset size: %d' % (mode, len(dataset))) 544 | return data_loader 545 | 546 | 547 | def set_path(args): 548 | if args.resume: exp_path = os.path.dirname(os.path.dirname(args.resume)) 549 | else: 550 | exp_path = 'log/{args.prefix}/ft_{args.dataset}-{args.img_dim}_mode-{args.modality}_' \ 551 | 'sp{args.split}_{0}_{args.model}_bs{args.batch_size}_' \ 552 | 'lr{1}_wd{args.wd}_ds{args.ds}_seq{args.num_seq}_len{args.seq_len}_' \ 553 | 'dp{args.dropout}_train-{args.train_what}{2}'.format( 554 | 'r%s' % args.net[6::], 555 | args.old_lr if args.old_lr is not None else args.lr, 556 | '_'+args.notes, 557 | args=args) 558 | exp_path = os.path.join(args.save_dir, exp_path) 559 | img_path = os.path.join(exp_path, 'img') 560 | model_path = os.path.join(exp_path, 'model') 561 | if not os.path.exists(img_path): os.makedirs(img_path) 562 | if not os.path.exists(model_path): os.makedirs(model_path) 563 | return img_path, model_path 564 | 565 | 566 | def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): 567 | '''return the multipier for LambdaLR, 568 | 0 <= ep < 10: gamma^0 569 | 10 <= ep < 15: gamma^1 570 | 15 <= ep < 20: gamma^2 571 | 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' 572 | max_step = max(step) 573 | effective_epoch = epoch % max_step 574 | if epoch // max_step >= repeat: 575 | exp = len(step) - 1 576 | else: 577 | exp = len([i for i in step if effective_epoch>=i]) 578 | return gamma ** exp 579 | 580 | 581 | if __name__ == '__main__': 582 | main() 583 | --------------------------------------------------------------------------------