├── .gitignore ├── LICENSE ├── README.md ├── comptab.jpg ├── fig2.png ├── fig5.jpg ├── menovideo ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── menovideo.cpython-37.pyc │ └── videopre.cpython-37.pyc ├── menovideo.py ├── version.py └── videopre.py ├── package_test.ipynb ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | menovideo.egg-info 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 almamon rasool abdali 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/data-efficient-video-transformer-for-violence/action-recognition-on-real-life-violence)](https://paperswithcode.com/sota/action-recognition-on-real-life-violence?p=data-efficient-video-transformer-for-violence) 2 | 3 | 4 | # Data-efficient-video-transformer 5 | 6 | this repo is for menovideo associated with the paper ['Data Efficient Video Transformer for Violence Detection' (DeVTR)](https://ieeexplore.ieee.org/abstract/document/9530829) 7 | 8 | one of big challenges facing researchers in computer vision with transformers especially in video tasks is the need for large data and high computational resources , our method called DeVTR (Data Efficient Video Transformer for Violence Detection) to overcame these challenges (he need for large data and high computational resources ) 9 | 10 | In this work, we propose a data-efficient video transformer (DeVTr) based on the transformer network as a Spatio-temporal learning method with a pre-trained 2d-Convolutional neural network (2d-CNN) as an embedding layer for the input data. The model has been trained and tested on the Real-life violence dataset (RLVS) and achieved an accuracy of 96.25%. A comparison of the result for the suggested method with previous techniques illustrated that the suggested method provides the best result among all the other studies for violence event detection. 11 | 12 | ### Results and benchmarking 13 | the model achieved 96.25% based on RLVS dataset and also worth to mention that it was better than TimeSformer in both memory efficiency and convergence speed and accuracy 14 | 15 | [Comparing results of DeVTr vs other methods based on RLVS Dataset ](https://github.com/mamonraab/Data-efficient-video-transformer/blob/main/comptab.jpg) 16 | 17 | 18 | [saliency map for random video of violence action ](https://github.com/mamonraab/Data-efficient-video-transformer/blob/main/fig5.jpg) 19 | 20 | 21 | ### menvideo package 22 | the [menovideo package](https://pypi.org/project/menovideo/) help you build video action recognation / video understanding model based on 23 | 1- build using our Novel model DeVTR with full costmaztion 24 | 2- video dataset reader and preprocessing to easly read videos and make them as pytorch ready dataloaders 25 | 3- Timedistributed warper similar to keras timedistributed warper which can help you easly build (classical CNN+LSTM ) 26 | 27 | 28 | this is new novel transformer network combined with Conv net to build a highly accuract video action recognation model with limited data and hw rescources 29 | 30 | 31 | ### simple usage 32 | 33 | install 34 | ``` 35 | pip install menovideo 36 | 37 | 38 | ``` 39 | import it 40 | ``` 41 | import menovideo.menovideo as menoformer 42 | import menovideo.videopre as vide_reader 43 | 44 | ``` 45 | 46 | init DeVTr model without pre-trained wights 47 | ``` 48 | model = menoformer.DeVTr() 49 | 50 | 51 | ``` 52 | init DeVTr with pre-trained wigths 53 | the trained wights can be [downloaded from this url](https://drive.google.com/file/d/1s7Z1c-4zC522BFVM5EiZDMQLe6ZV8QQh/view?usp=sharing) 54 | 55 | ``` 56 | wight = 'drive/MyDrive/Colab Notebooks/transformers/violance-detaction-myresearch/vg19bn40convtransformer-ep-0.pth' 57 | model2 = menoformer.DeVTr(w= wight , base ='default') 58 | 59 | ``` 60 | 61 | 62 | using the video reader and pre-processing helpers 63 | parameters is : 64 | 65 | 1. pandas dataframe contain the path and label of each video 66 | 2. number of frames for the singal video 67 | 3. RGB is the number of color channles 68 | 4. h is the hieght of the frame for each video 69 | 5. w is the width of the frame for each video 70 | ``` 71 | valid_dataset = vide_reader.TaskDataset(valid_df,timesep=time_stp,rgb=RGB,h=H,w=W) 72 | 73 | ``` 74 | 75 | for detlied example of using the labrary use [package_test.ipynb](https://github.com/mamonraab/Data-efficient-video-transformer/blob/main/package_test.ipynb) 76 | 77 | #### please use pytorch 1.9 for the pre-trained model 78 | 79 | To cite our paper/code: 80 | 81 | ``` 82 | 83 | @INPROCEEDINGS{9530829, author={Abdali, Almamon Rasool}, booktitle={2021 IEEE International Conference on Communication, Networks and Satellite (COMNETSAT)}, title={Data Efficient Video Transformer for Violence Detection}, year={2021}, volume={}, number={}, pages={195-199}, doi={10.1109/COMNETSAT53002.2021.9530829}} 84 | 85 | ``` 86 | -------------------------------------------------------------------------------- /comptab.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/comptab.jpg -------------------------------------------------------------------------------- /fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/fig2.png -------------------------------------------------------------------------------- /fig5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/fig5.jpg -------------------------------------------------------------------------------- /menovideo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/menovideo/__init__.py -------------------------------------------------------------------------------- /menovideo/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/menovideo/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /menovideo/__pycache__/menovideo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/menovideo/__pycache__/menovideo.cpython-37.pyc -------------------------------------------------------------------------------- /menovideo/__pycache__/videopre.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mamonraab/Data-efficient-video-transformer/cf76b17cac20adea811428dd534cece2e7f08a7e/menovideo/__pycache__/videopre.cpython-37.pyc -------------------------------------------------------------------------------- /menovideo/menovideo.py: -------------------------------------------------------------------------------- 1 | 2 | import timm 3 | import torch 4 | from torch.nn import functional as F 5 | from torch import nn 6 | import math 7 | 8 | 9 | class TimeWarp(nn.Module): 10 | def __init__(self, baseModel, method='sqeeze' , flatn = True): 11 | super(TimeWarp, self).__init__() 12 | self.baseModel = baseModel 13 | self.method = method 14 | self.flatn = flatn 15 | 16 | def forward(self, x): 17 | batch_size, time_steps, C, H, W = x.size() 18 | if self.method == 'loop': 19 | output = [] 20 | for i in range(time_steps): 21 | #input one frame at a time into the basemodel 22 | x_t = self.baseModel(x[:, i, :, :, :]) 23 | # Flatten the output 24 | if self.flatn: 25 | x_t = x_t.view(x_t.size(0), -1) 26 | output.append(x_t) 27 | #end loop 28 | #make output as ( samples, timesteps, output_size) 29 | x = torch.stack(output, dim=0).transpose_(0, 1) 30 | output = None # clear var to reduce data in memory 31 | x_t = None # clear var to reduce data in memory 32 | else: 33 | # reshape input to be (batch_size * timesteps, input_size) 34 | x = x.contiguous().view(batch_size * time_steps, C, H, W) 35 | x = self.baseModel(x) 36 | if self.flatn: 37 | x = x.view(x.size(0), -1) 38 | #make output as ( samples, timesteps, output_size) 39 | x = x.contiguous().view(batch_size , time_steps , x.size(-1)) 40 | #print(x.shape) 41 | return x 42 | 43 | class extractlastcell(nn.Module): 44 | def forward(self,x): 45 | out , _ = x 46 | return out[:, -1, :] 47 | # postiona encoder give use the information of the postion or (time of frame in the seq) 48 | # it will help us to learn the temproal feature 49 | class PostionalEcnoder(nn.Module): 50 | def __init__(self,embd_dim , dropout=0.1, time_steps=30): 51 | #embd_dim == d_model 52 | #time_steps == max_len 53 | 54 | super(PostionalEcnoder,self).__init__() 55 | self.dropout = nn.Dropout(p=dropout) 56 | self.embd_dim = embd_dim 57 | self.time_steps = time_steps 58 | def do_pos_encode(self): 59 | 60 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 61 | 62 | 63 | pe = torch.zeros(self.time_steps , self.embd_dim).to(device) 64 | for pos in range(self.time_steps): 65 | for i in range(0,self.embd_dim , 2):# tow steps loop , for each dim in embddim 66 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embd_dim))) 67 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embd_dim))) 68 | pe = pe.unsqueeze(0) #to make shape of (batch size , time steps ,embding_dim) 69 | return pe 70 | def forward(self , x): 71 | #x here is embded data must be shape of (batch , time_steps , embding_dim) 72 | x = x * math.sqrt(self.embd_dim) 73 | pe = self.do_pos_encode() 74 | x += pe[:, :x.size(1)] # pe will automatically be expanded with the same batch size as encoded_words 75 | x = self.dropout(x) 76 | return x 77 | 78 | class memoTransormer(nn.Module): 79 | def __init__(self , dim , heads=8 ,layers = 6 ,actv='gelu' ): 80 | super(memoTransormer,self).__init__() 81 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads , activation=actv) 82 | self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=layers) 83 | 84 | def forward(self , x): 85 | x = self.transformer_encoder(x) 86 | return x 87 | 88 | #orginal with and without wights can change num_class 89 | #any base model + your transformation 90 | def DeVTr(w= 'none' , base ='default' ,classifier='default',mid_layer=1024,mid_drop=0.4, num_classes=1 , dim_embd = 512 , dr_rate= 0.1 , time_stp = 40 , encoder_stack = 4,encoder_head = 8): 91 | 92 | #defualt devter with wights , without wight 93 | #defualt change class numbers 94 | #change base cnn 95 | #change classifier 96 | 97 | if base == 'default': 98 | if w !='none': 99 | num_classes = 1 100 | dr_rate= 0.1 101 | dim_embd = 512 102 | encoder_stack = 4 103 | encoder_head = 8 104 | time_stp = 40 105 | 106 | baseModel = timm.create_model('vgg19_bn', pretrained=True , num_classes=dim_embd ) 107 | i = 0 108 | for child in baseModel.features.children(): 109 | if i < 40: 110 | for param in child.parameters(): 111 | param.requires_grad = False 112 | else: 113 | for param in child.parameters(): 114 | param.requires_grad = True 115 | i +=1 116 | 117 | bas2 = nn.Sequential(baseModel , 118 | nn.ReLU(),) 119 | 120 | model = nn.Sequential(TimeWarp(bas2,method='loop' , flatn = False), 121 | #PrintLayer(), 122 | PostionalEcnoder(dim_embd , dropout=dr_rate, time_steps=time_stp), 123 | memoTransormer(dim_embd , heads=encoder_head ,layers = encoder_stack ,actv='gelu' ), 124 | #PrintLayer(), 125 | nn.Flatten(), 126 | #PrintLayer(), 127 | #20480 is frame numbers * dim 128 | nn.Linear(time_stp * dim_embd, 1024), 129 | nn.Dropout(0.4), 130 | nn.ReLU(), 131 | nn.Linear(1024, num_classes) 132 | 133 | ) 134 | if w !='none': 135 | if torch.cuda.is_available(): 136 | model.load_state_dict(torch.load(w)) 137 | else: 138 | model.load_state_dict(torch.load(w,map_location ='cpu')) 139 | else: 140 | 141 | bas2 = nn.Sequential(base , 142 | nn.ReLU(),) 143 | if classifier != 'default': 144 | model = nn.Sequential(TimeWarp(bas2,method='loop' , flatn = False), 145 | #PrintLayer(), 146 | PostionalEcnoder(dim_embd , dropout=dr_rate, time_steps=time_stp), 147 | memoTransormer(dim_embd , heads=encoder_head ,layers = encoder_stack ,actv='gelu' ), 148 | #PrintLayer(), 149 | nn.Flatten(), 150 | #PrintLayer(), 151 | #20480 is frame numbers * dim 152 | nn.Linear(time_stp * dim_embd, mid_layer), 153 | nn.Dropout(mid_drop), 154 | nn.ReLU(), 155 | classifier 156 | 157 | ) 158 | else: 159 | model = nn.Sequential(TimeWarp(bas2,method='loop' , flatn = False), 160 | #PrintLayer(), 161 | PostionalEcnoder(dim_embd , dropout=dr_rate, time_steps=time_stp), 162 | memoTransormer(dim_embd , heads=encoder_head ,layers = encoder_stack ,actv='gelu' ), 163 | #PrintLayer(), 164 | nn.Flatten(), 165 | #PrintLayer(), 166 | #20480 is frame numbers * dim 167 | nn.Linear(time_stp * dim_embd, mid_layer), 168 | nn.Dropout(mid_drop), 169 | nn.ReLU(), 170 | nn.Linear(mid_layer, num_classes) 171 | 172 | ) 173 | 174 | return model 175 | -------------------------------------------------------------------------------- /menovideo/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.5.0' 2 | -------------------------------------------------------------------------------- /menovideo/videopre.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from skimage.io import imread 4 | from skimage.transform import resize 5 | import torch 6 | from torchvision import datasets, models, transforms 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | def capture(filename,timesep,rgb,h,w): 11 | tmp = [] 12 | frames = np.zeros((timesep,rgb,h,w), dtype=np.float) 13 | i=0 14 | vc = cv2.VideoCapture(filename) 15 | if vc.isOpened(): 16 | rval , frame = vc.read() 17 | else: 18 | rval = False 19 | frm = resize(frame,(h, w,rgb)) 20 | frm = np.expand_dims(frm,axis=0) 21 | frm = np.moveaxis(frm, -1, 1) 22 | if(np.max(frm)>1): 23 | frm = frm/255.0 24 | frames[i][:] = frm 25 | i +=1 26 | while i < timesep: 27 | tmp[:] = frm[:] 28 | rval, frame = vc.read() 29 | frm = resize(frame,( h, w,rgb)) 30 | frm = np.expand_dims(frm,axis=0) 31 | if(np.max(frm)>1): 32 | frm = frm/255.0 33 | frm = np.moveaxis(frm, -1, 1) 34 | frames[i][:] = frm # - tmp 35 | i +=1 36 | del tmp 37 | del frm 38 | del rval 39 | return frames 40 | 41 | 42 | class TaskDataset(Dataset): 43 | """Fire dataset.""" 44 | 45 | def __init__(self, datas, timesep=10,rgb=3,h=90,w=90): 46 | """ 47 | Args: 48 | datas: pandas dataframe contain path to videos files with label of them 49 | timesep: number of frames 50 | rgb: number of color chanles 51 | h: height 52 | w: width 53 | 54 | """ 55 | self.dataloctions = datas 56 | self.timesep,self.rgb,self.h,self.w = timesep,rgb,h,w 57 | 58 | 59 | def __len__(self): 60 | return len(self.dataloctions) 61 | 62 | def __getitem__(self, idx): 63 | if torch.is_tensor(idx): 64 | idx = idx.tolist() 65 | 66 | video = capture(self.dataloctions.iloc[idx, 0],self.timesep,self.rgb,self.h,self.w) 67 | sample = {'video': torch.from_numpy(video), 'label': torch.from_numpy(np.asarray(self.dataloctions.iloc[idx, 1]))} 68 | 69 | 70 | return sample 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | pyyaml 4 | opencv-python 5 | numpy 6 | scikit-image -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('menovideo/version.py').read()) 14 | setup( 15 | name='menovideo', 16 | version=__version__, 17 | description='(Unofficial) PyTorch library data efficient video transformer for video understanding and action recognatio ', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/mamonraab/Data-efficient-video-transformer', 21 | author='almamon rasool abdali', 22 | author_email='mamonrasoolabdali@gmail.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 5 - Production/Stable', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained video models efficientnet transformer ', 44 | packages=find_packages(exclude=['convert', 'tests', 'results']), 45 | include_package_data=True, 46 | install_requires=['torch >= 1.4', 'torchvision'], 47 | python_requires='>=3.6', 48 | ) --------------------------------------------------------------------------------