├── DL ├── dl_ft_1_test_O_ECL.py ├── dl_ft_1_train_O_ECL.py └── dvs_gesture.py ├── README.md ├── configs ├── config_DVS_wECL.py ├── config_DVS_woECL.py └── config_NEK.py ├── images └── Model.PNG ├── nt_xent_original.py ├── train.py └── vtn ├── eventR50_VTN.yaml ├── eventVIT_B_VTN.yaml ├── parser_sf.py ├── vtn_ECL.py └── vtn_helper.py /DL/dl_ft_1_test_O_ECL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms, utils 6 | import random 7 | import pickle 8 | 9 | import json 10 | import math 11 | import cv2 12 | import time 13 | import torchvision.transforms as trans 14 | from fnmatch import fnmatch 15 | from pathlib import Path 16 | from itertools import chain 17 | import sys 18 | from DL.dl_ft_1_train_O_ECL import selectFrames 19 | 20 | class ek_test(Dataset): 21 | 22 | def __init__(self, shuffle = False,Test = True, kitchen = 'p01'): 23 | print(f'into initiliazation function of DL') 24 | self.shuffle = shuffle # I still need to add the shuffle functionality 25 | self.Test = Test 26 | self.all_paths = self.get_path(kitchen) 27 | if self.shuffle: 28 | random.shuffle(self.all_paths) 29 | self.data = self.all_paths 30 | self.PIL = trans.ToPILImage() 31 | self.TENSOR = trans.ToTensor() 32 | self.num_frames = 10 # 10 voxels/clip 33 | self.num_clips_test = 5 34 | 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self,index): 40 | #I need one clip at a time i.e. 10 voxels 41 | clip_1,clip_2,clip_3,clip_4,clip_5, clip_class, vid_path = self.process_data(index) 42 | return clip_1,clip_2,clip_3,clip_4,clip_5, clip_class, vid_path 43 | 44 | def get_path(self, kitchen): 45 | PATH = [] 46 | folders = [kitchen]#, 'p08', 'p22'] 47 | for fol in folders: 48 | if(self.Test==False): 49 | root = '/home/ad358172/AY/event_summer/phase_1/N-EPIC-Kitchens/ek_train_test/train/' + fol + '_train/' 50 | else: 51 | root = '/home/ad358172/AY/event_summer/phase_1/N-EPIC-Kitchens/ek_train_test/test/' + fol + '_test/' 52 | for path, subdirs, files in os.walk(root): 53 | for name in files: 54 | #if fnmatch(name, pattern): 55 | PATH.append(path) 56 | PATH = list(set(PATH)) 57 | PATH.sort() 58 | return PATH 59 | 60 | def process_data(self, idx): 61 | vid_path = self.data[idx].split(' ')[0] 62 | clip_1,clip_2,clip_3,clip_4,clip_5, clip_class = self.build_clip(vid_path) 63 | return clip_1,clip_2,clip_3,clip_4,clip_5, clip_class, vid_path 64 | 65 | def build_clip(self, vid_path): 66 | clip_class = [] 67 | actions = ['put','take','open','close','wash','cut','mix','pour'] 68 | for id, k in enumerate(actions): 69 | if(vid_path.find(k)!=-1): 70 | clip_class = id 71 | break 72 | clip_class = np.array(clip_class).repeat(self.num_clips_test) 73 | os.chdir(vid_path) #now we are into the parent directory e.g. P01_01 containg all npy voxels 74 | p = Path.cwd() 75 | 76 | ################################ frame list maker starts here ########################### 77 | files = list(p.glob("*.npy*")) 78 | files.sort() #sorting in ascending order 79 | frame_count = len(files) 80 | frames_dense = selectFrames(frame_count, self.num_frames, self.num_clips_test, False) 81 | 82 | 83 | 84 | #now frame_dense is 5x10 i.e. we would have 5 clips 85 | clip_1 = [];clip_2 = [];clip_3 = [];clip_4 = [];clip_5 = [] 86 | files = np.array(files) 87 | frames_dense = np.array(frames_dense) 88 | files = files[frames_dense] 89 | 90 | for iii in files[0]: clip_1.append(self.augmentation(np.load(iii),(224,224))) 91 | for iii in files[1]: clip_2.append(self.augmentation(np.load(iii),(224,224))) 92 | for iii in files[2]: clip_3.append(self.augmentation(np.load(iii),(224,224))) 93 | for iii in files[3]: clip_4.append(self.augmentation(np.load(iii),(224,224))) 94 | for iii in files[4]: clip_5.append(self.augmentation(np.load(iii),(224,224))) 95 | 96 | return clip_1,clip_2,clip_3,clip_4,clip_5, clip_class 97 | 98 | 99 | def augmentation(self, image, resize_size): 100 | x = np.einsum('ijk->jki',image) 101 | x = x + np.abs(np.min(x)) 102 | x *= 255/(x.max()) 103 | x[x>255] = 255 104 | x[x<0] = 0 105 | x = x.astype(np.uint8) 106 | image = self.PIL(x) 107 | transform = trans.transforms.Resize(resize_size) 108 | image = transform(image) 109 | image = trans.functional.to_tensor(image) #range 0-1 110 | return image 111 | 112 | def collate_fn_test(batch): 113 | clip_1 = [];clip_2 = [];clip_3 = [];clip_4 = [];clip_5 = [] 114 | clip_class = [] 115 | vid_path = [] 116 | for item in batch: 117 | clip_1.append(torch.stack(item[0],dim=0)) 118 | clip_2.append(torch.stack(item[1],dim=0)) 119 | clip_3.append(torch.stack(item[2],dim=0)) 120 | clip_4.append(torch.stack(item[3],dim=0)) 121 | clip_5.append(torch.stack(item[4],dim=0)) 122 | clip_class.append(torch.as_tensor(np.asarray(item[5]))) 123 | vid_path.append(item[6]) 124 | 125 | clip_1 = torch.stack(clip_1, dim=0) 126 | clip_2 = torch.stack(clip_2, dim=0) 127 | clip_3 = torch.stack(clip_3, dim=0) 128 | clip_4 = torch.stack(clip_4, dim=0) 129 | clip_5 = torch.stack(clip_5, dim=0) 130 | 131 | return clip_1,clip_2,clip_3,clip_4,clip_5, clip_class, vid_path 132 | 133 | return clip, clip_class,vid_path 134 | 135 | 136 | def vis_frames(clip,name,path): 137 | #temp = clip[0,:] 138 | temp = clip.permute(2,3,1,0) 139 | 140 | frame_width = 224 141 | frame_height = 224 142 | frame_size = (frame_width,frame_height) 143 | path = path + name + '.avi' 144 | video = cv2.VideoWriter(path,cv2.VideoWriter_fourcc('p', 'n', 'g', ' '),3,(frame_size[1],frame_size[0])) 145 | 146 | for i in range(temp.shape[3]): 147 | x = np.array(temp[:,:,:,i]) 148 | x *= 255/(x.max()) 149 | x[x>255] = 255 150 | x[x<0] = 0 151 | x = x.astype(np.uint8) 152 | video.write(x) 153 | video.release() 154 | if __name__ == '__main__': 155 | train_dataset = ek_test(shuffle = True) 156 | print(f'Train dataset length: {len(train_dataset)}') 157 | train_dataloader = DataLoader(train_dataset,batch_size=2, collate_fn=collate_fn_test, drop_last = True) 158 | print(f'Step involved: {len(train_dataset)/2}') 159 | t=time.time() 160 | for i, (clip_1,clip_2,clip_3,clip_4,clip_5, clip_class) in enumerate(train_dataloader): 161 | print(i) 162 | 163 | print(f'Time taken to load data is {time.time()-t}') 164 | -------------------------------------------------------------------------------- /DL/dl_ft_1_train_O_ECL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms, utils 6 | #import config as cfg 7 | import random 8 | import pickle 9 | 10 | import json 11 | import math 12 | import cv2 13 | import time 14 | import torchvision.transforms as trans 15 | from fnmatch import fnmatch 16 | from pathlib import Path 17 | from itertools import chain 18 | import torchvision 19 | import imageio as iio 20 | import sys 21 | 22 | class ek_train(Dataset): 23 | 24 | def __init__(self, shuffle = True, trainKitchen = 'p01', eventDrop = False, eventAugs = ['all'], numClips = 1): 25 | print(f'into initiliazation function of DL (O)') 26 | self.shuffle = shuffle # I still need to add the shuffle functionality 27 | self.all_paths = self.get_path(trainKitchen) 28 | if self.shuffle: 29 | random.shuffle(self.all_paths) 30 | self.data = self.all_paths 31 | self.PIL = trans.ToPILImage() 32 | self.TENSOR = trans.ToTensor() 33 | self.num_frames = 10 # 10 voxels/clip 34 | self.eventDrop = eventDrop 35 | self.numClips = numClips 36 | if "all" in eventAugs: 37 | self.eventAugs = ["val", "rand", "time", "rect", "pol"] 38 | else: 39 | self.eventAugs = eventAugs 40 | 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self,index): 46 | #I need one clip at a time i.e. 10 voxels 47 | if self.numClips == 1: 48 | clip, clip_class,vid_path = self.process_data(index) 49 | return clip, clip_class,vid_path 50 | if self.numClips == 2: 51 | clip, clip1, clip_class,vid_path = self.process_data(index) 52 | return clip, clip1, clip_class,vid_path 53 | 54 | def get_path(self, trainKitchen): 55 | PATH = [] 56 | folders = [trainKitchen]#, 'p08', 'p22'] 57 | for fol in folders: 58 | root = '/home/ad358172/AY/event_summer/phase_1/N-EPIC-Kitchens/ek_train_test/train/' + fol + '_train/' 59 | #pattern = "*.npy" 60 | for path, subdirs, files in os.walk(root): 61 | for name in files: 62 | #if fnmatch(name, pattern): 63 | PATH.append(path) 64 | PATH = list(set(PATH)) 65 | PATH.sort() 66 | 67 | return PATH 68 | 69 | def process_data(self, idx): 70 | 71 | vid_path = self.data[idx].split(' ')[0] 72 | if self.numClips == 1: 73 | clip, clip_class = self.build_clip(vid_path) 74 | return clip, clip_class,vid_path 75 | elif self.numClips == 2: 76 | clip, clip1, clip_class = self.build_clip(vid_path) 77 | return clip, clip1, clip_class,vid_path 78 | #print("vid_path", vid_path, "\nclip_class", actions[clip_class]) 79 | 80 | 81 | 82 | def build_clip(self, vid_path): 83 | clip_class = [] 84 | 85 | actions = ['put','take','open','close','wash','cut','mix','pour'] 86 | for id, k in enumerate(actions): 87 | if(vid_path.find(k)!=-1): 88 | clip_class = id 89 | break 90 | os.chdir(vid_path) #now we are into the parent directory e.g. P01_01 containg all npy voxels 91 | p = Path.cwd() 92 | 93 | ################################ frame list maker starts here ########################### 94 | files = list(p.glob("*.npy*")) 95 | files.sort() #sorting in ascending order 96 | files = np.array(files) 97 | frame_count = len(files) 98 | frames_dense = selectFrames(frame_count, self.num_frames, 1, True) 99 | files_1 = files[frames_dense] 100 | height = 256 101 | width = 456 102 | finalHW = 224 103 | clips = [] 104 | for nc in range(self.numClips): 105 | clip = [] 106 | random_array = np.random.rand(10) 107 | x_erase = np.random.randint(0,finalHW, size = (2,)) 108 | y_erase = np.random.randint(0,finalHW, size = (2,)) 109 | 110 | 111 | cropping_factor1 = np.random.uniform(0.8, 1) # on an average cropping factor is 80% i.e. covers 64% area 112 | x0 = np.random.randint(0, width - width*cropping_factor1 + 1) 113 | y0 = np.random.randint(0, height - height*cropping_factor1 + 1) 114 | 115 | erase_size1 = np.random.randint(int(height/6),int(height/3), size = (2,)) 116 | erase_size2 = np.random.randint(int(width/6),int(width/3), size = (2,)) 117 | 118 | eventHide = np.random.random((finalHW, finalHW)) 119 | eventHide = np.array([eventHide, eventHide, eventHide]) 120 | eventHide = np.einsum('ijk->jki',eventHide) 121 | ratioHide = np.random.randint(0, 16)/100.00 122 | 123 | timeRatio = 0 124 | intensityThreshold = np.random.randint(0, 21)/100.00 125 | 126 | # + vid_path + "\n" + str1) 127 | for ind, i in enumerate(files_1): 128 | frame = np.load(i)#frame is the individual voxel 129 | x = np.einsum('ijk->jki',frame) 130 | 131 | minsaved = np.min(x) 132 | x = x + np.abs(minsaved) 133 | maxsaved = x.max() 134 | shift = int(np.abs(minsaved) * 255/(maxsaved)) 135 | x *= 255/(maxsaved) 136 | x[x>255] = 255; x[x<0] = 0 137 | x = x.astype(np.uint8) 138 | 139 | fname = vid_path.rsplit('/', 1)[-1] + "_" + str(ind) 140 | y= self.augmentation(x,random_array, x_erase, y_erase, cropping_factor1, x0, y0, erase_size1,erase_size2, height, width, finalHW, eventHide, ratioHide, shift, timeRatio, intensityThreshold) 141 | 142 | clip.append(y) 143 | timeRatio += 0.07 144 | clips.append(clip) 145 | if self.numClips == 2: 146 | return clips[0], clips[1], clip_class 147 | if self.numClips == 1: 148 | return clips[0], clip_class 149 | 150 | 151 | 152 | def augmentation(self, image, random_array, x_erase, y_erase, cropping_factor1, x0, y0, erase_size1,erase_size2, height, width, finalHW, eventHide, ratioHide, shift, timeRatio, intensityThreshold): 153 | image = self.PIL(image) 154 | image = trans.functional.resized_crop(image,y0,x0,int(height*cropping_factor1),int(width*cropping_factor1),(finalHW,finalHW)) 155 | 156 | if random_array[0] > 0.5: 157 | image = trans.functional.hflip(image) 158 | 159 | image = np.array(image) 160 | 161 | if (self.eventDrop): 162 | posThreshold = shift + (255 - shift) * intensityThreshold 163 | negThreshold = shift * (1- intensityThreshold) 164 | 165 | if "val" in self.eventAugs: 166 | #erase by value 167 | if random_array[1] > 0.7: 168 | image[(image < posThreshold) & (image > negThreshold)] = shift 169 | 170 | if "rand" in self.eventAugs: 171 | #random erase 172 | if random_array[3] > 0.8: 173 | #random erase not the same for each channel / time 174 | eventHide = np.random.random(image.shape) 175 | ratioHide = np.random.randint(0, 16)/100.00 176 | if random_array[3] > 0.6: 177 | image[(eventHide < ratioHide) & (image != shift)] = shift 178 | if "time" in self.eventAugs: 179 | #erase with time 180 | if (random_array[3] > 0.4) and (random_array[3] < 0.6): 181 | image[eventHide < timeRatio] = shift 182 | 183 | if "rect" in self.eventAugs: 184 | #erase entire rectangles 185 | if random_array[4] > 0.6: 186 | image[x_erase[0]:x_erase[0] + erase_size1[0],y_erase[0]: y_erase[0] + erase_size2[0],:] = shift 187 | if random_array[5] > 0.6: 188 | image[x_erase[1]:x_erase[1] + erase_size1[1],y_erase[1]: y_erase[1] + erase_size2[1],:] = shift 189 | 190 | if "pol" in self.eventAugs: 191 | #erase based on pos/neg 192 | if random_array[6] > 0.8: 193 | image[image > shift] = shift 194 | elif random_array[6] > 0.6: 195 | image[image < shift] = shift 196 | 197 | image = trans.functional.to_tensor(image) 198 | 199 | 200 | return image 201 | 202 | def collate_fn2(batch): 203 | clip = [] 204 | clip1 = [] 205 | clip_class = [] 206 | vid_path = [] 207 | twoClips = False 208 | for item in batch: 209 | if not (None in item): 210 | clip.append(torch.stack(item[0],dim=0)) 211 | if (len(item) == 4): 212 | twoClips = True 213 | clip1.append(torch.stack(item[1],dim=0)) 214 | clip_class.append(torch.as_tensor(np.asarray(item[2]))) 215 | vid_path.append(item[3]) 216 | else: 217 | clip_class.append(torch.as_tensor(np.asarray(item[1]))) 218 | vid_path.append(item[2]) 219 | 220 | 221 | clip = torch.stack(clip, dim=0) 222 | if twoClips: 223 | clip1 = torch.stack(clip1, dim=0) 224 | return clip, clip1, clip_class,vid_path 225 | return clip, clip_class,vid_path 226 | 227 | def vis_frames(clip,name,path): 228 | #temp = clip[0,:] 229 | temp = clip.permute(2,3,1,0) 230 | 231 | frame_width = 224 232 | frame_height = 224 233 | frame_size = (frame_width,frame_height) 234 | path = path + '/' + name + '.avi' 235 | print(path) 236 | video = cv2.VideoWriter(path,cv2.VideoWriter_fourcc('p', 'n', 'g', ' '),2,(frame_size[1],frame_size[0])) 237 | 238 | for i in range(temp.shape[3]): 239 | x = np.array(temp[:,:,:,i]) 240 | x *= 255/(x.max()) 241 | x[x>255] = 255 242 | x[x<0] = 0 243 | x = x.astype(np.uint8) 244 | #x = np.clip(x, a_min = -0.5, a_max = 0.5) 245 | video.write(x) 246 | video.release() 247 | 248 | def find_action(vid_path): 249 | actions = ['put','take','open','close','wash','cut','mix','pour'] 250 | for id, k in enumerate(actions): 251 | if(vid_path.find(k)!=-1): 252 | clip_class = id 253 | break 254 | return clip_class 255 | 256 | def selectFrames(frame_count, num_frames, num_clips_test, isTrain): 257 | if(frame_count num_frames): 265 | repeat_rate = i 266 | break 267 | if (isTrain): 268 | start = np.random.randint(len(s_1) - num_frames + 1) 269 | 270 | frames_dense = np.array(s_1[start:start+num_frames]) 271 | else: 272 | frames_dense = [] 273 | for j in range(num_clips_test): 274 | start = np.random.randint((frame_count*repeat_rate-num_frames)/repeat_rate + 1) * repeat_rate 275 | frames_dense.append(np.array(s_1[start:start+num_frames])) 276 | frames_dense = np.array(frames_dense) 277 | else: 278 | if (isTrain): 279 | skipRate = int(frame_count/num_frames)#np.random.randint(int(frame_count/num_frames)) + 1 280 | frames_dense = np.array(np.linspace(0,num_frames-1,num_frames,dtype=int) * skipRate 281 | + np.random.randint(frame_count - skipRate * (num_frames-1))) 282 | else: 283 | frames_dense = [] 284 | for i in range(num_clips_test): 285 | skipRate = np.random.randint(int(frame_count/num_frames)) + 1 286 | #skipRate = int(frame_count/num_frames) 287 | frames_dense.append(np.linspace(0,num_frames-1,num_frames,dtype=int) * skipRate 288 | + np.random.randint(frame_count - skipRate * (num_frames-1))) 289 | frames_dense = np.array(frames_dense) 290 | return frames_dense 291 | 292 | if __name__ == '__main__': 293 | actions = ['put','take','open','close','wash','cut','mix','pour'] 294 | train_dataset = ek_train(shuffle = True, trainKitchen = 'p01', eventDrop = False, eventAugs = ['all'], numClips = 2) 295 | print(f'Train dataset length: {len(train_dataset)}') 296 | train_dataloader = DataLoader(train_dataset,batch_size=1,shuffle= True, collate_fn=collate_fn2, drop_last = True) 297 | t=time.time() 298 | for i, (clip, clip1, clip_class,vid_path) in enumerate(train_dataloader): 299 | print(clip.shape) 300 | print(clip1.shape) 301 | a1 = find_action(vid_path[0]) 302 | 303 | print(i) 304 | print(f'Time taken to load data is {time.time()-t}') 305 | -------------------------------------------------------------------------------- /DL/dvs_gesture.py: -------------------------------------------------------------------------------- 1 | # DVS Gesture citation: A. Amir, B. Taba, D. Berg, T. Melano, J. McKinstry, C. Di Nolfo, T. Nayak, A. Andreopoulos, G. Garreau, M. Mendoza, J. Kusnitz, M. Debole, S. Esser, T. Delbruck, M. Flickner, and D. Modha, "A Low Power, Fully Event-Based Gesture Recognition System," 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017. 2 | # Dataloader adapted from https://github.com/nmi-lab/torchneuromorphic by Emre Neftci and Clemens Schaefer 3 | 4 | import struct 5 | import time 6 | import numpy as np 7 | import h5py 8 | import torch.utils.data 9 | from ..neuromorphic_dataset import NeuromorphicDataset 10 | from ..events_timeslices import * 11 | from .._transforms import * 12 | import os 13 | from tqdm import tqdm 14 | import glob 15 | from .._utils import * 16 | import torchvision.transforms as trans 17 | 18 | 19 | 20 | mapping = { 21 | 0: "Hand Clapping", 22 | 1: "Right Hand Wave", 23 | 2: "Left Hand Wave", 24 | 3: "Right Arm CW", 25 | 4: "Right Arm CCW", 26 | 5: "Left Arm CW", 27 | 6: "Left Arm CCW", 28 | 7: "Arm Roll", 29 | 8: "Air Drums", 30 | 9: "Air Guitar", 31 | 10: "Other", 32 | } 33 | 34 | 35 | class DVSGesture(NeuromorphicDataset): 36 | 37 | """`DVS Gesture `_ Dataset. 38 | 39 | The data was recorded using a DVS128. The dataset contains 11 hand gestures from 29 subjects under 3 illumination conditions. 40 | 41 | **Number of classes:** 11 42 | 43 | **Number of train samples:** 1176 44 | 45 | **Number of test samples:** 288 46 | 47 | **Dimensions:** ``[num_steps x 2 x 128 x 128]`` 48 | 49 | * **num_steps:** time-dimension of event-based footage 50 | * **2:** number of channels (on-spikes for luminance increasing; off-spikes for luminance decreasing) 51 | * **128x128:** W x H spatial dimensions of event-based footage 52 | 53 | For further reading, see: 54 | 55 | *A. Amir, B. Taba, D. Berg, T. Melano, J. McKinstry, C. Di Nolfo, T. Nayak, A. Andreopoulos, G. Garreau, M. Mendoza, J. Kusnitz, M. Debole, S. Esser, T. Delbruck, M. Flickner, and D. Modha, "A Low Power, Fully Event-Based Gesture Recognition System," 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017.* 56 | 57 | 58 | 59 | 60 | Example:: 61 | 62 | from snntorch.spikevision import spikedata 63 | 64 | train_ds = spikedata.DVSGesture("data/dvsgesture", train=True, num_steps=500, dt=1000) 65 | test_ds = spikedata.DVSGesture("data/dvsgesture", train=False, num_steps=1800, dt=1000) 66 | 67 | # by default, each time step is integrated over 1ms, or dt=1000 microseconds 68 | # dt can be changed to integrate events over a varying number of time steps 69 | # Note that num_steps should be scaled inversely by the same factor 70 | 71 | train_ds = spikedata.DVSGesture("data/dvsgesture", train=True, num_steps=250, dt=2000) 72 | test_ds = spikedata.DVSGesture("data/dvsgesture", train=False, num_steps=900, dt=2000) 73 | 74 | 75 | The dataset can also be manually downloaded, extracted and placed into ``root`` which will allow the dataloader to bypass straight to the generation of a hdf5 file. 76 | 77 | **Direct Download Links:** 78 | 79 | `IBM Box Link `_ 80 | 81 | `Dropbox Link `_ 82 | 83 | 84 | :param root: Root directory of dataset. 85 | :type root: string 86 | 87 | :param train: If True, creates dataset from training set of dvsgesture, otherwise test set. 88 | :type train: bool, optional 89 | 90 | :param transform: A function/transform that takes in a PIL image and returns a transforms version. By default, a pre-defined set of transforms are applied to all samples to convert them into a time-first tensor with correct orientation. 91 | :type transform: callable, optional 92 | 93 | :param target_transform: A function/transform that takes in the target and transforms it. 94 | :type target_transform: callable, optional 95 | 96 | :param download_and_create: If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. 97 | :type download_and_create: bool, optional 98 | 99 | :param num_steps: Number of time steps, defaults to ``500`` for train set, or ``1800`` for test set 100 | :type num_steps: int, optional 101 | 102 | :param dt: The number of time stamps integrated in microseconds, defaults to ``1000`` 103 | :type dt: int, optional 104 | 105 | :param ds: Rescaling factor, defaults to ``1``. 106 | :type ds: int, optional 107 | 108 | :return_meta: Option to return metadata, defaults to ``False`` 109 | :type return_meta: bool, optional 110 | 111 | :time_shuffle: Option to randomize start time of dataset, defaults to ``False`` 112 | :type time_shuffle: bool, optional 113 | 114 | Dataloader adapted from `torchneuromorphic `_ originally by Emre Neftci and Clemens Schaefer. 115 | 116 | The dataset is released under a Creative Commons Attribution 4.0 license. All rights remain with the original authors. 117 | """ 118 | 119 | # _resources_url = [['Manually Download dataset here: https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/file/211521748942?sb=/details and place under {0}'.format(directory),None, 'DvsGesture.tar.gz']] 120 | 121 | _resources_url = [ 122 | [ 123 | "https://www.dropbox.com/s/cct5kyilhtsliup/DvsGesture.tar.gz?dl=1", 124 | None, 125 | "DvsGesture.tar.gz", 126 | ] 127 | ] 128 | # directory = 'data/dvsgesture/' 129 | 130 | def __init__( 131 | self, 132 | root, 133 | train=True, 134 | isVal = False, 135 | transform=None, 136 | target_transform=None, 137 | download_and_create=True, 138 | num_steps=None, 139 | final_frames=None, 140 | dt=1000, 141 | ds=None, 142 | return_meta=False, 143 | time_shuffle=False, 144 | eventDrop=True, 145 | numClips = 1, 146 | eventAugs = ['all'], 147 | train_temp_align = True, 148 | skip_rate = 1, 149 | randomcrop = False, 150 | changing_sr = False, 151 | rdCrop_fr = False, 152 | evtDropPol = False, 153 | adv_changing_dt = False, 154 | dvs_imageSize = 128, 155 | val_cr = False, 156 | ): 157 | 158 | self.n = 0 159 | self.download_and_create = download_and_create 160 | self.root = root 161 | self.train = train 162 | self.dt = dt 163 | self.return_meta = return_meta 164 | self.time_shuffle = time_shuffle 165 | self.hdf5_name = "dvs_gesture.hdf5" 166 | self.directory = root.split(self.hdf5_name)[0] 167 | self.resources_local = [self.directory + "/DvsGesture.tar.gz"] 168 | self.resources_local_extracted = [self.directory + "/DvsGesture/"] 169 | self.eventDrop = eventDrop 170 | self.final_frames = final_frames 171 | self.numClips = numClips 172 | self.isVal = isVal 173 | self.train_temp_align = train_temp_align 174 | self.skip_rate = skip_rate 175 | self.randomcrop = randomcrop 176 | self.changing_sr = changing_sr 177 | self.rdCrop_fr = rdCrop_fr 178 | self.evtDropPol = evtDropPol 179 | self.adv_changing_dt = adv_changing_dt 180 | self.dvs_imageSize = dvs_imageSize 181 | self.val_cr = val_cr 182 | if "all" in eventAugs: 183 | self.eventAugs = ["val", "rand", "time", "rect", "pol"] 184 | else: 185 | self.eventAugs = eventAugs 186 | 187 | 188 | 189 | if ds is None: 190 | ds = 1 191 | if isinstance(ds, int): 192 | ds = [ds, ds] 193 | 194 | size = [3, 128 // ds[0], 128 // ds[1]] # 128//ds[0], 128//ds[1] 195 | 196 | if num_steps is None: 197 | if self.train: 198 | self.num_steps = 500 199 | else: 200 | self.num_steps = 1800 201 | else: 202 | self.num_steps = num_steps 203 | 204 | if self.adv_changing_dt: 205 | assert self.final_frames == 16 206 | assert self.dt == 5000 207 | self.dt = 1000 208 | self.num_steps = 500 209 | self.final_frames = 80 210 | 211 | if transform is None: 212 | transform = Compose( 213 | [ 214 | Downsample(factor=[self.dt, 1, ds[0], ds[1]]), 215 | ToCountFrame(T=self.num_steps, size=size), 216 | ToTensor(), 217 | dvs_permute(), 218 | ] 219 | ) 220 | 221 | 222 | if target_transform is not None: 223 | target_transform = Compose([Repeat(num_steps), toOneHot(11)]) 224 | 225 | super(DVSGesture, self).__init__( 226 | root=root + "/" + self.hdf5_name, 227 | transform=transform, 228 | target_transform_train=target_transform, 229 | ) 230 | 231 | with h5py.File(self.root, "r", swmr=True, libver="latest") as f: 232 | if train: 233 | self.n = f["extra"].attrs["Ntrain"] 234 | self.keys = f["extra"]["train_keys"][()] 235 | else: 236 | self.n = f["extra"].attrs["Ntest"] 237 | self.keys = f["extra"]["test_keys"][()] 238 | 239 | def _download(self): 240 | isexisting = super(DVSGesture, self)._download() 241 | 242 | def _create_hdf5(self): 243 | create_events_hdf5( 244 | self.directory, 245 | self.resources_local_extracted[0], 246 | self.directory + "/" + self.hdf5_name, 247 | ) 248 | 249 | def __len__(self): 250 | return self.n 251 | 252 | def __getitem__(self, key): 253 | 254 | # print("self.train", self.train) # True 255 | # print("self.transform", self.transform) #function 256 | # print("self.target_transform", self.target_transform) #function 257 | # print("self.return_meta", self.return_meta) # False 258 | 259 | # Important to open and close in getitem to enable num_workers>0 260 | if self.changing_sr: 261 | assert self.train == True 262 | self.skip_rate = np.random.randint(4,6) 263 | self.num_steps = np.random.randint(90,111) 264 | self.dt = int(500000 / self.num_steps) 265 | # frames2skip = 30000/ self.dt 266 | self.transform = Compose( 267 | [ 268 | Downsample(factor=[self.dt, 1, 1, 1]), 269 | ToCountFrame(T=self.num_steps, size=[3, 128, 128]), 270 | ToTensor(), 271 | dvs_permute(), 272 | ] 273 | ) 274 | assert self.skip_rate*self.final_frames < self.num_steps 275 | # if self.adv_changing_dt: 276 | # self.dt = np.random.randint(3,11) * 1000 277 | # self.num_steps = int(500000/self.dt) 278 | # self.skip_rate = np.max(int(5000/self.dt), 1) 279 | 280 | 281 | with h5py.File(self.root, "r", swmr=True, libver="latest") as f: 282 | if not self.train: 283 | key = key + f["extra"].attrs["Ntrain"] 284 | assert key in self.keys 285 | data, target, meta_info_light, meta_info_user = sample( 286 | f, key, T=self.num_steps, shuffle=self.time_shuffle, train=self.train 287 | ) 288 | # print(data[0]) 289 | 290 | if (self.isVal == False): 291 | clips = [] 292 | start = np.random.randint(self.num_steps - (self.final_frames - 1)*self.skip_rate) 293 | lengths = [] 294 | starts = [] 295 | for i in range(self.numClips): 296 | if self.adv_changing_dt: 297 | length = np.random.randint(3,11) 298 | start_changing_dt = np.random.randint(0,11-length) 299 | lengths.append(length) 300 | starts.append(start_changing_dt) 301 | data1 = data.copy() 302 | if self.transform is not None: 303 | data1 = self.transform(data1) 304 | if self.dvs_imageSize != 128: 305 | data1 = trans.functional.resize(data1, [self.dvs_imageSize,self.dvs_imageSize]) 306 | if self.train_temp_align == False: 307 | start = np.random.randint(self.num_steps - (self.final_frames - 1)*self.skip_rate) 308 | if self.adv_changing_dt: 309 | data2 = torch.zeros((16,3,self.dvs_imageSize,self.dvs_imageSize)) 310 | frames = torch.tensor(list(range(start, start + self.final_frames*self.skip_rate))) 311 | frames_f0 = torch.tensor(list(range(0, self.final_frames*self.skip_rate))) 312 | frame = 0 313 | 314 | for j in frames_f0[::25]: 315 | frames2concat = frames[j:j+10] 316 | frames2concat = frames2concat[start_changing_dt: start_changing_dt + length] 317 | for k in frames2concat: 318 | data2[frame] += data1[k] 319 | frame += 1 320 | del frames 321 | del frames_f0 322 | del frame 323 | del frames2concat 324 | f_frames = list(range(start+self.final_frames*self.skip_rate + 1))[start:start+self.final_frames*self.skip_rate:self.skip_rate] 325 | data1 = data1[f_frames] 326 | if self.adv_changing_dt: 327 | data1 = data2 328 | del data2 329 | if np.random.random_sample() > 0.5: 330 | data1 = self.eventDropAug(data1, key) 331 | if self.randomcrop: 332 | data1 = self.resizedCrop(data1) 333 | clips.append(data1) 334 | if self.numClips == 2: 335 | # print(lengths, starts) 336 | data, data1 = clips[0], clips[1] 337 | elif self.numClips == 1: 338 | data = clips[0] 339 | 340 | 341 | if self.target_transform is not None: 342 | target = self.target_transform(target) 343 | 344 | if self.return_meta is True: 345 | return data, target, meta_info_light, meta_info_user 346 | else: 347 | if (self.numClips > 1): 348 | return data, data1, target, key 349 | else: 350 | return data, target, key 351 | if (self.isVal == True): 352 | clips = [] 353 | for i in range(self.numClips): 354 | data1 = data.copy() 355 | data1 = self.transform(data1) 356 | if self.dvs_imageSize != 128: 357 | data1 = trans.functional.resize(data1, [self.dvs_imageSize,self.dvs_imageSize]) 358 | start = np.random.randint(len(data1) - (self.final_frames - 1)*self.skip_rate) 359 | f_frames = [item for item in list(range(start,start + (self.final_frames - 1)*self.skip_rate + 1)) 360 | if item % self.skip_rate == start % self.skip_rate] 361 | data1 = data1[f_frames] 362 | if self.val_cr: 363 | rd_height = np.random.randint(int(self.dvs_imageSize*.9), self.dvs_imageSize) 364 | centcrop = trans.transforms.CenterCrop([rd_height, rd_height]) 365 | resize = trans.transforms.Resize([224,224]) 366 | data1 = centcrop(data1) 367 | data1 = resize(data1) 368 | # if np.random.random_sample() > 0.5: 369 | # data1 = self.eventDropAug(data1, key) 370 | clips.append(data1[None, :]) 371 | clips = torch.cat(clips, dim=0) 372 | # print(clips.shape) 373 | return clips, target, key 374 | 375 | def resizedCrop(self, data): 376 | finalHW = self.dvs_imageSize 377 | xmin = int(np.random.randint(0,13)*(self.dvs_imageSize/128)) 378 | xmax = int(np.random.randint(115,128)*(self.dvs_imageSize/128)) 379 | length = xmax-xmin 380 | rd = np.random.rand() 381 | yborder = np.random.randint(self.dvs_imageSize - length) 382 | for i in range(len(data)): 383 | if self.rdCrop_fr: 384 | xmin = int(np.random.randint(0,13)*(self.dvs_imageSize/128)) 385 | xmax = int(np.random.randint(115,128)*(self.dvs_imageSize/128)) 386 | length = xmax-xmin 387 | rd = np.random.rand() 388 | yborder = np.random.randint(self.dvs_imageSize - length) 389 | data[i] = trans.functional.resized_crop(data[i],xmin,yborder,length,length,(finalHW,finalHW)) 390 | return data 391 | 392 | def eventDropAug(self, data, key): 393 | height = width = self.dvs_imageSize 394 | finalHW = self.dvs_imageSize 395 | random_array = np.random.rand(10) 396 | eventHide = torch.rand((3, finalHW, finalHW)) 397 | ratioHide = np.random.randint(0, 16)/100.00 398 | timeRatio = 0 399 | maxTR = 0.35 400 | threshold = 1 401 | x_erase = np.random.randint(0,height, size = (2,)) 402 | y_erase = np.random.randint(0,width, size = (2,)) 403 | erase_size1 = np.random.randint(int(height/6),int(height/5), size = (2,)) 404 | erase_size2 = np.random.randint(int(width/6),int(width/5), size = (2,)) 405 | 406 | if self.eventDrop: 407 | for image in data: 408 | timeRatio += maxTR/self.final_frames 409 | 410 | if "val" in self.eventAugs: 411 | #erase by value 412 | if random_array[1] > 0.8: 413 | if not (self.evtDropPol): 414 | image[image < threshold] = 0 415 | else: 416 | rd = np.random.rand() 417 | if rd > 1/3: 418 | image[image < threshold] = 0 419 | elif rd > 2/3: 420 | image[0][image[0] < threshold] = 0 421 | else: 422 | image[1][image[1] < threshold] = 0 423 | 424 | 425 | if "rand" in self.eventAugs: 426 | #random erase 427 | if random_array[3] > 0.8: 428 | #random erase not the same for each channel / time 429 | eventHide = np.random.random(image.shape) 430 | ratioHide = np.random.randint(0, 16)/100.00 431 | elif random_array[3] > 0.6: 432 | if not (self.evtDropPol): 433 | image[(eventHide < ratioHide) & (image != 0)] = 0 434 | else: 435 | rd = np.random.rand() 436 | if rd > 1/3: 437 | image[(eventHide < ratioHide) & (image != 0)] = 0 438 | elif rd > 2/3: 439 | image[0][(eventHide[0] < ratioHide) & (image[0] != 0)] = 0 440 | else: 441 | image[1][(eventHide[1] < ratioHide) & (image[1] != 0)] = 0 442 | if "time" in self.eventAugs: 443 | #erase with time 444 | if (random_array[3] > 0.4) and (random_array[3] < 0.6): 445 | if (random_array[8] > 0.5): 446 | image[eventHide < timeRatio] = 0 447 | else: 448 | # reverse order 449 | image[eventHide > (1-maxTR) + timeRatio] = 0 450 | 451 | 452 | #erase entire rectangles 453 | if "rect" in self.eventAugs: 454 | if random_array[4] > 0.75: 455 | if not (self.evtDropPol): 456 | image[:, x_erase[0]:x_erase[0] + erase_size1[0],y_erase[0]: y_erase[0] + erase_size2[0]] = 0 457 | else: 458 | rd = np.random.rand() 459 | if rd > 1/3: 460 | image[:, x_erase[0]:x_erase[0] + erase_size1[0],y_erase[0]: y_erase[0] + erase_size2[0]] = 0 461 | elif rd > 2/3: 462 | image[0, x_erase[0]:x_erase[0] + erase_size1[0],y_erase[0]: y_erase[0] + erase_size2[0]] = 0 463 | else: 464 | image[1, x_erase[0]:x_erase[0] + erase_size1[0],y_erase[0]: y_erase[0] + erase_size2[0]] = 0 465 | if random_array[5] > 0.75: 466 | if not (self.evtDropPol): 467 | image[:, x_erase[1]:x_erase[1] + erase_size1[1],y_erase[1]: y_erase[1] + erase_size2[1]] = 0 468 | else: 469 | rd = np.random.rand() 470 | if rd > 1/3: 471 | image[:, x_erase[1]:x_erase[1] + erase_size1[1],y_erase[1]: y_erase[1] + erase_size2[1]] = 0 472 | elif rd > 2/3: 473 | image[0:, x_erase[1]:x_erase[1] + erase_size1[1],y_erase[1]: y_erase[1] + erase_size2[1]] = 0 474 | else: 475 | image[1:, x_erase[1]:x_erase[1] + erase_size1[1],y_erase[1]: y_erase[1] + erase_size2[1]] = 0 476 | 477 | if "pol" in self.eventAugs: 478 | # erase pos/neg 479 | if random_array[6] > 0.75: 480 | if random_array[7] > 0.5: 481 | image[0,:,:] = 0 # erase pos events 482 | else: 483 | image[1,:,:] = 0 # erase neg events 484 | 485 | return data 486 | 487 | 488 | def sample(hdf5_file, key, T=500, shuffle=False, train=True): 489 | if train: 490 | T_default = 500 491 | else: 492 | T_default = 1800 493 | dset = hdf5_file["data"][str(key)] 494 | label = dset["labels"][()] 495 | tbegin = dset["times"][0] 496 | tend = np.maximum(0, dset["times"][-1] - 2 * T * 1000) 497 | start_time = np.random.randint(tbegin, tend + 1) if shuffle else 0 498 | # print(start_time) 499 | # tmad = get_tmad_slice(dset['times'][()], dset['addrs'][()], start_time, T*1000) 500 | tmad = get_tmad_slice( 501 | dset["times"][()], dset["addrs"][()], start_time, T_default * 1000 502 | ) 503 | tmad[:, 0] -= tmad[0, 0] 504 | meta = eval(dset.attrs["meta_info"]) 505 | return tmad[:, [0, 3, 1, 2]], label, meta["light condition"], meta["subject"] 506 | 507 | 508 | def create_events_hdf5(directory, extracted_directory, hdf5_filename): 509 | fns_train = gather_aedat(directory, extracted_directory, 1, 24) 510 | fns_test = gather_aedat(directory, extracted_directory, 24, 30) 511 | test_keys = [] 512 | train_keys = [] 513 | 514 | assert len(fns_train) == 98 515 | 516 | with h5py.File(hdf5_filename, "w") as f: 517 | f.clear() 518 | 519 | key = 0 520 | metas = [] 521 | data_grp = f.create_group("data") 522 | extra_grp = f.create_group("extra") 523 | print("\nCreating dvs_gesture.hdf5...") 524 | for file_d in tqdm(fns_train + fns_test): 525 | istrain = file_d in fns_train 526 | data, labels_starttime = aedat_to_events(file_d) 527 | tms = data[:, 0] 528 | ads = data[:, 1:] 529 | lbls = labels_starttime[:, 0] 530 | start_tms = labels_starttime[:, 1] 531 | end_tms = labels_starttime[:, 2] 532 | out = [] 533 | 534 | for i, v in enumerate(lbls): 535 | if istrain: 536 | train_keys.append(key) 537 | else: 538 | test_keys.append(key) 539 | s_ = get_slice(tms, ads, start_tms[i], end_tms[i]) 540 | times = s_[0] 541 | addrs = s_[1] 542 | # subj, light = file_d.replace('\\', '/').split('/')[-1].split('.')[0].split('_')[:2] # this line throws an error in get_slice, because idx_beg = idx_end --> empty batch 543 | subj, light = file_d.split("/")[-1].split(".")[0].split("_")[:2] 544 | metas.append( 545 | { 546 | "key": str(key), 547 | "subject": subj, 548 | "light condition": light, 549 | "training sample": istrain, 550 | } 551 | ) 552 | subgrp = data_grp.create_group(str(key)) 553 | tm_dset = subgrp.create_dataset("times", data=times, dtype=np.uint32) 554 | ad_dset = subgrp.create_dataset("addrs", data=addrs, dtype=np.uint8) 555 | lbl_dset = subgrp.create_dataset( 556 | "labels", data=lbls[i] - 1, dtype=np.uint8 557 | ) 558 | subgrp.attrs["meta_info"] = str(metas[-1]) 559 | assert lbls[i] - 1 in range(11) 560 | key += 1 561 | extra_grp.create_dataset("train_keys", data=train_keys) 562 | extra_grp.create_dataset("test_keys", data=test_keys) 563 | extra_grp.attrs["N"] = len(train_keys) + len(test_keys) 564 | extra_grp.attrs["Ntrain"] = len(train_keys) 565 | extra_grp.attrs["Ntest"] = len(test_keys) 566 | 567 | print("dvs_gesture.hdf5 was created successfully.") 568 | 569 | 570 | def gather_aedat( 571 | directory, extracted_directory, start_id, end_id, filename_prefix="user" 572 | ): 573 | if not os.path.isdir(directory): 574 | raise FileNotFoundError( 575 | "DVS Gestures Dataset not found, looked at: {}".format(directory) 576 | ) 577 | 578 | fns = [] 579 | for i in range(start_id, end_id): 580 | search_mask = ( 581 | extracted_directory 582 | + "/" 583 | + filename_prefix 584 | + "{0:02d}".format(i) 585 | + "*.aedat" 586 | ) 587 | glob_out = glob.glob(search_mask) 588 | if len(glob_out) > 0: 589 | fns += glob_out 590 | return fns 591 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EventTransAct 2 | ![Example Image](images/Model.PNG) 3 | 4 | ## Downloading dataset 5 | You can find instructions on downloading the 2 datasets used in the paper below 6 | DVS: https://research.ibm.com/interactive/dvsgesture/ 7 | N-Epic-Kitchens: https://github.com/EgocentricVision/N-EPIC-Kitchens 8 | 9 | ## Preparing dataset for DVS 10 | Clone snntorch repository: 11 | ``` 12 | git clone https://github.com/jeshraghian/snntorch/tree/master 13 | ``` 14 | Go to the snntorch/snntorch/spikevision/spikedata folder and replace dvs_gesture.py by the dvs_gesture.py file found in the /DL folder of this repo 15 | 16 | ## Running the code 17 | 18 | Finetuning on N-Epic-Kitchens dataset: 19 | ``` 20 | python train.py -c configs/config_NEK.py 21 | ``` 22 | Finetuning on DVS dataset: 23 | ``` 24 | python train.py -c configs/config_DVS_woECL.py 25 | ``` 26 | Finetuning on DVS dataset with contrastive loss: 27 | ``` 28 | python train.py -c configs/config_DVS_wECL.py 29 | ``` 30 | -------------------------------------------------------------------------------- /configs/config_DVS_wECL.py: -------------------------------------------------------------------------------- 1 | batch_size=4 2 | num_epochs = 50 3 | logFolder = 'DVS_NoAssignedFolder' 4 | 5 | arch = 'vitb'#'r50', 'vitb' 6 | learning_rate = 1e-4 7 | opt = "adam" #'adam', 'sgd' 8 | use_sched = False 9 | sched_ms = [15, 30, 45] 10 | sched_gm = 0.4 11 | three_layer_frozen = False 12 | two_layer_frozen = False 13 | pretrainedVTN = False 14 | pretrained = True 15 | if arch == 'r50': 16 | backbone = 'r50' # 'r18' 'r34' 'r50' 17 | weight_rn50_ssl = '' #"/home/c3-0/ishan/semisup_saved_models/dummy/model_best_e30_loss_2.9230.pth" # '' leave empty to unuse 18 | if (backbone != 'r50'): 19 | weight_rn50_ssl = '' 20 | 21 | cosinelr = True 22 | eventDrop = True 23 | randomcrop = False 24 | dataset = 'DVS' 25 | changing_sr = False 26 | adv_changing_dt = False 27 | rdCrop_fr = False 28 | if rdCrop_fr: 29 | randomcrop = True 30 | # NEK only 31 | trainkit = "p22" 32 | testkit = "p01" 33 | evAugs = ["rand"] 34 | 35 | dvs_imageSize = 128 36 | if arch == 'vitb': 37 | dvs_imageSize = 224 38 | 39 | 40 | 41 | train_temp_align = True 42 | ECL_weight = 1 43 | ECL = True 44 | if ECL: 45 | numClips = 2 46 | else: 47 | numClips = 1 48 | num_segments = 4 # for contrastive loss 49 | 50 | final_frames = 16 51 | num_steps = 100 52 | skip_rate = 5 53 | assert final_frames*skip_rate <= num_steps 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/config_DVS_woECL.py: -------------------------------------------------------------------------------- 1 | batch_size=4 2 | num_epochs = 50 3 | logFolder = 'DVS_NoAssignedFolder' 4 | 5 | arch = 'vitb'#'r50', 'vitb' 6 | learning_rate = 1e-4 7 | opt = "adam" #'adam', 'sgd' 8 | use_sched = False 9 | sched_ms = [15, 30, 45] 10 | sched_gm = 0.4 11 | three_layer_frozen = False 12 | two_layer_frozen = False 13 | pretrainedVTN = False 14 | pretrained = True 15 | if arch == 'r50': 16 | backbone = 'r50' # 'r18' 'r34' 'r50' 17 | weight_rn50_ssl = '' #"/home/c3-0/ishan/semisup_saved_models/dummy/model_best_e30_loss_2.9230.pth" # '' leave empty to unuse 18 | if (backbone != 'r50'): 19 | weight_rn50_ssl = '' 20 | 21 | cosinelr = True 22 | eventDrop = True 23 | randomcrop = False 24 | dataset = 'DVS' 25 | changing_sr = False 26 | adv_changing_dt = False 27 | rdCrop_fr = False 28 | if rdCrop_fr: 29 | randomcrop = True 30 | # NEK only 31 | trainkit = "p22" 32 | testkit = "p01" 33 | evAugs = ["rand"] 34 | 35 | dvs_imageSize = 128 36 | if arch == 'vitb': 37 | dvs_imageSize = 224 38 | 39 | 40 | 41 | train_temp_align = True 42 | ECL_weight = 1 43 | ECL = False 44 | if ECL: 45 | numClips = 2 46 | else: 47 | numClips = 1 48 | num_segments = 4 # for contrastive loss 49 | 50 | final_frames = 16 51 | num_steps = 100 52 | skip_rate = 5 53 | assert final_frames*skip_rate <= num_steps -------------------------------------------------------------------------------- /configs/config_NEK.py: -------------------------------------------------------------------------------- 1 | batch_size=16 2 | num_epochs = 50 3 | logFolder = 'NEK_NoAssignedFolder' 4 | 5 | arch = 'vitb'#'r50', 'vitb' 6 | learning_rate = 1e-3 7 | opt = "sgd" #'adam', 'sgd' 8 | use_sched = False 9 | sched_ms = [15, 30, 45] 10 | sched_gm = 0.4 11 | three_layer_frozen = False 12 | two_layer_frozen = False 13 | pretrainedVTN = False 14 | pretrained = True 15 | backbone = 'r50' # 'r18' 'r34' 'r50' 16 | weight_rn50_ssl = '' #"/home/c3-0/ishan/semisup_saved_models/dummy/model_best_e30_loss_2.9230.pth" # '' leave empty to unuse 17 | if (backbone != 'r50'): 18 | weight_rn50_ssl = '' 19 | 20 | cosinelr = True 21 | eventDrop = True 22 | randomcrop = False 23 | dataset = 'NEK' 24 | changing_sr = False 25 | adv_changing_dt = False 26 | rdCrop_fr = False 27 | if rdCrop_fr: 28 | randomcrop = True 29 | # NEK only 30 | trainkit = "p22" 31 | testkit = "p01" 32 | evAugs = ["rand"] 33 | 34 | dvs_imageSize = 128 35 | if arch == 'vitb': 36 | dvs_imageSize = 224 37 | 38 | 39 | 40 | train_temp_align = True 41 | ECL_weight = 1 42 | ECL = False 43 | num_segments = 5 # for contrastive loss 44 | 45 | final_frames = 10 46 | num_steps = 100 47 | skip_rate = 5 48 | assert final_frames*skip_rate <= num_steps 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /images/Model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tristandb8/EventTransAct/0264d26f92c86825b8c8b10a7c7a086fa6554b85/images/Model.PNG -------------------------------------------------------------------------------- /nt_xent_original.py: -------------------------------------------------------------------------------- 1 | #This code is taken from https://github.com/sthalles/SimCLR 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class NTXentLoss(torch.nn.Module): 7 | 8 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 9 | super(NTXentLoss, self).__init__() 10 | self.batch_size = batch_size 11 | self.temperature = temperature 12 | self.device = device 13 | self.softmax = torch.nn.Softmax(dim=-1) 14 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 15 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 16 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 17 | 18 | def _get_similarity_function(self, use_cosine_similarity): 19 | if use_cosine_similarity: 20 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 21 | return self._cosine_simililarity 22 | else: 23 | return self._dot_simililarity 24 | 25 | def _get_correlated_mask(self): 26 | diag = np.eye(2 * self.batch_size) 27 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 28 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 29 | mask = torch.from_numpy((diag + l1 + l2)) 30 | mask = (1 - mask).type(torch.bool) 31 | return mask.to(self.device) 32 | 33 | @staticmethod 34 | def _dot_simililarity(x, y): 35 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 36 | # x shape: (N, 1, C) 37 | # y shape: (1, C, 2N) 38 | # v shape: (N, 2N) 39 | return v 40 | 41 | def _cosine_simililarity(self, x, y): 42 | # x shape: (N, 1, C) 43 | # y shape: (1, 2N, C) 44 | # v shape: (N, 2N) 45 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 46 | return v 47 | 48 | def forward(self, zis, zjs): 49 | representations = torch.cat([zjs, zis], dim=0) 50 | 51 | similarity_matrix = self.similarity_function(representations, representations) 52 | # print(f'similarity_matrix shpae is {similarity_matrix.shape}') 53 | 54 | # filter out the scores from the positive samples 55 | l_pos = torch.diag(similarity_matrix, self.batch_size) 56 | # print(f'l_pos shpae is {l_pos.shape}') 57 | 58 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 59 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 60 | 61 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) 62 | 63 | logits = torch.cat((positives, negatives), dim=1) 64 | logits /= self.temperature 65 | 66 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 67 | loss = self.criterion(logits, labels) 68 | 69 | return loss / (2 * self.batch_size), logits 70 | 71 | if __name__ == '__main__': 72 | BS = 4 73 | feature = 128 74 | zis = torch.rand(BS, feature).cuda() 75 | zjs = torch.rand(BS, feature).cuda() 76 | print(zis.shape) 77 | con_loss = NTXentLoss(device = 'cuda', batch_size=BS, temperature=0.1, use_cosine_similarity = False) 78 | 79 | loss = con_loss(zis, zjs) 80 | print(loss) 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from datetime import datetime 4 | import sys 5 | import math 6 | 7 | from torch.utils.data import DataLoader 8 | import os 9 | import sys 10 | import argparse 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from torch.autograd import Variable 15 | import numpy as np 16 | from argparse import Namespace 17 | import shutil 18 | from nt_xent_original import * 19 | 20 | import argparse 21 | sys.path.insert(0, './vtn/') 22 | from parser_sf import parse_args, load_config 23 | 24 | 25 | def main(params): 26 | torch.cuda.empty_cache() 27 | assert params.final_frames*params.skip_rate <= params.num_steps 28 | save_model='./saved_models/' + params.logFolder + '/' 29 | 30 | if not os.path.exists(save_model): 31 | os.makedirs(save_model) 32 | else: 33 | idx = 0 34 | save_model = save_model.replace(params.logFolder, params.logFolder + "_" + str(idx)) 35 | while os.path.exists(save_model): 36 | logFolderb4 = params.logFolder + "_" + str(idx) 37 | idx += 1 38 | logFolder = params.logFolder + "_" + str(idx) 39 | save_model = save_model.replace(logFolderb4, logFolder) 40 | os.makedirs(save_model) 41 | 42 | logFile = open(save_model + 'logfile.txt', 'a') 43 | 44 | if params.dataset == 'NEK': 45 | from DL.dl_ft_1_train_O_ECL import ek_train, collate_fn2 46 | from DL.dl_ft_1_test_O_ECL import ek_test, collate_fn_test 47 | 48 | train_dataset = ek_train(shuffle = True, trainKitchen = 'p01', eventDrop = params.eventDrop, eventAugs = params.evAugs, numClips = params.numClips) 49 | train_dataloader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True, num_workers=4, 50 | collate_fn=collate_fn2, drop_last = True) 51 | 52 | elif params.dataset == 'DVS': 53 | sys.path.append('snntorch/snntorch') 54 | from spikevision.spikedata.dvs_gesture import DVSGesture 55 | 56 | train_dataset = DVSGesture("/home/tr248228/RP_EvT/October/videoMae/DVS/download", train=True, dt = int(500000/params.num_steps), num_steps=params.num_steps, 57 | eventDrop = params.eventDrop, eventAugs = params.evAugs, skip_rate = params.skip_rate, final_frames=params.final_frames, 58 | randomcrop = params.randomcrop, numClips = params.numClips, train_temp_align = params.train_temp_align, rdCrop_fr = params.rdCrop_fr, 59 | changing_sr = params.changing_sr, adv_changing_dt = params.adv_changing_dt, dvs_imageSize = params.dvs_imageSize) 60 | 61 | train_dataloader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False, num_workers=4, drop_last = True) 62 | print(f'Train dataset length: {len(train_dataset)}') 63 | logFile.write(f'Train dataset length: {len(train_dataset)}\n') 64 | 65 | if params.ECL: 66 | from vtn_ECL import VTN 67 | else: 68 | from vtn import VTN 69 | 70 | args = Namespace(cfg_file='configs/Kinetics/SLOWFAST_4x16_R50.yaml', init_method='tcp://localhost:9999', num_shards=1, opts=[], shard_id=0) 71 | 72 | if params.arch == 'r50': 73 | args.cfg_file = 'vtn/eventR50_VTN.yaml' 74 | if params.arch == 'vitb': 75 | args.cfg_file = 'vtn/eventVIT_B_VTN.yaml' 76 | 77 | cfg = load_config(args) 78 | if params.dataset == 'NEK': 79 | cfg.MODEL.NUM_CLASSES = 8 80 | if params.dataset == 'DVS': 81 | cfg.MODEL.NUM_CLASSES = 11 82 | 83 | 84 | if params.arch == 'r50': 85 | model = VTN(cfg, params.weight_rn50_ssl, params.backbone, params.pretrained).cuda() 86 | elif params.arch == 'vitb': 87 | model = VTN(cfg, '', '', True).cuda() 88 | if params.pretrainedVTN: 89 | pretrained_kvpair = torch.load('vtn/VTN_VIT_B_KINETICS.pyth')['model_state'] 90 | model_kvpair = model.state_dict() 91 | for layer_name, weights in pretrained_kvpair.items(): 92 | if 'mlp_head.4' in layer_name or 'temporal_encoder.embeddings.position_ids' in layer_name:# in layer_name or 'temporal_encoder.embeddings.position_embeddings' in layer_name: 93 | print(f'Skipping {layer_name}') 94 | logFile.write(f'Skipping {layer_name}\n') 95 | continue 96 | model_kvpair[layer_name] = weights 97 | model.load_state_dict(model_kvpair, strict=True) 98 | print('model loaded successfully') 99 | logFile.write('model loaded successfully\n') 100 | 101 | 102 | 103 | exclusion_name = [] 104 | if params.three_layer_frozen: 105 | exclusion_name = ['layer4'] 106 | elif params.two_layer_frozen: 107 | exclusion_name = ['layer3', 'layer4'] 108 | if len(exclusion_name) > 0: 109 | for name, par in model.named_parameters(): 110 | if 'backbone' in name: 111 | # still it will have M learnable params 112 | if not any([exclusion_name_el in name for exclusion_name_el in exclusion_name]): 113 | print(f'Freezing {name}') 114 | logFile.write(f'Freezing {name}') 115 | par.requires_grad = False 116 | 117 | 118 | if torch.cuda.device_count()>1: 119 | print(f'Multiple GPUS found!') 120 | logFile.write(f'Multiple GPUS found!\n') 121 | model=nn.DataParallel(model) 122 | model.cuda() 123 | 124 | else: 125 | print('Only 1 GPU is available') 126 | logFile.write('Only 1 GPU is available\n') 127 | model.cuda() 128 | 129 | if params.opt == 'adam': 130 | optimizer = optim.Adam(model.parameters(), lr=params.learning_rate) 131 | elif params.opt == 'sgd': 132 | optimizer = optim.SGD(model.parameters(), lr=params.learning_rate) 133 | else: 134 | exit() 135 | 136 | if params.cosinelr: 137 | cosine_lr_array = list(np.linspace(0.01,1, 5)) + [(math.cos(x) + 1)/2 for x in np.linspace(0,math.pi/0.99, params.num_epochs-5)] 138 | 139 | 140 | if (params.use_sched): 141 | lr_sched = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.sched_ms, gamma=params.sched_gm) 142 | 143 | 144 | 145 | #num_steps_per_update = 4 # accum gradient 146 | steps = 0 147 | if params.dataset == "NEK": 148 | class_count = [164,679,242,210,119,39,1,113] 149 | weights = 1 - (torch.tensor(class_count)/1567) 150 | weights = weights.cuda() 151 | criterion= torch.nn.CrossEntropyLoss(weight=weights.float()).cuda() 152 | elif params.dataset == "DVS": 153 | criterion= torch.nn.CrossEntropyLoss().cuda() 154 | model.train() 155 | criterion_intra = NTXentLoss(device = 'cuda', batch_size=params.num_segments, temperature=0.1, use_cosine_similarity = False) 156 | acc = 0 157 | bestacc = 0 158 | 159 | for epoch in range(params.num_epochs): 160 | losses, ce_losses, con_losses = [], [], [] 161 | intra_csl2d_logits_predictions = [] 162 | if params.cosinelr: 163 | learning_rate2 = cosine_lr_array[epoch]*params.learning_rate 164 | for param_group in optimizer.param_groups: 165 | param_group['lr']=learning_rate2 166 | print(f"Learning rate is: {param_group['lr']}") 167 | logFile.write(f"Learning rate is: {param_group['lr']}\n") 168 | for i, data in enumerate(train_dataloader, 0): 169 | if params.ECL: 170 | inputs, inputs1, labels, pathBS = data 171 | else: 172 | inputs, labels, pathBS = data 173 | if (i == 0) & (epoch == 0): 174 | print("inputs.shape", inputs.shape, flush = True) 175 | logFile.write(f"inputs.shape {inputs.shape} \n") 176 | optimizer.zero_grad() 177 | 178 | inputs = inputs.permute(0,2,1,3,4) #aug_DL output is [120, 16, 3, 112, 112]], #model expects [8, 3, 16, 112, 112] 179 | inputs = Variable(inputs.cuda()) 180 | if params.ECL: 181 | inputs1 = inputs1.permute(0,2,1,3,4) 182 | inputs1 = Variable(inputs1.cuda()) 183 | labels = torch.as_tensor(labels) 184 | labels = Variable(labels.cuda()) 185 | 186 | frameids1= torch.arange(0, inputs.shape[2],1).to(torch.int).repeat(inputs.shape[0], 1).cuda() 187 | 188 | if params.ECL: 189 | per_frame_logits, twoDrep1 = model([inputs, frameids1]) 190 | _, twoDrep2 = model([inputs1, frameids1]) 191 | else: 192 | per_frame_logits = model([inputs, frameids1]) 193 | 194 | 195 | ce_loss = criterion(per_frame_logits,labels.long()) 196 | ce_losses.append(ce_loss.cpu().detach().numpy()) 197 | 198 | if params.ECL: 199 | con_loss = 0 200 | for ii in range(0, twoDrep1.shape[0], inputs.shape[2]): 201 | temp1, temp2 = criterion_intra(twoDrep1[ii:ii+inputs.shape[2]:params.num_segments,:], twoDrep2[ii:ii+inputs.shape[2]:params.num_segments,:]) 202 | 203 | intra_csl2d_logits_predictions.extend(torch.max(temp2, axis=1).indices.cpu().numpy()) 204 | 205 | con_loss += temp1 206 | con_loss/= (twoDrep1.shape[0]/params.final_frames) 207 | con_losses.append(con_loss.cpu().detach().numpy()) 208 | loss = ce_loss * params.ECL_weight + con_loss 209 | else: 210 | loss = ce_loss 211 | losses.append(loss.cpu().detach().numpy()) 212 | loss.backward() 213 | optimizer.step() 214 | steps += 1 215 | if (steps+1) % 100 == 0: 216 | print('Epoch {} average loss: {:.4f}'.format(epoch,np.mean(losses)), flush = True) 217 | logFile.write('Epoch {} average loss: {:.4f}\n'.format(epoch,np.mean(losses))) 218 | if (params.use_sched): 219 | lr_sched.step() 220 | if((epoch%20==0) and (epoch > 0)): 221 | print("optimizer", optimizer) 222 | logFile.write("optimizer\n" + str(optimizer) + "\n") 223 | 224 | signal = "===============================================\n" 225 | if params.ECL: 226 | eoe = "End of epoch " + str(epoch)+ ", ECL : " + str(np.mean(con_losses)) + ", mean loss: " + str(np.mean(losses)) + "\n" 227 | else: 228 | eoe = "End of epoch " + str(epoch) + ", mean loss: " + str(np.mean(losses)) + "\n" 229 | print(signal + eoe + signal) 230 | logFile.write(signal + eoe + signal) 231 | 232 | if params.ECL: 233 | intra_csl2d_logits_predictions = np.asarray(intra_csl2d_logits_predictions) 234 | intracontrastive2d_acc = ((intra_csl2d_logits_predictions == 0).sum())/len(intra_csl2d_logits_predictions) 235 | print(f'intra-2D Contrastive Accuracy at Epoch {epoch} is {intracontrastive2d_acc*100 :0.3f}') 236 | logFile.write(f'intra-2D Contrastive Accuracy at Epoch {epoch} is {intracontrastive2d_acc*100 :0.3f}\n') 237 | logFile.flush() 238 | 239 | if(epoch%4==0) or (epoch + 10 > params.num_epochs): 240 | if (params.dataset == "NEK"): 241 | acc = validate(model, epoch, logFile, ek_test, collate_fn_test, params.testkit, isTest = True) 242 | if (epoch%20==0): 243 | for testk in list(set(["p22", "p08", "p01"]) - set([params.testkit])): 244 | validate(model,epoch, logFile, ek_test, collate_fn_test, testk, isTest = True) 245 | elif params.dataset == 'DVS': 246 | acc = validateDVS(model,epoch, logFile, DVSGesture, params.num_steps, params.final_frames, params.skip_rate, ECL = True, dvs_imageSize = params.dvs_imageSize) 247 | if acc > bestacc: 248 | bestacc = acc 249 | print("BEST!!!!") 250 | logFile.write("BEST!!!! \n") 251 | 252 | model.train() 253 | 254 | torch.save(model.state_dict(), save_model+str(epoch).zfill(6)+'.pt') 255 | now = datetime.now() 256 | d8 = now.strftime("%d%m%Y") 257 | current_time = now.strftime("%H:%M:%S") 258 | weightStatus = d8 + " | " + current_time + " saving weights to: " + save_model +str(epoch) 259 | print(weightStatus) 260 | logFile.write(weightStatus + "\n") 261 | logFile.write("---------------------file close-------------------------------\n") 262 | logFile.close() 263 | 264 | 265 | def validate(model,epoch, logFile, ek_test, collate_fn_test, testKitchen, isTest = True): 266 | 267 | if (isTest): 268 | str1 = "Validation" 269 | else: 270 | str1 = "Training" 271 | print(f"*************************{str1} accuracy at epoch {epoch}********************") 272 | model.eval() 273 | batch_size = 1 274 | test_dataset = ek_test(shuffle = False, Test = isTest, kitchen = testKitchen) 275 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers = 8, shuffle=False,collate_fn=collate_fn_test, drop_last = True) 276 | print(f'{str1} dataset length: {len(test_dataset)}') 277 | count = 0 278 | pred_vid = np.zeros((batch_size,1),dtype=(int)) 279 | 280 | for i,data in enumerate(test_dataloader, 0): 281 | clip_1,clip_2,clip_3,clip_4,clip_5,labels,pathBS = data 282 | frameids = torch.arange(0, clip_1.shape[1],1) 283 | frameids = frameids.to(torch.int).repeat(clip_1.shape[0], 1).cuda() 284 | 285 | 286 | clip_1 = clip_1.permute(0,2,1,3,4) 287 | clip_2 = clip_2.permute(0,2,1,3,4) 288 | clip_3 = clip_3.permute(0,2,1,3,4) 289 | clip_4 = clip_4.permute(0,2,1,3,4) 290 | clip_5 = clip_5.permute(0,2,1,3,4) 291 | 292 | clip_1 = Variable(clip_1.cuda()) 293 | clip_2 = Variable(clip_2.cuda()) 294 | clip_3 = Variable(clip_3.cuda()) 295 | clip_4 = Variable(clip_4.cuda()) 296 | clip_5 = Variable(clip_5.cuda()) 297 | 298 | labels = [(x.numpy()) for x in labels][0] 299 | 300 | pred_clip_1 = model([clip_1, frameids])[0].squeeze() 301 | pred_clip_2 = model([clip_2, frameids])[0].squeeze() 302 | pred_clip_3 = model([clip_3, frameids])[0].squeeze() 303 | pred_clip_4 = model([clip_4, frameids])[0].squeeze() 304 | pred_clip_5 = model([clip_5, frameids])[0].squeeze() 305 | 306 | sftmx = torch.nn.Softmax(dim=0) 307 | pred_clip_1 = sftmx(pred_clip_1) 308 | pred_clip_2 = sftmx(pred_clip_2) 309 | pred_clip_3 = sftmx(pred_clip_3) 310 | pred_clip_4 = sftmx(pred_clip_4) 311 | pred_clip_5 = sftmx(pred_clip_5) 312 | idxs_mean = [] 313 | for i in range(len(pred_clip_1)): 314 | idxs_mean.append(np.mean([pred_clip_1.cpu().detach().numpy()[i], pred_clip_2.cpu().detach().numpy()[i], pred_clip_3.cpu().detach().numpy()[i], pred_clip_4.cpu().detach().numpy()[i], pred_clip_5.cpu().detach().numpy()[i]])) 315 | pred_vid = idxs_mean.index(max(idxs_mean)) 316 | 317 | if(pred_vid==labels[0]): 318 | count+=1 319 | 320 | acc = count/len(test_dataset)*100 321 | print(str(testKitchen), str1, "accuracy:", acc) 322 | logFile.write(str(testKitchen) + str1 + " accuracy: " + str(acc) + "\n") 323 | print(f'*****************************************************************************') 324 | return acc 325 | 326 | def validateDVS(model,epoch, logFile, DVSGesture, num_steps, final_frames, skip_rate, ECL = False, dvs_imageSize = 128, val_cr = True, numClips = 5): 327 | print(f'*************************Test Accuracy********************') 328 | print(f'Checking Test Accuracy at epoch {epoch}') 329 | model.eval() 330 | bs = 1 331 | num_steps_test = int(np.floor(num_steps / 5 * 18)) 332 | 333 | test_set = DVSGesture("/home/tr248228/RP_EvT/October/videoMae/DVS/download", train=False, 334 | num_steps=num_steps_test, dt=int(500000/num_steps), final_frames = final_frames, 335 | skip_rate = skip_rate, numClips = numClips, isVal = True, dvs_imageSize = dvs_imageSize, val_cr = val_cr) 336 | test_dataloader = DataLoader(test_set, batch_size=bs, shuffle=True, num_workers=4, drop_last = True) 337 | 338 | count = 0 339 | for i, data in enumerate(test_dataloader, 0): 340 | clips, clip_label, pathBS = data 341 | clipPred = [] 342 | for j in range(len(clips[0])): 343 | video = clips[:,j] 344 | video = video.permute(0,2,1,3,4) 345 | input = Variable(video.cuda()) 346 | frameids = torch.arange(0, video.shape[2],1).to(torch.int).repeat(video.shape[0], 1).cuda() 347 | pred = model([input, frameids]) 348 | if ECL: 349 | pred = pred[0] 350 | pred = pred.squeeze() 351 | sftmx = torch.nn.Softmax(dim=0) 352 | pred_clip_1 = sftmx(pred) 353 | clipPred.append(pred_clip_1[None, :]) 354 | clipPred = torch.cat(clipPred, dim=0) 355 | idxs_mean = [] 356 | for k in range(len(pred_clip_1)): 357 | idxs_mean.append(torch.mean(clipPred[:,k])) 358 | pred_vid = idxs_mean.index(max(idxs_mean)) 359 | if(pred_vid==clip_label[0]): 360 | count+=1 361 | acc = count/len(test_set)*100 362 | 363 | print("test accuracy: " + str(acc) + "\n") 364 | print(f'**************************************************************') 365 | logFile.write("test accuracy: " + str(acc) + "\n") 366 | 367 | return acc 368 | 369 | 370 | if __name__ == "__main__": 371 | import argparse, importlib 372 | parser = argparse.ArgumentParser(description='Script to finetune VTN w/ or w/o ECL') 373 | 374 | parser.add_argument('-c', '--config', type=str, help='Path to the config file') 375 | 376 | args = parser.parse_args() 377 | 378 | spec = importlib.util.spec_from_file_location('params', args.config) 379 | params = importlib.util.module_from_spec(spec) 380 | spec.loader.exec_module(params) 381 | 382 | main(params) 383 | 384 | -------------------------------------------------------------------------------- /vtn/eventR50_VTN.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 16 5 | EVAL_PERIOD: 1 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | EVAL_FULL_VIDEO: True 9 | EVAL_NUM_FRAMES: 250 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 8 13 | TARGET_FPS: 25 14 | TRAIN_JITTER_SCALES: [256, 320] 15 | TRAIN_CROP_SIZE: 128 16 | TEST_CROP_SIZE: 128 17 | INPUT_CHANNEL_NUM: [3] ## 2 18 | SOLVER: 19 | BASE_LR: 0.001 20 | LR_POLICY: 21 | STEPS: [0, 13, 24] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 25 24 | MOMENTUM: 0.9 25 | OPTIMIZING_METHOD: sgd 26 | MODEL: 27 | NUM_CLASSES: 8 28 | ARCH: R50 #VIT #VIT21k 29 | MODEL_NAME: VTN 30 | LOSS_FUNC: cross_entropy 31 | DROPOUT_RATE: 0.5 32 | VTN: 33 | PRETRAINED: True 34 | MLP_DIM: 768 35 | DROP_PATH_RATE: 0.0 36 | DROP_RATE: 0.0 37 | HIDDEN_DIM: 512 #768 38 | MAX_POSITION_EMBEDDINGS: 15 #I used 9 until Jan 22nd 39 | NUM_ATTENTION_HEADS: 8 #16 #12 40 | NUM_HIDDEN_LAYERS: 3 ## 2 41 | ATTENTION_MODE: 'sliding_chunks' 42 | PAD_TOKEN_ID: -1 43 | ATTENTION_WINDOW: [18, 18, 18] ## remove one 18 44 | INTERMEDIATE_SIZE: 4096 #3072 45 | ATTENTION_PROBS_DROPOUT_PROB: 0.1 46 | HIDDEN_DROPOUT_PROB: 0.1 47 | TEST: 48 | ENABLE: True 49 | DATASET: kinetics 50 | BATCH_SIZE: 16 51 | NUM_ENSEMBLE_VIEWS: 1 52 | NUM_SPATIAL_CROPS: 1 53 | DATA_LOADER: 54 | NUM_WORKERS: 8 55 | PIN_MEMORY: True 56 | NUM_GPUS: 4 57 | NUM_SHARDS: 1 58 | RNG_SEED: 0 59 | OUTPUT_DIR: . 60 | LOG_MODEL_INFO: False -------------------------------------------------------------------------------- /vtn/eventVIT_B_VTN.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 16 5 | EVAL_PERIOD: 1 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | EVAL_FULL_VIDEO: True 9 | EVAL_NUM_FRAMES: 250 10 | DATA: 11 | NUM_FRAMES: 16 12 | SAMPLING_RATE: 8 13 | TARGET_FPS: 25 14 | TRAIN_JITTER_SCALES: [256, 320] 15 | TRAIN_CROP_SIZE: 128 16 | TEST_CROP_SIZE: 128 17 | INPUT_CHANNEL_NUM: [3] 18 | SOLVER: 19 | BASE_LR: 0.001 20 | LR_POLICY: 21 | STEPS: [0, 13, 24] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 25 24 | MOMENTUM: 0.9 25 | OPTIMIZING_METHOD: sgd 26 | MODEL: 27 | NUM_CLASSES: 8 28 | ARCH: VIT #VIT21k 29 | MODEL_NAME: VTN 30 | LOSS_FUNC: cross_entropy 31 | DROPOUT_RATE: 0.5 32 | VTN: 33 | PRETRAINED: True 34 | MLP_DIM: 768 35 | DROP_PATH_RATE: 0.0 36 | DROP_RATE: 0.0 37 | HIDDEN_DIM: 768 38 | MAX_POSITION_EMBEDDINGS: 288 39 | NUM_ATTENTION_HEADS: 12 40 | NUM_HIDDEN_LAYERS: 3 41 | ATTENTION_MODE: 'sliding_chunks' 42 | PAD_TOKEN_ID: -1 43 | ATTENTION_WINDOW: [18, 18, 18] 44 | INTERMEDIATE_SIZE: 3072 45 | ATTENTION_PROBS_DROPOUT_PROB: 0.1 46 | HIDDEN_DROPOUT_PROB: 0.1 47 | TEST: 48 | ENABLE: True 49 | DATASET: kinetics 50 | BATCH_SIZE: 16 51 | NUM_ENSEMBLE_VIEWS: 1 52 | NUM_SPATIAL_CROPS: 1 53 | DATA_LOADER: 54 | NUM_WORKERS: 8 55 | PIN_MEMORY: True 56 | NUM_GPUS: 4 57 | NUM_SHARDS: 1 58 | RNG_SEED: 0 59 | OUTPUT_DIR: . 60 | LOG_MODEL_INFO: False -------------------------------------------------------------------------------- /vtn/parser_sf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Argument parser functions.""" 5 | 6 | import argparse 7 | import sys, os 8 | 9 | # import slowfast.utils.checkpoint as cu 10 | from slowfast_config_defaults import get_cfg 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Parse the following arguments for a default parser for PySlowFast users. 16 | Args: 17 | shard_id (int): shard id for the current machine. Starts from 0 to 18 | num_shards - 1. If single machine is used, then set shard id to 0. 19 | num_shards (int): number of shards using by the job. 20 | init_method (str): initialization method to launch the job with multiple 21 | devices. Options includes TCP or shared file-system for 22 | initialization. details can be find in 23 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 24 | cfg (str): path to the config file. 25 | opts (argument): provide addtional options from the command line, it 26 | overwrites the config loaded from file. 27 | """ 28 | parser = argparse.ArgumentParser( 29 | description="Provide SlowFast video training and testing pipeline." 30 | ) 31 | parser.add_argument( 32 | "--shard_id", 33 | help="The shard id of current node, Starts from 0 to num_shards - 1", 34 | default=0, 35 | type=int, 36 | ) 37 | parser.add_argument( 38 | "--num_shards", 39 | help="Number of shards using by the job", 40 | default=1, 41 | type=int, 42 | ) 43 | parser.add_argument( 44 | "--init_method", 45 | help="Initialization method, includes TCP or shared file-system", 46 | default="tcp://localhost:9999", 47 | type=str, 48 | ) 49 | parser.add_argument( 50 | "--cfg", 51 | dest="cfg_file", 52 | help="Path to the config file", 53 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 54 | type=str, 55 | ) 56 | parser.add_argument( 57 | "opts", 58 | help="See slowfast/config/defaults.py for all options", 59 | default=None, 60 | nargs=argparse.REMAINDER, 61 | ) 62 | if len(sys.argv) == 1: 63 | parser.print_help() 64 | return parser.parse_args() 65 | 66 | 67 | def load_config(args): 68 | """ 69 | Given the arguemnts, load and initialize the configs. 70 | Args: 71 | args (argument): arguments includes `shard_id`, `num_shards`, 72 | `init_method`, `cfg_file`, and `opts`. 73 | """ 74 | # Setup cfg. 75 | cfg = get_cfg() 76 | # Load config from cfg. 77 | if args.cfg_file is not None: 78 | cfg.merge_from_file(args.cfg_file) 79 | # Load config from command line, overwrite config from opts. 80 | if args.opts is not None: 81 | cfg.merge_from_list(args.opts) 82 | 83 | # Inherit parameters from args. 84 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 85 | cfg.NUM_SHARDS = args.num_shards 86 | cfg.SHARD_ID = args.shard_id 87 | if hasattr(args, "rng_seed"): 88 | cfg.RNG_SEED = args.rng_seed 89 | if hasattr(args, "output_dir"): 90 | cfg.OUTPUT_DIR = args.output_dir 91 | 92 | # Create the checkpoint dir. 93 | # cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 94 | # I NEED TO MAKE THE CHECKPOINT DIR ON MY OWN 95 | if not os.path.exists(cfg.OUTPUT_DIR): 96 | os.makedirs(cfg.OUTPUT_DIR) 97 | 98 | return cfg -------------------------------------------------------------------------------- /vtn/vtn_ECL.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from timm.models.vision_transformer import vit_base_patch16_224, vit_base_patch16_224_in21k 6 | import vtn_helper 7 | from mlp import mlp 8 | 9 | class VTN(nn.Module): 10 | """ 11 | VTN model builder. It uses ViT-Base as the backbone. 12 | Daniel Neimark, Omri Bar, Maya Zohar and Dotan Asselmann. 13 | "Video Transformer Network." 14 | https://arxiv.org/abs/2102.00719 15 | """ 16 | 17 | def __init__(self, cfg, weight, backbone, pretrained): 18 | """ 19 | The `__init__` method of any subclass should also contain these 20 | arguments. 21 | Args: 22 | cfg (CfgNode): model building configs, details are in the 23 | comments of the config file. 24 | """ 25 | super(VTN, self).__init__() 26 | self._construct_network(cfg, weight, backbone, pretrained) 27 | 28 | def _construct_network(self, cfg, weight, backbone, pretrained): 29 | """ 30 | Builds a VTN model, with a given backbone architecture. 31 | Args: 32 | cfg (CfgNode): model building configs, details are in the 33 | comments of the config file. 34 | """ 35 | #print("cfg.MODEL.ARCH", cfg.MODEL.ARCH) 36 | if cfg.MODEL.ARCH == "VIT": 37 | self.backbone = vit_base_patch16_224(pretrained=cfg.VTN.PRETRAINED, 38 | num_classes=0, 39 | drop_path_rate=cfg.VTN.DROP_PATH_RATE, 40 | drop_rate=cfg.VTN.DROP_RATE) 41 | elif cfg.MODEL.ARCH == "VIT21k": 42 | self.backbone = vit_base_patch16_224_in21k(pretrained=cfg.VTN.PRETRAINED, 43 | num_classes=0, 44 | drop_path_rate=cfg.VTN.DROP_PATH_RATE, 45 | drop_rate=cfg.VTN.DROP_RATE) 46 | elif cfg.MODEL.ARCH == 'R50': 47 | #print("cfg.VTN.PRETRAINED", cfg.VTN.PRETRAINED) 48 | print('---backbone---', backbone) 49 | if (backbone == 'r50'): 50 | self.backbone = torchvision.models.resnet50(pretrained = pretrained) 51 | if weight != '': 52 | pretrained_kvpair = torch.load(weight)['state_dict'] 53 | model_kvpair = self.backbone.state_dict() 54 | 55 | for layer_name, weights in pretrained_kvpair.items(): 56 | if layer_name[:2] == '0.': 57 | layer_name = layer_name[2:] 58 | if layer_name[:2] == '1.': 59 | # print(f'excluding {layer_name}') 60 | continue 61 | model_kvpair[layer_name] = weights 62 | self.backbone.load_state_dict(model_kvpair, strict=True) 63 | if (backbone == 'r18'): 64 | self.backbone = torchvision.models.resnet18(pretrained = pretrained) 65 | if (backbone == 'r34'): 66 | self.backbone = torchvision.models.resnet34(pretrained = pretrained) 67 | self.backbone.fc = nn.Identity() 68 | 69 | 70 | #VTN_VIT_B_KINETICS.pyth 71 | 72 | else: 73 | raise NotImplementedError(f"not supporting {cfg.MODEL.ARCH}") 74 | if cfg.MODEL.ARCH == 'VIT': 75 | embed_dim = 768 76 | else: 77 | embed_dim = 2048 78 | if (backbone == 'r18'): 79 | embed_dim = 512 80 | if (backbone == 'r34'): 81 | embed_dim = 512 82 | 83 | self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 84 | 85 | self.temporal_encoder = vtn_helper.VTNLongformerModel( 86 | embed_dim=embed_dim, 87 | max_position_embeddings=cfg.VTN.MAX_POSITION_EMBEDDINGS, 88 | num_attention_heads=cfg.VTN.NUM_ATTENTION_HEADS, 89 | num_hidden_layers=cfg.VTN.NUM_HIDDEN_LAYERS, 90 | attention_mode=cfg.VTN.ATTENTION_MODE, 91 | pad_token_id=cfg.VTN.PAD_TOKEN_ID, 92 | attention_window=cfg.VTN.ATTENTION_WINDOW, 93 | intermediate_size=cfg.VTN.INTERMEDIATE_SIZE, 94 | attention_probs_dropout_prob=cfg.VTN.ATTENTION_PROBS_DROPOUT_PROB, 95 | hidden_dropout_prob=cfg.VTN.HIDDEN_DROPOUT_PROB) 96 | 97 | self.mlp_head = nn.Sequential( 98 | nn.LayerNorm(embed_dim), 99 | nn.Linear(embed_dim, cfg.VTN.MLP_DIM), 100 | nn.GELU(), 101 | nn.Dropout(cfg.MODEL.DROPOUT_RATE), 102 | nn.Linear(cfg.VTN.MLP_DIM, cfg.MODEL.NUM_CLASSES) 103 | ) 104 | 105 | self.twoDmlp = mlp(feature_size= embed_dim) 106 | self.position_id = 0 107 | 108 | 109 | def forward(self, x, bboxes=None): 110 | 111 | x, position_ids = x 112 | 113 | # spatial backbone 114 | B, C, F, H, W = x.shape 115 | x = x.permute(0, 2, 1, 3, 4) 116 | x = x.reshape(B * F, C, H, W) 117 | x = self.backbone(x) 118 | twoDrep = self.twoDmlp(x) 119 | 120 | # max pool over feature --> 5,5,1 --> upscale 121 | # print(f'Shape after backbone {x.shape}') 122 | x = x.reshape(B, F, -1) 123 | 124 | 125 | # temporal encoder (Longformer) 126 | B, D, E = x.shape 127 | attention_mask = torch.ones((B, D), dtype=torch.long, device=x.device) 128 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 129 | x = torch.cat((cls_tokens, x), dim=1) 130 | cls_atten = torch.ones(1).expand(B, -1).to(x.device) 131 | attention_mask = torch.cat((attention_mask, cls_atten), dim=1) 132 | attention_mask[:, 0] = 2 133 | x, attention_mask, position_ids = vtn_helper.pad_to_window_size_local( 134 | x, 135 | attention_mask, 136 | position_ids, 137 | self.temporal_encoder.config.attention_window[0], 138 | self.temporal_encoder.config.pad_token_id) 139 | token_type_ids = torch.zeros(x.size()[:-1], dtype=torch.long, device=x.device) 140 | token_type_ids[:, 0] = 1 141 | 142 | # position_ids 143 | position_ids = position_ids.long() 144 | mask = attention_mask.ne(0).int() 145 | max_position_embeddings = self.temporal_encoder.config.max_position_embeddings 146 | position_ids = position_ids % (max_position_embeddings - 2) 147 | position_ids[:, 0] = max_position_embeddings - 2 148 | # print("position_ids") 149 | # print(position_ids.shape) 150 | # print("mask") 151 | # print(mask.shape) 152 | position_ids[mask == 0] = max_position_embeddings - 1 153 | 154 | x = self.temporal_encoder(input_ids=None, 155 | attention_mask=attention_mask, 156 | token_type_ids=token_type_ids, 157 | position_ids=position_ids, 158 | inputs_embeds=x, 159 | output_attentions=None, 160 | output_hidden_states=None, 161 | return_dict=None) 162 | # MLP head 163 | x = x["last_hidden_state"] 164 | x = self.mlp_head(x[:, 0]) 165 | return x, twoDrep 166 | 167 | if __name__ == '__main__': 168 | import numpy as np 169 | from torchsummary import summary 170 | 171 | bs = 2 172 | num_frames = 8 173 | pos_array = torch.from_numpy(np.asarray(list(range(num_frames)))).unsqueeze(0).repeat(bs,1).cuda() 174 | #print(pos_array.shape) 175 | rand_input = [ torch.randn(bs,3,num_frames,224,224).cuda(), pos_array] 176 | # print(rand_input.shape) 177 | 178 | 179 | import argparse 180 | import sys, os 181 | from slowfast_config_defaults import get_cfg 182 | from parser_sf import parse_args, load_config 183 | 184 | args = parse_args() 185 | args.cfg_file = 'VIT_B_VTN.yaml' 186 | # args.cfg_file = 'eventR50_VTN.yaml' 187 | # 188 | cfg = load_config(args) 189 | # print(cfg) # config read operation seems working for now, not sure what to ignore in the read config 190 | 191 | 192 | # print(vtn_model) 193 | 194 | print("main vtn.py calls to print model") 195 | 196 | 197 | 198 | # vtn_model.load_state_dict(pretrained, strict=True) 199 | # exit() 200 | # for layer_name, weights in pretrained_kvpair.items(): 201 | # if 'temporal_encoder.embeddings' in layer_name: 202 | # print(layer_name) 203 | # if 'temporal_encoder.embeddings.position_ids' in layer_name: 204 | # print(weights) 205 | 206 | 207 | 208 | # for layer_name, weights in model_kvpair.items(): 209 | # if 'temporal_encoder.embeddings' in layer_name: 210 | # print(layer_name) 211 | 212 | # exit() 213 | vtn_model = VTN(cfg, '', '', True).cuda() 214 | pretrained_kvpair = torch.load('VTN_VIT_B_KINETICS.pyth')['model_state'] 215 | model_kvpair = vtn_model.state_dict() 216 | for layer_name, weights in pretrained_kvpair.items(): 217 | # layer_name.replace('position_id','position_embeddings') 218 | 219 | if 'mlp_head.4' in layer_name or 'temporal_encoder.embeddings.position_ids' in layer_name or 'temporal_encoder.embeddings.position_embeddings' in layer_name: 220 | print(f'Skipping {layer_name}') 221 | continue 222 | model_kvpair[layer_name] = weights 223 | vtn_model.load_state_dict(model_kvpair, strict=True) 224 | print('model loaded successfully') 225 | 226 | 227 | # model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long] 228 | 229 | # print(vtn_model) 230 | 231 | # for name, param in vtn_model.named_parameters(): 232 | # if 'backbone' in name: 233 | # param.requires_grad = False 234 | 235 | # for name, param in vtn_model.named_parameters(): 236 | # print(name, param.requires_grad) 237 | 238 | 239 | # for m in list(model.parameters())[:-2]: 240 | # m.requires_grad = False 241 | 242 | 243 | output = vtn_model(rand_input) 244 | print(output[0].shape, output[1].shape) 245 | 246 | 247 | #print(torch.cuda.memory_allocated()/1e9) #3.433 248 | # summary(vtn_model, [(3,16,224,224), (1, 16)], dtypes=[torch.float, torch.float]) 249 | # summary(vtn_model, rand_input) 250 | #print(output.shape) 251 | 252 | 253 | 254 | -------------------------------------------------------------------------------- /vtn/vtn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LongformerModel, LongformerConfig 3 | import torch.nn.functional as F 4 | 5 | 6 | class VTNLongformerModel(LongformerModel): 7 | 8 | def __init__(self, 9 | embed_dim=768, 10 | max_position_embeddings=2 * 60 * 60, 11 | num_attention_heads=12, 12 | num_hidden_layers=3, 13 | attention_mode='sliding_chunks', 14 | pad_token_id=-1, 15 | attention_window=None, 16 | intermediate_size=3072, 17 | attention_probs_dropout_prob=0.1, 18 | hidden_dropout_prob=0.1): 19 | 20 | self.config = LongformerConfig() 21 | self.config.attention_mode = attention_mode 22 | self.config.intermediate_size = intermediate_size 23 | self.config.attention_probs_dropout_prob = attention_probs_dropout_prob 24 | self.config.hidden_dropout_prob = hidden_dropout_prob 25 | self.config.attention_dilation = [1, ] * num_hidden_layers 26 | self.config.attention_window = [256, ] * num_hidden_layers if attention_window is None else attention_window 27 | self.config.num_hidden_layers = num_hidden_layers 28 | self.config.num_attention_heads = num_attention_heads 29 | self.config.pad_token_id = pad_token_id 30 | self.config.max_position_embeddings = max_position_embeddings 31 | self.config.hidden_size = embed_dim 32 | super(VTNLongformerModel, self).__init__(self.config, add_pooling_layer=False) 33 | self.embeddings.word_embeddings = None # to avoid distributed error of unused parameters 34 | 35 | 36 | def pad_to_window_size_local(input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, 37 | one_sided_window_size: int, pad_token_id: int): 38 | '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer self-attention. 39 | Based on _pad_to_window_size from https://github.com/huggingface/transformers: 40 | https://github.com/huggingface/transformers/blob/71bdc076dd4ba2f3264283d4bc8617755206dccd/src/transformers/models/longformer/modeling_longformer.py#L1516 41 | Input: 42 | input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces 43 | attention_mask = torch.Tensor(bsz x seqlen): attention mask 44 | one_sided_window_size = int: window size on one side of each token 45 | pad_token_id = int: tokenizer.pad_token_id 46 | Returns 47 | (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size 48 | ''' 49 | w = 2 * one_sided_window_size 50 | seqlen = input_ids.size(1) 51 | padding_len = (w - seqlen % w) % w 52 | input_ids = F.pad(input_ids.permute(0, 2, 1), (0, padding_len), value=pad_token_id).permute(0, 2, 1) 53 | attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens 54 | position_ids = F.pad(position_ids, (1, padding_len), value=False) # no attention on the padding tokens 55 | return input_ids, attention_mask, position_ids --------------------------------------------------------------------------------